Skip to content

Commit 0e40604

Browse files
authored
Enable toggling sticky routing - default false (#18)
* feat(sagemaker/sessions): add config to enable stateful sessions using env variable * feat(sagemaker/sessions): add utility functions for getting session_manager * chore(sagemaker/sessions): update tests * feat: update env config to use pydantic model SageMakerConfig and use SAGEMAKER_ as prefix. * feat(sagemaker/sessions): add validation layer so if session_manager is None but request is a session request, raise 400 error - change session api transform's transform_request to never return request field in output (previously used to pass session_manager - change handler code to only take raw_request as a parameter and use new utility function get_session_manager - add new integration tests for expected errors when sessions is disabled * Update way of setting sessions_path. * import logger to sessions/transform.py --------- Signed-off-by: Zuyi Zhao <[email protected]>
1 parent 3a14116 commit 0e40604

File tree

11 files changed

+381
-33
lines changed

11 files changed

+381
-33
lines changed

python/model_hosting_container_standards/sagemaker/config.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,99 @@
11
"""SageMaker-specific configuration constants."""
22

3+
import os
4+
from typing import Any, Dict, Optional
5+
6+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
7+
8+
SAGEMAKER_ENV_VAR_PREFIX = "SAGEMAKER_"
9+
10+
11+
class SageMakerConfig(BaseModel):
12+
"""Pydantic model for SageMaker configuration.
13+
14+
Automatically loads configuration from environment variables with SAGEMAKER_ prefix.
15+
Example: SAGEMAKER_ENABLE_STATEFUL_SESSIONS=true -> enable_stateful_sessions=True
16+
17+
Only fields defined in this model are loaded. Other SAGEMAKER_* env vars
18+
(like SAGEMAKER_MODEL_PATH) are ignored.
19+
20+
Usage:
21+
# Create from environment variables
22+
config = SagemakerConfig.from_env()
23+
24+
# Or just instantiate (automatically loads from env)
25+
config = SagemakerConfig()
26+
27+
# Override specific values
28+
config = SagemakerConfig(enable_stateful_sessions=True)
29+
"""
30+
31+
model_config = ConfigDict(extra="ignore")
32+
33+
# Stateful sessions configuration
34+
enable_stateful_sessions: bool = Field(
35+
default=False, description="Enable stateful sessions for the application"
36+
)
37+
sessions_expiration: int = Field(
38+
default=1200, # 20 minutes
39+
description="Session expiration time in seconds",
40+
gt=0,
41+
)
42+
sessions_path: Optional[str] = Field(
43+
default=None,
44+
description="Custom path for session storage (defaults to /dev/shm or temp)",
45+
)
46+
47+
@classmethod
48+
def from_env(cls) -> "SageMakerConfig":
49+
"""Create SagemakerConfig from environment variables.
50+
51+
Returns:
52+
SagemakerConfig instance with values loaded from SAGEMAKER_* env vars
53+
"""
54+
return cls()
55+
56+
@model_validator(mode="before")
57+
@classmethod
58+
def load_from_env_vars(cls, data: Any) -> Dict[str, Any]:
59+
"""Load configuration from environment variables.
60+
61+
Extracts SAGEMAKER_* environment variables and merges with any provided data.
62+
Provided data takes precedence over environment variables.
63+
Unknown SAGEMAKER_* variables are ignored (only defined fields are loaded).
64+
"""
65+
# Extract env vars with SAGEMAKER_ prefix
66+
env_config = {
67+
key[len(SAGEMAKER_ENV_VAR_PREFIX) :].lower(): val
68+
for key, val in os.environ.items()
69+
if key.startswith(SAGEMAKER_ENV_VAR_PREFIX)
70+
}
71+
72+
# If data is provided, merge with env config (data takes precedence)
73+
if isinstance(data, dict):
74+
return {**env_config, **data}
75+
return env_config
76+
77+
@field_validator("enable_stateful_sessions", mode="before")
78+
@classmethod
79+
def parse_bool_string(cls, v: Any) -> bool:
80+
"""Convert string values from env vars to boolean."""
81+
if isinstance(v, bool):
82+
return v
83+
if isinstance(v, str):
84+
return v.lower() in ("true", "1")
85+
return bool(v)
86+
87+
@field_validator("sessions_expiration", mode="before")
88+
@classmethod
89+
def parse_int_string(cls, v: Any) -> int:
90+
"""Convert string values from env vars to integer."""
91+
if isinstance(v, int):
92+
return v
93+
if isinstance(v, str):
94+
return int(v)
95+
return int(v)
96+
397

498
class SageMakerEnvVars:
599
"""SageMaker environment variable names."""

python/model_hosting_container_standards/sagemaker/sessions/handlers.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
from fastapi.exceptions import HTTPException
66

77
from ...logging_config import logger
8-
from .manager import SessionManager
9-
from .models import SageMakerSessionHeader, SessionRequestType
8+
from .manager import get_session_manager
9+
from .models import (
10+
SESSION_DISABLED_ERROR_DETAIL,
11+
SESSION_DISABLED_LOG_MESSAGE,
12+
SageMakerSessionHeader,
13+
SessionRequestType,
14+
)
1015
from .utils import get_session_id_from_request
1116

1217

@@ -27,11 +32,10 @@ def get_handler_for_request_type(request_type: SessionRequestType):
2732
return None
2833

2934

30-
async def close_session(session_manager: SessionManager, raw_request: Request):
35+
async def close_session(raw_request: Request):
3136
"""Close an existing session and clean up its resources.
3237
3338
Args:
34-
session_manager: SessionManager instance to manage the session lifecycle
3539
raw_request: FastAPI Request object containing session ID in headers
3640
3741
Returns:
@@ -41,6 +45,13 @@ async def close_session(session_manager: SessionManager, raw_request: Request):
4145
HTTPException: If session closure fails with 424 FAILED_DEPENDENCY status
4246
"""
4347
session_id = get_session_id_from_request(raw_request)
48+
session_manager = get_session_manager()
49+
if session_manager is None:
50+
logger.error(SESSION_DISABLED_LOG_MESSAGE)
51+
raise HTTPException(
52+
status_code=HTTPStatus.BAD_REQUEST.value,
53+
detail=SESSION_DISABLED_ERROR_DETAIL,
54+
)
4455
try:
4556
session_manager.close_session(session_id)
4657
logger.info(f"Session {session_id} closed")
@@ -57,11 +68,10 @@ async def close_session(session_manager: SessionManager, raw_request: Request):
5768
)
5869

