Skip to content

Commit c0ac850

Browse files
committed
Add transient-error classification to APIFailure
API.do_request now records the HTTP status code on every exception it raises (status_code attribute), and APIFailure gains is_transient_error(): True for gateway/connection-level failures (HTTP 408/502/503/504, dropped or reset connections, client-side timeouts) where retrying the same request may succeed, False for deterministic errors (400/401/403/404/429, wrapped unexpected errors). Classification is based on the recorded status code rather than exception class identity or message text, so it stays correct if a status code gains a dedicated subclass later. Motivated by SocketDev/socket-python-cli#232: the CLI retries transient full-scan upload failures and previously had to parse the status code out of catch-all APIFailure message text.
1 parent 836936c commit c0ac850

6 files changed

Lines changed: 269 additions & 60 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "socketdev"
7-
version = "3.2.1"
7+
version = "3.3.0"
88
requires-python = ">= 3.9"
99
dependencies = [
1010
'requests',

socketdev/core/api.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,26 +76,26 @@ def format_headers(headers_dict):
7676
path_str = f"\nPath: {url}"
7777

7878
if response.status_code == 401:
79-
raise APIAccessDenied(f"Unauthorized{path_str}{headers_str}")
79+
raise APIAccessDenied(f"Unauthorized{path_str}{headers_str}", status_code=401)
8080
if response.status_code == 403:
8181
try:
8282
error_message = response.json().get("error", {}).get("message", "")
8383
if "Insufficient permissions for API method" in error_message:
8484
log.error(f"{error_message}{path_str}{headers_str}")
85-
raise APIInsufficientPermissions()
85+
raise APIInsufficientPermissions(status_code=403)
8686
elif "Organization not allowed" in error_message:
8787
log.error(f"{error_message}{path_str}{headers_str}")
88-
raise APIOrganizationNotAllowed()
88+
raise APIOrganizationNotAllowed(status_code=403)
8989
elif "Insufficient max quota" in error_message:
9090
log.error(f"{error_message}{path_str}{headers_str}")
91-
raise APIInsufficientQuota()
91+
raise APIInsufficientQuota(status_code=403)
9292
else:
93-
raise APIAccessDenied(f"{error_message or 'Access denied'}{path_str}{headers_str}")
93+
raise APIAccessDenied(f"{error_message or 'Access denied'}{path_str}{headers_str}", status_code=403)
9494
except ValueError:
95-
raise APIAccessDenied(f"Access denied{path_str}{headers_str}")
95+
raise APIAccessDenied(f"Access denied{path_str}{headers_str}", status_code=403)
9696
if response.status_code == 404:
9797
log.error(f"Path not found {path}{path_str}{headers_str}")
98-
raise APIResourceNotFound()
98+
raise APIResourceNotFound(status_code=404)
9999
if response.status_code == 429:
100100
retry_after = response.headers.get("retry-after")
101101
if retry_after:
@@ -109,10 +109,10 @@ def format_headers(headers_dict):
109109
else:
110110
time_msg = ""
111111
log.error(f"Insufficient quota for API route.{time_msg}{path_str}{headers_str}")
112-
raise APIInsufficientQuota()
112+
raise APIInsufficientQuota(status_code=429)
113113
if response.status_code == 502:
114114
log.error(f"Upstream server error{path_str}{headers_str}")
115-
raise APIBadGateway()
115+
raise APIBadGateway(status_code=502)
116116
if response.status_code >= 400:
117117
try:
118118
error_json = response.json()
@@ -124,7 +124,7 @@ def format_headers(headers_dict):
124124
f"Error message: {error_message}"
125125
)
126126
log.error(error)
127-
raise APIFailure(error)
127+
raise APIFailure(error, status_code=response.status_code)
128128

129129
return response
130130
except Timeout:

socketdev/exceptions.py

