Skip to content

Commit e040f2d

Browse files
committed
feat(sagemaker/sessions): add validation layer so if session_manager is None but request is a session request, raise 400 error
- change session api transform's transform_request to never return request field in output (previously used to pass session_manager - change handler code to only take raw_request as a parameter and use new utility function get_session_manager - add new integration tests for expected errors when sessions is disabled
1 parent 9087e47 commit e040f2d

File tree

7 files changed

+172
-23
lines changed

7 files changed

+172
-23
lines changed

python/model_hosting_container_standards/sagemaker/sessions/handlers.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
from fastapi import Request, Response
66
from fastapi.exceptions import HTTPException
77

8-
from .manager import SessionManager
9-
from .models import SageMakerSessionHeader, SessionRequestType
8+
from .manager import get_session_manager
9+
from .models import (
10+
SESSION_DISABLED_ERROR_DETAIL,
11+
SESSION_DISABLED_LOG_MESSAGE,
12+
SageMakerSessionHeader,
13+
SessionRequestType,
14+
)
1015
from .utils import get_session_id_from_request
1116

1217
logger = logging.getLogger(__name__)
@@ -29,11 +34,10 @@ def get_handler_for_request_type(request_type: SessionRequestType):
2934
return None
3035

3136

32-
async def close_session(session_manager: SessionManager, raw_request: Request):
37+
async def close_session(raw_request: Request):
3338
"""Close an existing session and clean up its resources.
3439
3540
Args:
36-
session_manager: SessionManager instance to manage the session lifecycle
3741
raw_request: FastAPI Request object containing session ID in headers
3842
3943
Returns:
@@ -43,6 +47,13 @@ async def close_session(session_manager: SessionManager, raw_request: Request):
4347
HTTPException: If session closure fails with 424 FAILED_DEPENDENCY status
4448
"""
4549
session_id = get_session_id_from_request(raw_request)
50+
session_manager = get_session_manager()
51+
if session_manager is None:
52+
logger.error(SESSION_DISABLED_LOG_MESSAGE)
53+
raise HTTPException(
54+
status_code=HTTPStatus.BAD_REQUEST.value,
55+
detail=SESSION_DISABLED_ERROR_DETAIL,
56+
)
4657
try:
4758
session_manager.close_session(session_id)
4859
logger.info(f"Session {session_id} closed")
@@ -59,11 +70,10 @@ async def close_session(session_manager: SessionManager, raw_request: Request):
5970
)
6071

6172

62-
async def create_session(session_manager: SessionManager, raw_request: Request):
73+
async def create_session(raw_request: Request):
6374
"""Create a new stateful session with expiration tracking.
6475
6576
Args:
66-
session_manager: SessionManager instance to manage the session lifecycle
6777
raw_request: FastAPI Request object (unused but part of handler signature)
6878
6979
Returns:
@@ -72,6 +82,13 @@ async def create_session(session_manager: SessionManager, raw_request: Request):
7282
Raises:
7383
HTTPException: If session creation fails with 424 FAILED_DEPENDENCY status
7484
"""
85+
session_manager = get_session_manager()
86+
if session_manager is None:
87+
logger.error(SESSION_DISABLED_LOG_MESSAGE)
88+
raise HTTPException(
89+
status_code=HTTPStatus.BAD_REQUEST.value,
90+
detail=SESSION_DISABLED_ERROR_DETAIL,
91+
)
7592
try:
7693
session = session_manager.create_session()
7794
# expiration_ts is guaranteed to be set for newly created sessions

python/model_hosting_container_standards/sagemaker/sessions/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,11 @@ class SageMakerSessionHeader:
3737
SESSION_ID = "X-Amzn-SageMaker-Session-Id"
3838
NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id"
3939
CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id"
40+
41+
42+
# Error messages for session management
43+
SESSION_DISABLED_ERROR_DETAIL = "Invalid payload. stateful sessions not enabled"
44+
SESSION_DISABLED_LOG_MESSAGE = (
45+
f"Invalid payload. stateful sessions not enabled, "
46+
f"{SageMakerSessionHeader.SESSION_ID} header not supported"
47+
)

python/model_hosting_container_standards/sagemaker/sessions/transform.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from ...common import BaseApiTransform, BaseTransformRequestOutput
1111
from .handlers import get_handler_for_request_type
1212
from .manager import SessionManager, get_session_manager
13-
from .models import SessionRequest
13+
from .models import (
14+
SESSION_DISABLED_ERROR_DETAIL,
15+
SESSION_DISABLED_LOG_MESSAGE,
16+
SessionRequest,
17+
)
1418
from .utils import get_session, get_session_id_from_request
1519

