Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fa63931
feat(sagemaker/sessions): add config to enable stateful sessions usin…
zhaozuy Nov 10, 2025
0a99c84
feat(sagemaker/sessions): add utility functions for getting session_m…
zhaozuy Nov 10, 2025
2e57cd5
chore(sagemaker/sessions): update tests
zhaozuy Nov 10, 2025
85e1687
feat: update env config to use pydantic model SageMakerConfig and use…
zhaozuy Nov 10, 2025
9087e47
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Nov 10, 2025
e040f2d
feat(sagemaker/sessions): add validation layer so if session_manager …
zhaozuy Nov 11, 2025
9c9791a
Update way of setting sessions_path.
zhaozuy Nov 11, 2025
1d9f7fd
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Nov 18, 2025
413dfd2
feat(initial - sagemaker/sessions): support engines with their own cr…
zhaozuy Nov 17, 2025
aaf7774
feat(initial - sagemaker/sessions): refactor create/close api transfo…
zhaozuy Nov 18, 2025
24838c2
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Dec 1, 2025
1a31cf2
import logger to sessions/transform.py
zhaozuy Dec 1, 2025
f2aecfb
Merge branch 'toggle-sticky-routing' of github.com:aws/model-hosting-…
zhaozuy Dec 1, 2025
ce4ab40
Remove manual logger setups.
zhaozuy Dec 1, 2025
4f62b3f
Update README.md
zhaozuy Dec 2, 2025
2975b10
Fix linting.
zhaozuy Dec 2, 2025
dfc6b97
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Dec 3, 2025
39a1af2
wip - update stateful sessions manager to move sm id header to target
zhaozuy Dec 3, 2025
c77013e
fix(sessions): Fix session ID injection and update tests
zhaozuy Dec 4, 2025
454ad34
Add unit tests
zhaozuy Dec 4, 2025
259ea2f
Update tests, improve how check for use default is done.
zhaozuy Dec 4, 2025
5c01442
Remove unnecessary bootstrap in integ tests.
zhaozuy Dec 5, 2025
c94f0e3
chore(sessions): clarify parameter naming and improve documentation
zhaozuy Dec 6, 2025
ecb6184
Update docs.
zhaozuy Dec 6, 2025
89e2f1d
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Dec 8, 2025
8de6d32
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Dec 9, 2025
ec0ea84
refactor(sagemaker/sessions): simplify custom session handler registr…
zhaozuy Dec 9, 2025
b7fb991
refactor(sagemaker/sessions): improve parameter naming clarity for cu…
zhaozuy Dec 9, 2025
4036e22
fix docstring, add status_code to serialize_response dict output.
zhaozuy Dec 10, 2025
7303275
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Dec 10, 2025
d8f6635
Merge branch 'main' of github.com:aws/model-hosting-container-standar…
zhaozuy Dec 10, 2025
4bc21c7
update set_value call in SessionApiTransform _process_invocations_req…
zhaozuy Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 90 additions & 90 deletions docs/INTEGRATION_RUNBOOK.md

Large diffs are not rendered by default.

30 changes: 29 additions & 1 deletion python/model_hosting_container_standards/common/fastapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from typing import Any, Dict, Optional, Union

from fastapi import Request
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from pydantic import BaseModel


Expand Down Expand Up @@ -33,3 +35,29 @@ def serialize_request(
"query_params": raw_request.query_params,
"path_params": raw_request.path_params,
}


def serialize_response(response: Union[Response, JSONResponse]):
"""Create a structured data dictionary for JMESPath transformations.

Extracts and organizes response data into a standardized format that can be used
with JMESPath expressions to transform and extract specific data elements.

:param Union[Response, JSONResponse] response: Response body data - can be:
- FastAPI Response object
- JSONResponse object
:return Dict[str, Any]: Structured data with body, headers, and status_code
"""
# Process response body based on type
body = response.body.decode(response.charset)
try:
body = json.loads(body)
except json.JSONDecodeError:
# If body is not JSON, keep it as a string
pass