5970

60-
async def create_session(session_manager: SessionManager, raw_request: Request):
71+
async def create_session(raw_request: Request):
6172
"""Create a new stateful session with expiration tracking.
6273
6374
Args:
64-
session_manager: SessionManager instance to manage the session lifecycle
6575
raw_request: FastAPI Request object (unused but part of handler signature)
6676
6777
Returns:
@@ -70,6 +80,13 @@ async def create_session(session_manager: SessionManager, raw_request: Request):
7080
Raises:
7181
HTTPException: If session creation fails with 424 FAILED_DEPENDENCY status
7282
"""
83+
session_manager = get_session_manager()
84+
if session_manager is None:
85+
logger.error(SESSION_DISABLED_LOG_MESSAGE)
86+
raise HTTPException(
87+
status_code=HTTPStatus.BAD_REQUEST.value,
88+
detail=SESSION_DISABLED_ERROR_DETAIL,
89+
)
7390
try:
7491
session = session_manager.create_session()
7592
# expiration_ts is guaranteed to be set for newly created sessions

python/model_hosting_container_standards/sagemaker/sessions/manager.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from threading import RLock
1212
from typing import Optional
1313

14+
from ..config import SageMakerConfig
15+
1416

1517
class Session:
1618
"""Represents a single stateful session with file-based storage.
@@ -137,7 +139,7 @@ def __init__(self, properties: dict):
137139
else:
138140
session_dir = os.path.join(tempfile.gettempdir(), "sagemaker_sessions")
139141

140-
self.sessions_path = properties.get("sessions_path", session_dir)
142+
self.sessions_path = properties.get("sessions_path") or session_dir
141143
self.sessions: dict[str, Session] = {}
142144
self._lock = RLock() # Thread safety for concurrent session access
143145

@@ -248,5 +250,49 @@ def _clean_expired_session(self):
248250
self.close_session(session_id)
249251

250252

251-
# Global SessionManager instance
252-
session_manager = SessionManager({})
253+
def _init_session_manager(config: SageMakerConfig) -> SessionManager | None:
254+
"""Initialize a SessionManager if stateful sessions are enabled.
255+
256+
Args:
257+
config: SagemakerConfig instance with session settings
258+
259+
Returns:
260+
SessionManager instance if enabled, None otherwise
261+
"""
262+
if config.enable_stateful_sessions:
263+
# Convert config to dict for SessionManager
264+
config_dict = {
265+
"sessions_expiration": str(config.sessions_expiration),
266+
"sessions_path": config.sessions_path,
267+
}
268+
return SessionManager(config_dict)
269+
return None
270+
271+
272+
def get_session_manager() -> SessionManager | None:
273+
"""Get the global session manager instance.
274+
275+
Returns:
276+
The global SessionManager instance, or None if not initialized
277+
"""
278+
return session_manager
279+
280+
281+
def init_session_manager_from_env() -> SessionManager | None:
282+
"""Initialize the global session manager from environment variables.
283+
284+
This can be called to reinitialize the session manager after environment
285+
variables have been set.
286+
287+
Returns:
288+
The initialized SessionManager instance, or None if disabled
289+
"""
290+
global session_manager
291+
config = SageMakerConfig.from_env()
292+
session_manager = _init_session_manager(config)
293+
return session_manager
294+
295+
296+
# Global SessionManager instance - initialized from environment variables
297+
_config = SageMakerConfig.from_env()
298+
session_manager = _init_session_manager(_config)

python/model_hosting_container_standards/sagemaker/sessions/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,11 @@ class SageMakerSessionHeader:
3737
SESSION_ID = "X-Amzn-SageMaker-Session-Id"
3838
NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id"
3939
CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id"
40+
41+
42+
# Error messages for session management
43+
SESSION_DISABLED_ERROR_DETAIL = "Invalid payload. stateful sessions not enabled"
44+
SESSION_DISABLED_LOG_MESSAGE = (
45+
f"Invalid payload. stateful sessions not enabled, "
46+
f"{SageMakerSessionHeader.SESSION_ID} header not supported"
47+
)

python/model_hosting_container_standards/sagemaker/sessions/transform.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
from pydantic import ValidationError
88

99
from ...common import BaseApiTransform, BaseTransformRequestOutput
10+
from ...logging_config import logger
1011
from .handlers import get_handler_for_request_type
11-
from .manager import SessionManager, session_manager
12-
from .models import SessionRequest
12+
from .manager import SessionManager, get_session_manager
13+
from .models import (
14+
SESSION_DISABLED_ERROR_DETAIL,
15+
SESSION_DISABLED_LOG_MESSAGE,
16+
SessionRequest,
17+
)
1318
from .utils import get_session, get_session_id_from_request
1419

1520

@@ -38,7 +43,9 @@ def _parse_session_request(request_data: dict) -> Optional[SessionRequest]:
3843
return None
3944

4045

41-
def _validate_session_if_present(raw_request: Request, session_manager: SessionManager):
46+
def _validate_session_if_present(
47+
raw_request: Request, session_manager: Optional[SessionManager]
48+
):
4249
"""Validate that the session ID in the request exists and is not expired.
4350
4451
Args:
@@ -60,7 +67,7 @@ def _validate_session_if_present(raw_request: Request, session_manager: SessionM
6067

6168

6269
def process_session_request(
63-
request_data: dict, raw_request: Request, session_manager: SessionManager
70+
request_data: dict, raw_request: Request, session_manager: Optional[SessionManager]
6471
):
6572
"""Process a potential session management request.
6673
@@ -89,16 +96,22 @@ def process_session_request(
8996
# Not a session request - pass through for normal processing
9097
if session_request is None:
9198
return BaseTransformRequestOutput(
92-
request=None,
9399
raw_request=raw_request,
94100
intercept_func=None,
95101
)
96102

103+
if session_manager is None:
104+
logger.error(SESSION_DISABLED_LOG_MESSAGE)
105+
raise HTTPException(
106+
status_code=HTTPStatus.BAD_REQUEST.value,
107+
detail=SESSION_DISABLED_ERROR_DETAIL,
108+
)
109+
97110
# Route to appropriate session management handler
98111
intercept_func = get_handler_for_request_type(session_request.requestType)
99112

100113
return BaseTransformRequestOutput(
101-
request=session_manager, raw_request=raw_request, intercept_func=intercept_func
114+
raw_request=raw_request, intercept_func=intercept_func
102115
)
103116

104117

@@ -121,7 +134,7 @@ def __init__(self, request_shape, response_shape={}):
121134
The request/response shapes are passed to the parent class but not used
122135
for validation in this transform, as session requests use their own validation.
123136
"""
124-
self._session_manager = session_manager
137+
self._session_manager = get_session_manager()
125138
super().__init__(request_shape, response_shape)
126139

127140
async def transform_request(self, raw_request):

python/model_hosting_container_standards/sagemaker/sessions/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from http import HTTPStatus
2+
from typing import Optional
23

34
from fastapi import Request
45
from fastapi.exceptions import HTTPException
@@ -22,7 +23,7 @@ def get_session_id_from_request(raw_request: Request):
2223
return raw_request.headers.get(SageMakerSessionHeader.SESSION_ID)
2324

2425

25-
def get_session(session_manager: SessionManager, raw_request: Request):
26+
def get_session(session_manager: Optional[SessionManager], raw_request: Request):
2627
"""Retrieve the session associated with the request.
2728
2829
Args:

0 commit comments

Comments
 (0)