Lines changed: 80 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,80 @@
1-
class APIFailure(Exception):
2-
"""Base exception for all Socket API errors"""
3-
pass
4-
5-
6-
class APIKeyMissing(APIFailure):
7-
"""Raised when the api key is not passed and the headers are empty"""
8-
9-
10-
class APIAccessDenied(APIFailure):
11-
"""Raised when access is denied to the API"""
12-
pass
13-
14-
15-
class APIInsufficientPermissions(APIFailure):
16-
"""Raised when the API token doesn't have required permissions"""
17-
pass
18-
19-
20-
class APIOrganizationNotAllowed(APIFailure):
21-
"""Raised when organization doesn't have access to the feature"""
22-
pass
23-
24-
25-
class APIInsufficientQuota(APIFailure):
26-
"""Raised when access is denied to the API due to quota limits"""
27-
pass
28-
29-
30-
class APIResourceNotFound(APIFailure):
31-
"""Raised when the requested resource is not found"""
32-
pass
33-
34-
35-
class APITimeout(APIFailure):
36-
"""Raised when a request times out"""
37-
pass
38-
39-
40-
class APIConnectionError(APIFailure):
41-
"""Raised when there's a connection error"""
42-
pass
43-
44-
45-
class APIBadGateway(APIFailure):
46-
"""Raised when the upstream server returns a 502 Bad Gateway error"""
47-
pass
1+
from typing import Optional
2+
3+
# HTTP statuses classified as transient by APIFailure.is_transient_error(): gateway /
4+
# availability failures where the request was dropped before the application produced a
5+
# definitive response, so retrying the same request may succeed (408 Request Timeout,
6+
# 502 Bad Gateway, 503 Service Unavailable, 504 Gateway Timeout).
7+
TRANSIENT_HTTP_STATUS_CODES = frozenset({408, 502, 503, 504})
8+
9+
10+
class APIFailure(Exception):
11+
"""Base exception for all Socket API errors"""
12+
13+
def __init__(self, *args, status_code: Optional[int] = None):
14+
super().__init__(*args)
15+
self.status_code = status_code
16+
17+
def is_transient_error(self) -> bool:
18+
"""Whether this failure is transient, i.e. retrying the same request may succeed.
19+
20+
Transient failures happen at the gateway/connection level - HTTP 408/502/503/504,
21+
dropped or reset connections, and client-side timeouts - before the server produced
22+
a definitive response. Deterministic errors (e.g. 400/401/403/404/429) are not
23+
transient: retrying the same request fails the same way. Classification is based on
24+
the HTTP status code recorded when the exception was raised (or overridden by
25+
subclasses without an HTTP status, like timeouts), so it stays correct even if a
26+
status code gains a dedicated exception subclass later.
27+
"""
28+
return self.status_code in TRANSIENT_HTTP_STATUS_CODES
29+
30+
31+
class APIKeyMissing(APIFailure):
32+
"""Raised when the api key is not passed and the headers are empty"""
33+
34+
35+
class APIAccessDenied(APIFailure):
36+
"""Raised when access is denied to the API"""
37+
pass
38+
39+
40+
class APIInsufficientPermissions(APIFailure):
41+
"""Raised when the API token doesn't have required permissions"""
42+
pass
43+
44+
45+
class APIOrganizationNotAllowed(APIFailure):
46+
"""Raised when organization doesn't have access to the feature"""
47+
pass
48+
49+
50+
class APIInsufficientQuota(APIFailure):
51+
"""Raised when access is denied to the API due to quota limits"""
52+
pass
53+
54+
55+
class APIResourceNotFound(APIFailure):
56+
"""Raised when the requested resource is not found"""
57+
pass
58+
59+
60+
class APITimeout(APIFailure):
61+
"""Raised when a request times out"""
62+
63+
def is_transient_error(self) -> bool:
64+
# No HTTP status: the request timed out client-side, so a retry may succeed.
65+
return True
66+
67+
68+
class APIConnectionError(APIFailure):
69+
"""Raised when there's a connection error"""
70+
71+
def is_transient_error(self) -> bool:
72+
# No HTTP status: the connection was dropped/reset mid-request, so a retry may succeed.
73+
return True
74+
75+
76+
class APIBadGateway(APIFailure):
77+
"""Raised when the upstream server returns a 502 Bad Gateway error"""
78+
79+
def __init__(self, *args, status_code: Optional[int] = 502):
80+
super().__init__(*args, status_code=status_code)