return {
"body": body,
"headers": response.headers,
"status_code": response.status_code,
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, Field

from ...logging_config import logger
from ..fastapi.utils import serialize_request
from ..fastapi.utils import serialize_request, serialize_response
from .utils import _compile_jmespath_expressions


Expand Down Expand Up @@ -103,6 +103,19 @@ async def transform_request(self, raw_request: Request):
"""
raise NotImplementedError()

def _transform_response(self, response: Response):
"""Transform the response based on the request processing results.

Subclasses must implement this method to handle request parsing, validation,
and transformation according to their specific operation requirements.

:param Response response: The response to transform
:param transform_request_output: Output from the request transformation
:raises NotImplementedError: Must be implemented by subclasses
"""
response_data = serialize_response(response)
return self._transform(response_data, self._response_shape)

def _transform_request(
self, request: Optional[BaseModel], raw_request: Request
) -> Dict[str, Any]:
Expand Down
157 changes: 151 additions & 6 deletions python/model_hosting_container_standards/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from .lora.models import AppendOperation
from .sagemaker_loader import SageMakerFunctionLoader
from .sagemaker_router import create_sagemaker_router
from .sessions import create_session_transform_decorator
from .sessions import (
build_session_request_shape,
create_session_transform_decorator,
register_engine_session_handler,
)
from .sessions.models import SageMakerSessionHeader

# SageMaker decorator instances - created using utility functions

Expand Down Expand Up @@ -118,17 +123,157 @@ def inject_adapter_id(
)


def stateful_session_manager():
def stateful_session_manager(engine_request_session_id_path: Optional[str] = None):
"""Create a decorator for session-based sticky routing.

This decorator enables stateful session management without JMESPath transformations.
Pass empty dicts to enable transform infrastructure (for intercept functionality)
without requiring JMESPath expressions.
This decorator enables stateful session management for regular invocation requests,
allowing the session ID to be injected into the request body for stateful inference.

Args:
engine_request_session_id_path: Optional target path in the request body where
the session ID will be injected. The session ID
is extracted from the SageMaker session header and
placed at this path in the request sent to the engine.

Examples: "session_id", "metadata.session_id"

If None, session management is enabled but the
session ID is not injected into the request body.

Returns:
A decorator that can be applied to route handlers to enable session management
"""
return create_session_transform_decorator()(request_shape={}, response_shape={})
request_shape = {}
if engine_request_session_id_path:
request_shape[engine_request_session_id_path] = (
f'headers."{SageMakerSessionHeader.SESSION_ID}"'
)
return create_session_transform_decorator()(
request_shape=request_shape, response_shape={}
)


def register_create_session_handler(
engine_response_session_id_path: str,
engine_request_session_id_path: Optional[str] = None,
additional_request_shape: Optional[Dict[str, str]] = None,
content_path: str = "`successfully created session.`",
):
"""Register a handler for session creation with custom request/response transformations.

This decorator creates a session handler that transforms incoming requests to include
the session ID and extracts the session ID from the engine's response.

Args:
engine_response_session_id_path: JMESPath expression specifying where to extract
the session ID from the engine's response. Must
include a prefix indicating the source location:

- "body.session_id" - extract from response body
- "headers.X-Session-Id" - extract from response headers

The extracted session ID is placed in the SageMaker
response body for the client.

engine_request_session_id_path: Optional target path in the engine request body
where the session ID will be injected. The session
ID is extracted from the SageMaker session header
and placed at this path in the request sent to the
engine.

Examples: "session_id", "metadata.session_id"

If None, the session ID is not injected into the
engine request body. This is useful when the engine
manages session IDs internally and doesn't need them
in the request.

Limitation: Currently only supports injection into
the request body, not headers.

additional_request_shape: Optional dict of additional JMESPath transformations
to apply to the request. Keys are target paths in the
request body, values are source expressions. Defaults to None.

content_path: JMESPath expression for the success message in the response.
Defaults to a literal success message.

Returns:
A decorator that can be applied to engine-specific session creation handlers.

Note:
If engine_request_session_id_path appears in additional_request_shape, it will be
overwritten to ensure the session ID is properly injected.
"""
request_shape = build_session_request_shape(
engine_request_session_id_path, additional_request_shape
)

return register_engine_session_handler(
"create_session",
request_shape=request_shape,
response_session_id_path=engine_response_session_id_path,
content_path=content_path,
)


def register_close_session_handler(
engine_request_session_id_path: str,
additional_request_shape: Optional[Dict[str, str]] = None,
content_path: str = "`successfully closed session.`",
):
"""Register a handler for session closure with custom request transformations.

This decorator creates a session handler that transforms incoming requests to include
the session ID for proper session cleanup.

Args:
engine_request_session_id_path: Required. Target path in the engine request body
where the session ID will be injected. The session
ID is extracted from the SageMaker session header
and placed at this path in the request sent to the
engine.

Examples: "session_id", "metadata.session_id"

This parameter is required because the engine needs
to know which session to close.

Limitation: Currently only supports injection into
the request body, not headers.

additional_request_shape: Optional dict of additional JMESPath transformations
to apply to the request. Keys are target paths in the
request body, values are source expressions. Defaults to None.

content_path: JMESPath expression for the success message in the response.
Defaults to a literal success message.

Returns:
A decorator that can be applied to engine-specific session closure handlers.

Raises:
ValueError: If engine_request_session_id_path is None or empty.

Note:
If engine_request_session_id_path appears in additional_request_shape, it will be
overwritten to ensure the session ID is properly injected.
"""
if not engine_request_session_id_path:
raise ValueError(
"engine_request_session_id_path is required for close_session handler. "
"The engine needs to know which session to close."
)

request_shape = build_session_request_shape(
engine_request_session_id_path, additional_request_shape
)

return register_engine_session_handler(
"close_session",
request_shape=request_shape,
content_path=content_path,
)


def bootstrap(app: FastAPI) -> FastAPI:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]:
summary="Model inference endpoint",
)

if handler_type in ["create_session", "close_session"]:
# It's a request transformer, not a standalone API endpoint
# It modifies requests in-flight but doesn't expose its own route
return None

# Delegate to LoRA route resolver for LoRA-specific handlers
return get_lora_route_config(handler_type)

Expand Down
Loading