|
| 1 | +""" |
| 2 | +Unit tests for the SDK exception hierarchy and transient-error classification. |
| 3 | +
|
| 4 | +`APIFailure.is_transient_error()` tells consumers whether retrying the same request may |
| 5 | +succeed (gateway/connection-level failures: HTTP 408/502/503/504, dropped or reset |
| 6 | +connections, client-side timeouts) or whether the failure is deterministic (400/401/403/ |
| 7 | +404/429 and similar). Classification is based on the `status_code` recorded at raise time |
| 8 | +inside `API.do_request`, so these tests cover both the exception classes themselves and |
| 9 | +the status codes `do_request` attaches when raising them. |
| 10 | +
|
| 11 | +Run with: python -m pytest tests/unit/ -v |
| 12 | +""" |
| 13 | + |
| 14 | +import unittest |
| 15 | +from unittest.mock import MagicMock, patch |
| 16 | + |
| 17 | +import requests |
| 18 | + |
| 19 | +from socketdev.core.api import API |
| 20 | +from socketdev.exceptions import ( |
| 21 | + APIAccessDenied, |
| 22 | + APIBadGateway, |
| 23 | + APIConnectionError, |
| 24 | + APIFailure, |
| 25 | + APIInsufficientPermissions, |
| 26 | + APIInsufficientQuota, |
| 27 | + APIOrganizationNotAllowed, |
| 28 | + APIResourceNotFound, |
| 29 | + APITimeout, |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +class TestIsTransientError(unittest.TestCase): |
| 34 | + """Classification of exceptions constructed directly.""" |
| 35 | + |
| 36 | + def test_transient_statuses_on_catch_all_failure(self): |
| 37 | + for status in (408, 502, 503, 504): |
| 38 | + self.assertTrue(APIFailure("boom", status_code=status).is_transient_error()) |
| 39 | + |
| 40 | + def test_deterministic_statuses_on_catch_all_failure(self): |
| 41 | + for status in (400, 401, 403, 404, 422, 429, 500): |
| 42 | + self.assertFalse(APIFailure("boom", status_code=status).is_transient_error()) |
| 43 | + |
| 44 | + def test_no_status_code_is_not_transient(self): |
| 45 | + # The wrapped-unexpected-error case: do_request raises a bare APIFailure(). |
| 46 | + self.assertFalse(APIFailure().is_transient_error()) |
| 47 | + self.assertFalse(APIFailure("boom").is_transient_error()) |
| 48 | + |
| 49 | + def test_connection_level_classes_are_transient(self): |
| 50 | + self.assertTrue(APITimeout().is_transient_error()) |
| 51 | + self.assertTrue(APIConnectionError().is_transient_error()) |
| 52 | + self.assertTrue(APIBadGateway().is_transient_error()) |
| 53 | + |
| 54 | + def test_bad_gateway_carries_502_by_default(self): |
| 55 | + self.assertEqual(APIBadGateway().status_code, 502) |
| 56 | + |
| 57 | + def test_dedicated_4xx_classes_are_not_transient(self): |
| 58 | + self.assertFalse(APIAccessDenied("denied", status_code=401).is_transient_error()) |
| 59 | + self.assertFalse(APIInsufficientPermissions(status_code=403).is_transient_error()) |
| 60 | + self.assertFalse(APIOrganizationNotAllowed(status_code=403).is_transient_error()) |
| 61 | + self.assertFalse(APIInsufficientQuota(status_code=429).is_transient_error()) |
| 62 | + self.assertFalse(APIResourceNotFound(status_code=404).is_transient_error()) |
| 63 | + |
| 64 | + def test_subclass_with_transient_status_follows_the_status(self): |
| 65 | + # Classification is by recorded status, not class identity: if a transient status |
| 66 | + # ever gains a dedicated subclass, is_transient_error() keeps working unchanged. |
| 67 | + class APIServiceUnavailable(APIFailure): |
| 68 | + pass |
| 69 | + |
| 70 | + self.assertTrue(APIServiceUnavailable(status_code=503).is_transient_error()) |
| 71 | + |
| 72 | + def test_message_text_does_not_affect_classification(self): |
| 73 | + self.assertFalse( |
| 74 | + APIFailure("original_status_code:503 lookalike").is_transient_error() |
| 75 | + ) |
| 76 | + |
| 77 | + def test_single_message_arg_is_preserved(self): |
| 78 | + error = APIFailure("something broke", status_code=503) |
| 79 | + self.assertEqual(str(error), "something broke") |
| 80 | + |
| 81 | + |
| 82 | +def _mock_response(status_code, json_data=None, headers=None, text=""): |
| 83 | + response = MagicMock() |
| 84 | + response.status_code = status_code |
| 85 | + response.headers = headers if headers is not None else {} |
| 86 | + response.text = text |
| 87 | + if json_data is None: |
| 88 | + response.json.side_effect = ValueError("no json") |
| 89 | + else: |
| 90 | + response.json.return_value = json_data |
| 91 | + return response |
| 92 | + |
| 93 | + |
| 94 | +class TestDoRequestStatusCodes(unittest.TestCase): |
| 95 | + """do_request attaches the HTTP status to the exceptions it raises.""" |
| 96 | + |
| 97 | + def setUp(self): |
| 98 | + self.api = API() |
| 99 | + self.api.encode_key("test-token") |
| 100 | + |
| 101 | + def _do_request_raising(self, expected_class, response=None, side_effect=None): |
| 102 | + with patch("socketdev.core.api.requests.request") as mock_request: |
| 103 | + if side_effect is not None: |
| 104 | + mock_request.side_effect = side_effect |
| 105 | + else: |
| 106 | + mock_request.return_value = response |
| 107 | + with self.assertRaises(expected_class) as ctx: |
| 108 | + self.api.do_request("orgs/test/full-scans", method="POST") |
| 109 | + return ctx.exception |
| 110 | + |
| 111 | + def test_401_access_denied_is_not_transient(self): |
| 112 | + error = self._do_request_raising(APIAccessDenied, _mock_response(401)) |
| 113 | + self.assertEqual(error.status_code, 401) |
| 114 | + self.assertFalse(error.is_transient_error()) |
| 115 | + |
| 116 | + def test_403_insufficient_permissions_is_not_transient(self): |
| 117 | + response = _mock_response( |
| 118 | + 403, |
| 119 | + json_data={"error": {"message": "Insufficient permissions for API method"}}, |
| 120 | + ) |
| 121 | + error = self._do_request_raising(APIInsufficientPermissions, response) |
| 122 | + self.assertEqual(error.status_code, 403) |
| 123 | + self.assertFalse(error.is_transient_error()) |
| 124 | + |
| 125 | + def test_404_not_found_is_not_transient(self): |
| 126 | + error = self._do_request_raising(APIResourceNotFound, _mock_response(404)) |
| 127 | + self.assertEqual(error.status_code, 404) |
| 128 | + self.assertFalse(error.is_transient_error()) |
| 129 | + |
| 130 | + def test_429_quota_is_not_transient(self): |
| 131 | + error = self._do_request_raising(APIInsufficientQuota, _mock_response(429)) |
| 132 | + self.assertEqual(error.status_code, 429) |
| 133 | + self.assertFalse(error.is_transient_error()) |
| 134 | + |
| 135 | + def test_502_bad_gateway_is_transient(self): |
| 136 | + error = self._do_request_raising(APIBadGateway, _mock_response(502)) |
| 137 | + self.assertEqual(error.status_code, 502) |
| 138 | + self.assertTrue(error.is_transient_error()) |
| 139 | + |
| 140 | + def test_catch_all_transient_statuses(self): |
| 141 | + for status in (408, 503, 504): |
| 142 | + error = self._do_request_raising(APIFailure, _mock_response(status)) |
| 143 | + self.assertIs(type(error), APIFailure) |
| 144 | + self.assertEqual(error.status_code, status) |
| 145 | + self.assertTrue(error.is_transient_error()) |
| 146 | + |
| 147 | + def test_catch_all_deterministic_statuses(self): |
| 148 | + for status in (400, 500): |
| 149 | + error = self._do_request_raising(APIFailure, _mock_response(status)) |
| 150 | + self.assertIs(type(error), APIFailure) |
| 151 | + self.assertEqual(error.status_code, status) |
| 152 | + self.assertFalse(error.is_transient_error()) |
| 153 | + |
| 154 | + def test_timeout_is_transient(self): |
| 155 | + error = self._do_request_raising( |
| 156 | + APITimeout, side_effect=requests.exceptions.Timeout("timed out") |
| 157 | + ) |
| 158 | + self.assertIsNone(error.status_code) |
| 159 | + self.assertTrue(error.is_transient_error()) |
| 160 | + |
| 161 | + def test_connection_error_is_transient(self): |
| 162 | + error = self._do_request_raising( |
| 163 | + APIConnectionError, |
| 164 | + side_effect=requests.exceptions.ConnectionError("reset"), |
| 165 | + ) |
| 166 | + self.assertIsNone(error.status_code) |
| 167 | + self.assertTrue(error.is_transient_error()) |
| 168 | + |
| 169 | + def test_unexpected_error_wrapped_without_status_is_not_transient(self): |
| 170 | + error = self._do_request_raising(APIFailure, side_effect=RuntimeError("boom")) |
| 171 | + self.assertIsNone(error.status_code) |
| 172 | + self.assertFalse(error.is_transient_error()) |
| 173 | + |
| 174 | + |
| 175 | +if __name__ == "__main__": |
| 176 | + unittest.main() |
0 commit comments