1620
logger = logging.getLogger(__name__)
@@ -63,7 +67,7 @@ def _validate_session_if_present(raw_request: Request, session_manager: SessionM
6367

6468

6569
def process_session_request(
66-
request_data: dict, raw_request: Request, session_manager: SessionManager
70+
request_data: dict, raw_request: Request, session_manager: Optional[SessionManager]
6771
):
6872
"""Process a potential session management request.
6973
@@ -92,16 +96,22 @@ def process_session_request(
9296
# Not a session request - pass through for normal processing
9397
if session_request is None:
9498
return BaseTransformRequestOutput(
95-
request=None,
9699
raw_request=raw_request,
97100
intercept_func=None,
98101
)
99102

103+
if session_manager is None:
104+
logger.error(SESSION_DISABLED_LOG_MESSAGE)
105+
raise HTTPException(
106+
status_code=HTTPStatus.BAD_REQUEST.value,
107+
detail=SESSION_DISABLED_ERROR_DETAIL,
108+
)
109+
100110
# Route to appropriate session management handler
101111
intercept_func = get_handler_for_request_type(session_request.requestType)
102112

103113
return BaseTransformRequestOutput(
104-
request=session_manager, raw_request=raw_request, intercept_func=intercept_func
114+
raw_request=raw_request, intercept_func=intercept_func
105115
)
106116

107117

python/tests/integration/test_sagemaker_sessions_integration.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
init_session_manager_from_env,
3434
)
3535
from model_hosting_container_standards.sagemaker.sessions.models import (
36+
SESSION_DISABLED_ERROR_DETAIL,
3637
SageMakerSessionHeader,
3738
)
3839

@@ -542,5 +543,97 @@ def test_interleaved_session_operations(self):
542543
assert response.status_code == 200
543544

544545

546+
class TestSessionsDisabled:
547+
"""Test behavior when stateful sessions are disabled.
548+
549+
These tests verify that session management requests fail gracefully
550+
when the SAGEMAKER_ENABLE_STATEFUL_SESSIONS flag is not set.
551+
"""
552+
553+
@pytest.fixture
554+
def app_with_sessions_disabled(self, monkeypatch):
555+
"""Create app with sessions disabled."""
556+
# Explicitly disable sessions
557+
monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False)
558+
monkeypatch.delenv("SAGEMAKER_SESSIONS_PATH", raising=False)
559+
monkeypatch.delenv("SAGEMAKER_SESSIONS_EXPIRATION", raising=False)
560+
561+
# Reinitialize the global session manager (should be None)
562+
init_session_manager_from_env()
563+
564+
# Now create the app with sessions disabled
565+
app = FastAPI()
566+
router = APIRouter()
567+
568+
@router.post("/invocations")
569+
@sagemaker_standards.stateful_session_manager()
570+
async def invocations(request: Request):
571+
"""Stateful invocation handler."""
572+
body_bytes = await request.body()
573+
body = json.loads(body_bytes.decode())
574+
575+
return Response(
576+
status_code=200,
577+
content=json.dumps({"message": "success", "echo": body}),
578+
)
579+
580+
app.include_router(router)
581+
sagemaker_standards.bootstrap(app)
582+
583+
return TestClient(app)
584+
585+
def test_new_session_request_fails_when_disabled(self, app_with_sessions_disabled):
586+
"""Test that NEW_SESSION request fails when sessions are disabled."""
587+
response = app_with_sessions_disabled.post(
588+
"/invocations", json={"requestType": "NEW_SESSION"}
589+
)
590+
591+
# Should fail with 400 BAD_REQUEST since sessions are not enabled
592+
assert response.status_code == 400
593+
assert SESSION_DISABLED_ERROR_DETAIL in response.text
594+
595+
def test_close_session_request_fails_when_disabled(
596+
self, app_with_sessions_disabled
597+
):
598+
"""Test that CLOSE request fails when sessions are disabled."""
599+
response = app_with_sessions_disabled.post(
600+
"/invocations",
601+
json={"requestType": "CLOSE"},
602+
headers={SageMakerSessionHeader.SESSION_ID: "some-session-id"},
603+
)
604+
605+
# Should fail with 400 BAD_REQUEST due to session header when sessions disabled
606+
assert response.status_code == 400
607+
assert SESSION_DISABLED_ERROR_DETAIL in response.text
608+
609+
def test_regular_requests_work_when_sessions_disabled(
610+
self, app_with_sessions_disabled
611+
):
612+
"""Test that regular requests still work when sessions are disabled."""
613+
response = app_with_sessions_disabled.post(
614+
"/invocations", json={"prompt": "test request"}
615+
)
616+
617+
# Regular requests should still work
618+
assert response.status_code == 200
619+
data = json.loads(response.text)
620+
assert data["message"] == "success"
621+
assert data["echo"]["prompt"] == "test request"
622+
623+
def test_regular_requests_with_session_header_when_disabled(
624+
self, app_with_sessions_disabled
625+
):
626+
"""Test that requests with session headers fail validation when sessions disabled."""
627+
response = app_with_sessions_disabled.post(
628+
"/invocations",
629+
json={"prompt": "test"},
630+
headers={SageMakerSessionHeader.SESSION_ID: "invalid-session"},
631+
)
632+
633+
# Should fail with 400 BAD_REQUEST since sessions are not enabled
634+
assert response.status_code == 400
635+
assert SESSION_DISABLED_ERROR_DETAIL in response.text
636+
637+
545638
if __name__ == "__main__":
546639
pytest.main([__file__, "-v"])

python/tests/sagemaker/sessions/test_handlers.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import time
44
from http import HTTPStatus
5-
from unittest.mock import Mock
5+
from unittest.mock import Mock, patch
66

77
import pytest
88
from fastapi import Response
@@ -61,7 +61,11 @@ async def test_creates_session_successfully(
6161
"""Test successfully creates a session and returns response."""
6262
mock_session_manager.create_session.return_value = mock_session_with_expiration
6363

64-
response = await create_session(mock_session_manager, mock_request)
64+
with patch(
65+
"model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager",
66+
return_value=mock_session_manager,
67+
):
68+
response = await create_session(mock_request)
6569

6670
assert isinstance(response, Response)
6771
assert response.status_code == HTTPStatus.OK.value
@@ -80,7 +84,11 @@ async def test_calls_session_manager_create_session(
8084
"""Test calls session_manager.create_session method."""
8185
mock_session_manager.create_session.return_value = mock_session_with_expiration
8286

83-
await create_session(mock_session_manager, mock_request)
87+
with patch(
88+
"model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager",
89+
return_value=mock_session_manager,
90+
):
91+
await create_session(mock_request)
8492

8593
mock_session_manager.create_session.assert_called_once()
8694

@@ -91,8 +99,12 @@ async def test_raises_http_exception_on_session_creation_failure(
9199
"""Test raises HTTPException when session creation fails."""
92100
mock_session_manager.create_session.side_effect = Exception("Creation failed")
93101

94-
with pytest.raises(HTTPException) as exc_info:
95-
await create_session(mock_session_manager, mock_request)
102+
with patch(
103+
"model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager",
104+
return_value=mock_session_manager,
105+
):
106+
with pytest.raises(HTTPException) as exc_info:
107+
await create_session(mock_request)
96108

97109
assert exc_info.value.status_code == HTTPStatus.FAILED_DEPENDENCY.value
98110
assert "Failed to create session" in exc_info.value.detail
@@ -109,7 +121,11 @@ async def test_closes_session_successfully(
109121
session_id = "test-session-123"
110122
mock_session_manager.close_session.return_value = None
111123

112-
response = await close_session(mock_session_manager, mock_request_with_session)
124+
with patch(
125+
"model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager",
126+
return_value=mock_session_manager,
127+
):
128+
response = await close_session(mock_request_with_session)
113129

114130
assert isinstance(response, Response)
115131
assert response.status_code == HTTPStatus.OK.value
@@ -124,8 +140,12 @@ async def test_raises_http_exception_on_close_failure(
124140
"""Test raises HTTPException when session close fails."""
125141
mock_session_manager.close_session.side_effect = ValueError("Session not found")
126142

127-
with pytest.raises(HTTPException) as exc_info:
128-
await close_session(mock_session_manager, mock_request_with_session)
143+
with patch(
144+
"model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager",
145+
return_value=mock_session_manager,
146+
):
147+
with pytest.raises(HTTPException) as exc_info:
148+
await close_session(mock_request_with_session)
129149

130150
assert exc_info.value.status_code == HTTPStatus.FAILED_DEPENDENCY.value
131151
assert "Failed to close session" in exc_info.value.detail

python/tests/sagemaker/sessions/test_transform.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_returns_create_handler_for_new_session_request(
183183
)
184184

185185
assert isinstance(result, BaseTransformRequestOutput)
186-
assert result.request == mock_session_manager
186+
assert result.request is None
187187
assert result.raw_request == mock_request
188188
assert result.intercept_func == create_session
189189

@@ -198,7 +198,7 @@ def test_returns_close_handler_for_close_request(
198198
)
199199

200200
assert isinstance(result, BaseTransformRequestOutput)
201-
assert result.request == mock_session_manager
201+
assert result.request is None
202202
assert result.raw_request == mock_request
203203
assert result.intercept_func == close_session
204204

@@ -335,10 +335,10 @@ async def test_end_to_end_new_session_flow(self, transform):
335335

336336
# Verify we get an intercept function
337337
assert result.intercept_func == create_session
338-
assert result.request == transform._session_manager
338+
assert result.request is None
339339

340340
# Verify we can call the handler
341-
response = await result.intercept_func(result.request, mock_request)
341+
response = await result.intercept_func(mock_request)
342342
assert response.status_code == HTTPStatus.OK.value
343343
assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers
344344

python/tests/sagemaker/sessions/test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastapi.exceptions import HTTPException
99

1010
from model_hosting_container_standards.sagemaker.sessions.models import (
11+
SESSION_DISABLED_ERROR_DETAIL,
1112
SageMakerSessionHeader,
1213
)
1314
from model_hosting_container_standards.sagemaker.sessions.utils import (
@@ -92,7 +93,7 @@ def test_raises_http_exception_when_sessions_not_enabled_but_header_present(self
9293
get_session(None, raw_request)
9394

9495
assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value
95-
assert "stateful sessions not enabled" in exc_info.value.detail
96+
assert SESSION_DISABLED_ERROR_DETAIL in exc_info.value.detail
9697
assert SageMakerSessionHeader.SESSION_ID in exc_info.value.detail
9798

9899
def test_returns_none_when_sessions_not_enabled_and_no_header(self):

0 commit comments

Comments
 (0)