socketdev/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.2.1"
1+
__version__ = "3.3.0"

tests/unit/test_exceptions.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
Unit tests for the SDK exception hierarchy and transient-error classification.
3+
4+
`APIFailure.is_transient_error()` tells consumers whether retrying the same request may
5+
succeed (gateway/connection-level failures: HTTP 408/502/503/504, dropped or reset
6+
connections, client-side timeouts) or whether the failure is deterministic (400/401/403/
7+
404/429 and similar). Classification is based on the `status_code` recorded at raise time
8+
inside `API.do_request`, so these tests cover both the exception classes themselves and
9+
the status codes `do_request` attaches when raising them.
10+
11+
Run with: python -m pytest tests/unit/ -v
12+
"""
13+
14+
import unittest
15+
from unittest.mock import MagicMock, patch
16+
17+
import requests
18+
19+
from socketdev.core.api import API
20+
from socketdev.exceptions import (
21+
APIAccessDenied,
22+
APIBadGateway,
23+
APIConnectionError,
24+
APIFailure,
25+
APIInsufficientPermissions,
26+
APIInsufficientQuota,
27+
APIOrganizationNotAllowed,
28+
APIResourceNotFound,
29+
APITimeout,
30+
)
31+
32+
33+
class TestIsTransientError(unittest.TestCase):
34+
"""Classification of exceptions constructed directly."""
35+
36+
def test_transient_statuses_on_catch_all_failure(self):
37+
for status in (408, 502, 503, 504):
38+
self.assertTrue(APIFailure("boom", status_code=status).is_transient_error())
39+
40+
def test_deterministic_statuses_on_catch_all_failure(self):
41+
for status in (400, 401, 403, 404, 422, 429, 500):
42+
self.assertFalse(APIFailure("boom", status_code=status).is_transient_error())
43+
44+
def test_no_status_code_is_not_transient(self):
45+
# The wrapped-unexpected-error case: do_request raises a bare APIFailure().
46+
self.assertFalse(APIFailure().is_transient_error())
47+
self.assertFalse(APIFailure("boom").is_transient_error())
48+
49+
def test_connection_level_classes_are_transient(self):
50+
self.assertTrue(APITimeout().is_transient_error())
51+
self.assertTrue(APIConnectionError().is_transient_error())
52+
self.assertTrue(APIBadGateway().is_transient_error())
53+
54+
def test_bad_gateway_carries_502_by_default(self):
55+
self.assertEqual(APIBadGateway().status_code, 502)
56+
57+
def test_dedicated_4xx_classes_are_not_transient(self):
58+
self.assertFalse(APIAccessDenied("denied", status_code=401).is_transient_error())
59+
self.assertFalse(APIInsufficientPermissions(status_code=403).is_transient_error())
60+
self.assertFalse(APIOrganizationNotAllowed(status_code=403).is_transient_error())
61+
self.assertFalse(APIInsufficientQuota(status_code=429).is_transient_error())
62+
self.assertFalse(APIResourceNotFound(status_code=404).is_transient_error())
63+
64+
def test_subclass_with_transient_status_follows_the_status(self):
65+
# Classification is by recorded status, not class identity: if a transient status
66+
# ever gains a dedicated subclass, is_transient_error() keeps working unchanged.
67+
class APIServiceUnavailable(APIFailure):
68+
pass
69+
70+
self.assertTrue(APIServiceUnavailable(status_code=503).is_transient_error())
71+
72+
def test_message_text_does_not_affect_classification(self):
73+
self.assertFalse(
74+
APIFailure("original_status_code:503 lookalike").is_transient_error()
75+
)
76+
77+
def test_single_message_arg_is_preserved(self):
78+
error = APIFailure("something broke", status_code=503)
79+
self.assertEqual(str(error), "something broke")
80+
81+
82+
def _mock_response(status_code, json_data=None, headers=None, text=""):
83+
response = MagicMock()
84+
response.status_code = status_code
85+
response.headers = headers if headers is not None else {}
86+
response.text = text
87+
if json_data is None:
88+
response.json.side_effect = ValueError("no json")
89+
else:
90+
response.json.return_value = json_data
91+
return response
92+
93+
94+
class TestDoRequestStatusCodes(unittest.TestCase):
95+
"""do_request attaches the HTTP status to the exceptions it raises."""
96+
97+
def setUp(self):
98+
self.api = API()
99+
self.api.encode_key("test-token")
100+
101+
def _do_request_raising(self, expected_class, response=None, side_effect=None):
102+
with patch("socketdev.core.api.requests.request") as mock_request:
103+
if side_effect is not None:
104+
mock_request.side_effect = side_effect
105+
else:
106+
mock_request.return_value = response
107+
with self.assertRaises(expected_class) as ctx:
108+
self.api.do_request("orgs/test/full-scans", method="POST")
109+
return ctx.exception
110+
111+
def test_401_access_denied_is_not_transient(self):
112+
error = self._do_request_raising(APIAccessDenied, _mock_response(401))
113+
self.assertEqual(error.status_code, 401)
114+
self.assertFalse(error.is_transient_error())
115+
116+
def test_403_insufficient_permissions_is_not_transient(self):
117+
response = _mock_response(
118+
403,
119+
json_data={"error": {"message": "Insufficient permissions for API method"}},
120+
)
121+
error = self._do_request_raising(APIInsufficientPermissions, response)
122+
self.assertEqual(error.status_code, 403)
123+
self.assertFalse(error.is_transient_error())
124+
125+
def test_404_not_found_is_not_transient(self):
126+
error = self._do_request_raising(APIResourceNotFound, _mock_response(404))
127+
self.assertEqual(error.status_code, 404)
128+
self.assertFalse(error.is_transient_error())
129+
130+
def test_429_quota_is_not_transient(self):
131+
error = self._do_request_raising(APIInsufficientQuota, _mock_response(429))
132+
self.assertEqual(error.status_code, 429)
133+
self.assertFalse(error.is_transient_error())
134+
135+
def test_502_bad_gateway_is_transient(self):
136+
error = self._do_request_raising(APIBadGateway, _mock_response(502))
137+
self.assertEqual(error.status_code, 502)
138+
self.assertTrue(error.is_transient_error())
139+
140+
def test_catch_all_transient_statuses(self):
141+
for status in (408, 503, 504):
142+
error = self._do_request_raising(APIFailure, _mock_response(status))
143+
self.assertIs(type(error), APIFailure)
144+
self.assertEqual(error.status_code, status)
145+
self.assertTrue(error.is_transient_error())
146+
147+
def test_catch_all_deterministic_statuses(self):
148+
for status in (400, 500):
149+
error = self._do_request_raising(APIFailure, _mock_response(status))
150+
self.assertIs(type(error), APIFailure)
151+
self.assertEqual(error.status_code, status)
152+
self.assertFalse(error.is_transient_error())
153+
154+
def test_timeout_is_transient(self):
155+
error = self._do_request_raising(
156+
APITimeout, side_effect=requests.exceptions.Timeout("timed out")
157+
)
158+
self.assertIsNone(error.status_code)
159+
self.assertTrue(error.is_transient_error())
160+
161+
def test_connection_error_is_transient(self):
162+
error = self._do_request_raising(
163+
APIConnectionError,
164+
side_effect=requests.exceptions.ConnectionError("reset"),
165+
)
166+
self.assertIsNone(error.status_code)
167+
self.assertTrue(error.is_transient_error())
168+
169+
def test_unexpected_error_wrapped_without_status_is_not_transient(self):
170+
error = self._do_request_raising(APIFailure, side_effect=RuntimeError("boom"))
171+
self.assertIsNone(error.status_code)
172+
self.assertFalse(error.is_transient_error())
173+
174+
175+
if __name__ == "__main__":
176+
unittest.main()

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)