diff --git a/docs/INTEGRATION_RUNBOOK.md b/docs/INTEGRATION_RUNBOOK.md index 60fe52d..bbaf926 100644 --- a/docs/INTEGRATION_RUNBOOK.md +++ b/docs/INTEGRATION_RUNBOOK.md @@ -1,7 +1,7 @@ # MHCS Integration Runbook **Version**: 1.0 container -**Last Updated**: November 16, 2025 +**Last Updated**: November 16, 2025 **Target Audience**: ML framework developers integrating with Amazon SageMaker --- @@ -94,7 +94,7 @@ ## 1. Introduction -### 1.1 What is MHCS? +### 1.1 What is MHCS? Model Hosting Container Standards (MHCS) is a Python library that acts as a bridge between model hosting platforms and ML inference engines with rapidly evolving APIs. It standardizes how ML frameworks integrate with hosting platforms while maintaining backwards compatibility and adapting to changing engine interfaces. @@ -276,10 +276,10 @@ async def invocations(request: Request) -> dict: """Model inference endpoint for SageMaker.""" body = await request.json() prompt = body.get("prompt", "") - + # Your framework's inference logic here result = f"Processed: {prompt}" - + return {"predictions": [result]} # Bootstrap MHCS - must be called after handler definitions @@ -354,10 +354,10 @@ async def invocations(request: Request) -> dict: body = await request.json() prompt = body.get("prompt", "") adapter_id = body.get("model", "base-model") # Injected by decorator - + # Your framework's inference logic with adapter result = f"[{adapter_id}] Processed: {prompt}" - + return {"predictions": [result], "adapter_used": adapter_id} @sagemaker_standards.register_load_adapter_handler( @@ -369,10 +369,10 @@ async def load_adapter(request: Request) -> dict: body = await request.json() adapter_name = body["adapter_name"] adapter_path = body.get("adapter_path", "") - + # Your framework's adapter loading logic loaded_adapters[adapter_name] = {"path": adapter_path, "loaded": True} - + return {"status": "success", "adapter_name": adapter_name} @sagemaker_standards.register_unload_adapter_handler( @@ -382,12 +382,12 @@ async def load_adapter(request: Request) -> dict: async def unload_adapter(request: Request) -> dict: """Unload a LoRA adapter.""" adapter_name = request.path_params.get("adapter_name") - + # Your framework's adapter unloading logic if adapter_name in loaded_adapters: del loaded_adapters[adapter_name] return {"status": "success", "adapter_name": adapter_name} - + return {"status": "not_found", "adapter_name": adapter_name} sagemaker_standards.bootstrap(app) @@ -399,7 +399,7 @@ if __name__ == "__main__": **How it works**: - The LoRA decorators use the transform decorator system under the hood (see [Section 4: Transform Decorators](#4-transform-decorators) for details). -- `@inject_adapter_id("model")` - Extracts adapter ID from `X-Amzn-SageMaker-Adapter-Identifier` header and injects it into the `model` field of the request body. +- `@inject_adapter_id("model")` - Extracts adapter ID from `X-Amzn-SageMaker-Adapter-Identifier` header and injects it into the `model` field of the request body. - `@register_load_adapter_handler` - Creates `POST /adapters` endpoint for loading adapters - `@register_unload_adapter_handler` - Creates `DELETE /adapters/{adapter_name}` endpoint for unloading adapters @@ -471,16 +471,16 @@ async def invocations(request: Request) -> dict: """Inference with session management.""" body = await request.json() prompt = body.get("prompt", "") - + # Access session data if available session_id = request.headers.get("X-Amzn-SageMaker-Session-Id") - + # Your framework's inference logic with session context if session_id: result = f"[Session {session_id}] Processed: {prompt}" else: result = f"Processed: {prompt}" - + return {"predictions": [result]} sagemaker_standards.bootstrap(app) @@ -726,10 +726,10 @@ The `bootstrap(app)` function is the central integration point that connects you ```python def bootstrap(app: FastAPI) -> FastAPI: """Configure a FastAPI application with SageMaker functionality. - + Args: app: The FastAPI application instance to configure - + Returns: The configured FastAPI app """ @@ -832,7 +832,7 @@ sequenceDiagram participant SageMaker Router participant Handler Registry participant Your Handler - + Client->>FastAPI App: GET /ping FastAPI App->>Middleware: Process request Middleware->>SageMaker Router: Route to /ping @@ -893,10 +893,10 @@ from fastapi import Request, Response @sagemaker_standards.register_ping_handler async def ping(request: Request) -> Response: """Health check handler for SageMaker. - + Args: request: FastAPI Request object containing headers, body, etc. - + Returns: Response: FastAPI Response object with status code and content """ @@ -912,19 +912,19 @@ from typing import Dict, Any @sagemaker_standards.register_invocation_handler async def invocations(request: Request) -> Dict[str, Any]: """Model inference handler for SageMaker. - + Args: request: FastAPI Request object containing the inference request - + Returns: Dict: JSON-serializable dictionary with predictions """ body = await request.json() prompt = body.get("prompt", "") - + # Your framework's inference logic result = your_model.generate(prompt) - + return {"predictions": [result]} ``` @@ -948,11 +948,11 @@ sequenceDiagram participant MHCS Router participant Handler Registry participant Your Handler - + Note over Client,Your Handler: Registration Phase (at startup) Your Handler->>Handler Registry: @register_ping_handler decorator Handler Registry->>Handler Registry: Store handler as "ping" type - + Note over Client,Your Handler: Request Phase (at runtime) Client->>FastAPI: GET /ping FastAPI->>MHCS Router: Route to /ping endpoint @@ -1028,13 +1028,13 @@ graph TD F -->|No| H{Register Decorator?} H -->|Yes| I[Use Framework Default Handler] H -->|No| J[No Handler Found] - + C --> K[Handler Resolved] E --> K G --> K I --> K J --> L[Skip Route Creation] - + style C fill:#ff6b6b style E fill:#ffa500 style G fill:#ffd93d @@ -1271,12 +1271,12 @@ This step-by-step checklist guides you through a complete MHCS integration. Foll @sagemaker_standards.register_ping_handler async def ping(request: Request) -> Response: return Response(status_code=200, content="Healthy") - + @sagemaker_standards.register_invocation_handler async def invocations(request: Request) -> dict: # Your inference logic here ... - + # Call bootstrap() last sagemaker_standards.bootstrap(app) ``` @@ -1364,13 +1364,13 @@ If your framework supports LoRA adapters, add these handlers: curl -X POST http://localhost:8000/adapters \ -H "Content-Type: application/json" \ -d '{"name": "my-adapter", "src": "/path/to/adapter"}' - + # Use adapter curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ -H "X-Amzn-SageMaker-Adapter-Identifier: my-adapter" \ -d '{"prompt": "test"}' - + # Unload adapter curl -X DELETE http://localhost:8000/adapters/my-adapter ``` @@ -1411,13 +1411,13 @@ If your framework needs stateful sessions: curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ -d '{"requestType": "NEW_SESSION"}' - + # Use session (replace with actual ID) curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ -H "X-Amzn-SageMaker-Session-Id: " \ -d '{"prompt": "test"}' - + # Close session curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ @@ -1568,15 +1568,15 @@ graph TD B -->|used by| C[create_transform_decorator Factory] C -->|creates| D[Decorator Functions] D -->|applied to| E[Your Handler Functions] - + B1[RegisterLoRAApiTransform] -.->|example| B B2[InjectToBodyApiTransform] -.->|example| B B3[SessionApiTransform] -.->|example| B - + D1[inject_adapter_id decorator] -.->|example| D D2[register_load_adapter_handler decorator] -.->|example| D D3[stateful_session_manager decorator] -.->|example| D - + style A fill:#e1f5ff style C fill:#fff4e1 style D fill:#e8f5e9 @@ -1592,15 +1592,15 @@ class BaseApiTransform: # Compiles JMESPath expressions for efficient execution self._request_shape = _compile_jmespath_expressions(request_shape) self._response_shape = _compile_jmespath_expressions(response_shape) - + def _transform(self, source_data: Dict, target_shape: Dict) -> Dict: # Applies JMESPath expressions to extract and transform data pass - + async def transform_request(self, raw_request: Request): # Subclasses implement specific request transformation logic raise NotImplementedError() - + def transform_response(self, response: Response, transform_request_output): # Subclasses implement specific response transformation logic raise NotImplementedError() @@ -1624,26 +1624,26 @@ A factory function that creates decorators dynamically: ```python def create_transform_decorator(handler_type: str, transform_resolver: Callable): """Creates a decorator factory for a specific handler type.""" - + def decorator_with_params(request_shape: Dict = None, response_shape: Dict = None): """Configures the transformation shapes.""" - + def decorator(func: Callable): """The actual decorator that wraps your handler.""" # Resolves the appropriate transform class - transformer = _resolve_transforms(handler_type, transform_resolver, + transformer = _resolve_transforms(handler_type, transform_resolver, request_shape, response_shape) - + async def decorated_func(raw_request: Request): # Apply request transformation transform_output = await transformer.transform_request(raw_request) - + # Call your handler with transformed data response = await transformer.intercept(func, transform_output) - + # Apply response transformation return transformer.transform_response(response, transform_output) - + return decorated_func return decorator return decorator_with_params @@ -1705,7 +1705,7 @@ sequenceDiagram participant FastAPI participant Transform participant Handler - + Client->>FastAPI: POST /invocations
Header: X-Amzn-SageMaker-Adapter-Identifier: my-adapter
Body: {"prompt": "..."} FastAPI->>Transform: Raw Request Transform->>Transform: Extract adapter ID from header @@ -1738,7 +1738,7 @@ sequenceDiagram } ``` -4. **Transform vs Passthrough**: +4. **Transform vs Passthrough**: - Pass `request_shape=None` for no transformation (passthrough mode) - Pass `request_shape={}` for transform infrastructure without JMESPath - Pass `request_shape={...}` for full transformation @@ -1937,7 +1937,7 @@ def create_transform_decorator( Args: handler_type: Identifier for the handler (e.g., 'register_adapter') transform_resolver: Function that maps handler_type to transform class - + Returns: Decorator factory that accepts request_shape and response_shape """ @@ -1957,15 +1957,15 @@ def decorator(func): async def wrapped_func(raw_request: Request): # 1. Transform request transform_output = await transformer.transform_request(raw_request) - + # 2. Call your handler response = await transformer.intercept(func, transform_output) - + # 3. Transform response final_response = transformer.transform_response(response, transform_output) - + return final_response - + return wrapped_func ``` @@ -1997,7 +1997,7 @@ def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: return RouteConfig(path="/ping", method="GET", ...) elif handler_type == "invoke": return RouteConfig(path="/invocations", method="POST", ...) - + # Delegate to LoRA routes for adapter handlers return get_lora_route_config(handler_type) ``` @@ -2056,28 +2056,28 @@ from model_hosting_container_standards.common import BaseApiTransform, BaseTrans class MyCustomTransform(BaseApiTransform): """Custom transform for my specific use case.""" - + def __init__(self, request_shape, response_shape={}): """Initialize with request and response shapes. - + Args: request_shape: JMESPath expressions for extracting request data response_shape: JMESPath expressions for transforming responses """ super().__init__(request_shape, response_shape) # Add any custom initialization here - + async def transform_request(self, raw_request: Request): """Transform incoming request. - + This method is called before your handler executes. Extract and validate data, then return a BaseTransformRequestOutput. """ raise NotImplementedError() - + def transform_response(self, response: Response, transform_request_output): """Transform outgoing response. - + This method is called after your handler executes. Modify the response based on the request transformation output. """ @@ -2117,7 +2117,7 @@ class MyCustomTransform(BaseApiTransform): status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}", ) - + # Step 2: Validate using Pydantic model try: validated_request = MyRequestModel.model_validate(request_data) @@ -2126,10 +2126,10 @@ class MyCustomTransform(BaseApiTransform): status_code=HTTPStatus.BAD_REQUEST.value, detail=e.json(include_url=False), ) - + # Step 3: Apply JMESPath transformations (if request_shape provided) transformed_data = self._transform_request(validated_request, raw_request) - + # Step 4: Return BaseTransformRequestOutput return BaseTransformRequestOutput( request=transformed_data, # Transformed data passed to handler @@ -2163,24 +2163,24 @@ class MyCustomTransform(BaseApiTransform): """Transform the response based on request processing.""" # Option 1: Simple passthrough (no transformation) return response - + # Option 2: Route based on status code if response.status_code == HTTPStatus.OK: return self._transform_ok_response(response, transform_request_output) else: return self._transform_error_response(response, transform_request_output) - + def _transform_ok_response(self, response: Response, transform_request_output): """Transform successful responses.""" # Extract data from request transformation adapter_name = transform_request_output.request.get("adapter_name") - + # Create custom response return Response( status_code=HTTPStatus.OK.value, content=f"Success: Processed {adapter_name}", ) - + def _transform_error_response(self, response: Response, transform_request_output): """Transform error responses.""" # Pass through or customize error responses @@ -2249,7 +2249,7 @@ from fastapi import Request ) async def my_handler(transformed_data, raw_request: Request): """Handler receives transformed data as first argument. - + Args: transformed_data: SimpleNamespace with attributes from request_shape raw_request: Original FastAPI Request for additional context @@ -2258,10 +2258,10 @@ async def my_handler(transformed_data, raw_request: Request): adapter_name = transformed_data.adapter_name adapter_path = transformed_data.adapter_path custom_header = transformed_data.custom_header - + # Your handler logic here result = f"Processing {adapter_name} from {adapter_path}" - + return {"status": "success", "message": result} ``` @@ -2295,13 +2295,13 @@ If your custom transform needs its own HTTP endpoint (not just transforming exis ```python def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: """Map handler types to their route configurations.""" - + # Core SageMaker routes if handler_type == "ping": return RouteConfig(path="/ping", method="GET", ...) elif handler_type == "invoke": return RouteConfig(path="/invocations", method="POST", ...) - + # Your custom route elif handler_type == "my_custom_operation": return RouteConfig( @@ -2310,7 +2310,7 @@ def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: response_model=None, status_code=200 ) - + # Delegate to LoRA routes for adapter handlers return get_lora_route_config(handler_type) ``` @@ -2418,14 +2418,14 @@ async def invocations(request: Request): body = await request.json() adapter_id = body.get("model") # Automatically injected from header # Your inference logic with adapter_id - + # 2. Implement adapter loading @register_load_adapter_handler( request_shape={"adapter_name": "body.name", "adapter_path": "body.src"} ) async def load_adapter(data, request): # Your framework's adapter loading logic - + # 3. Implement adapter unloading @register_unload_adapter_handler( request_shape={"adapter_name": "path_params.adapter_name"} @@ -2515,7 +2515,7 @@ import model_hosting_container_standards.sagemaker as sagemaker_standards async def invocations(request: Request) -> dict: body = await request.json() adapter_id = body.get("model") # Adapter ID is now in body["model"] - + # Your framework's inference logic result = f"Using adapter: {adapter_id}" return {"predictions": [result]} @@ -2557,10 +2557,10 @@ Append mode concatenates the adapter ID to an existing value using a separator. async def invocations(request: Request) -> dict: body = await request.json() model_with_adapter = body.get("model") # "base-model:my-adapter" - + # Parse the composite identifier base_model, adapter_id = model_with_adapter.split(":", 1) - + return {"predictions": [f"Base: {base_model}, Adapter: {adapter_id}"]} ``` @@ -2570,7 +2570,7 @@ async def invocations(request: Request) -> dict: # Incoming request body: {"prompt": "Hello", "model": "base-model"} -# After @inject_adapter_id("model", append=True, separator=":") +# After @inject_adapter_id("model", append=True, separator=":") # with header X-Amzn-SageMaker-Adapter-Identifier: my-adapter # Transformed request body: {"prompt": "Hello", "model": "base-model:my-adapter"} @@ -2832,11 +2832,11 @@ class MyFramework: def load_adapter(self, name: str, path: str, **kwargs): """Load adapter from path with given name.""" pass - + def unload_adapter(self, name: str): """Unload adapter by name.""" pass - + def has_adapter(self, name: str) -> bool: """Check if adapter is loaded.""" pass @@ -2962,7 +2962,7 @@ graph LR D -->|Response| A A -->|CLOSE + Session ID| E[Close Session] E -->|Delete Data| F[Session Removed] - + C -->|TTL Expired| G[Auto Cleanup] G -->|Delete Data| F ``` @@ -3302,13 +3302,13 @@ async def invocations(request: Request) -> dict: """Inference handler with session support.""" body = await request.json() prompt = body.get("prompt", "") - + # Access session ID if present session_id = request.headers.get("X-Amzn-SageMaker-Session-Id") - + # Your inference logic here result = f"Processed: {prompt}" - + return {"predictions": [result]} ``` @@ -3333,7 +3333,7 @@ The decorator order follows Python's decorator application rules: decorators are **What the Decorator Does:** 1. **Intercepts Session Requests**: Detects `requestType` field in request body -2. **Routes to Default Session Handlers**: +2. **Routes to Default Session Handlers**: - `NEW_SESSION` → `create_session()` handler - `CLOSE` → `close_session()` handler 3. **Validates Session IDs**: Checks session existence and expiration @@ -3549,15 +3549,15 @@ def get_session(self, session_id: str) -> Optional[Session]: with self._lock: if session_id not in self.sessions: raise ValueError(f"session not found: {session_id}") - + session = self.sessions[session_id] - + # Check expiration if session.expiration_ts is not None and time.time() > session.expiration_ts: logging.info(f"Session expired: {session_id}") self.close_session(session_id) # Automatic cleanup return None - + return session ``` @@ -3896,7 +3896,7 @@ graph TD F -->|No| H{Register decorator?} H -->|Yes| I[Use register decorator handler] H -->|No| J[Error: No handler found] - + C --> K[Create route with handler] E --> K G --> K diff --git a/python/model_hosting_container_standards/common/fastapi/utils.py b/python/model_hosting_container_standards/common/fastapi/utils.py index 3c01253..723a19f 100644 --- a/python/model_hosting_container_standards/common/fastapi/utils.py +++ b/python/model_hosting_container_standards/common/fastapi/utils.py @@ -1,6 +1,8 @@ +import json from typing import Any, Dict, Optional, Union -from fastapi import Request +from fastapi import Request, Response +from fastapi.responses import JSONResponse from pydantic import BaseModel @@ -33,3 +35,29 @@ def serialize_request( "query_params": raw_request.query_params, "path_params": raw_request.path_params, } + + +def serialize_response(response: Union[Response, JSONResponse]): + """Create a structured data dictionary for JMESPath transformations. + + Extracts and organizes response data into a standardized format that can be used + with JMESPath expressions to transform and extract specific data elements. + + :param Union[Response, JSONResponse] response: Response body data - can be: + - FastAPI Response object + - JSONResponse object + :return Dict[str, Any]: Structured data with body, headers, and status_code + """ + # Process response body based on type + body = response.body.decode(response.charset) + try: + body = json.loads(body) + except json.JSONDecodeError: + # If body is not JSON, keep it as a string + pass + + return { + "body": body, + "headers": response.headers, + "status_code": response.status_code, + } diff --git a/python/model_hosting_container_standards/common/transforms/base_api_transform.py b/python/model_hosting_container_standards/common/transforms/base_api_transform.py index f2fc35b..feeec81 100644 --- a/python/model_hosting_container_standards/common/transforms/base_api_transform.py +++ b/python/model_hosting_container_standards/common/transforms/base_api_transform.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from ...logging_config import logger -from ..fastapi.utils import serialize_request +from ..fastapi.utils import serialize_request, serialize_response from .utils import _compile_jmespath_expressions @@ -103,6 +103,19 @@ async def transform_request(self, raw_request: Request): """ raise NotImplementedError() + def _transform_response(self, response: Response): + """Transform the response based on the request processing results. + + Subclasses must implement this method to handle request parsing, validation, + and transformation according to their specific operation requirements. + + :param Response response: The response to transform + :param transform_request_output: Output from the request transformation + :raises NotImplementedError: Must be implemented by subclasses + """ + response_data = serialize_response(response) + return self._transform(response_data, self._response_shape) + def _transform_request( self, request: Optional[BaseModel], raw_request: Request ) -> Dict[str, Any]: diff --git a/python/model_hosting_container_standards/sagemaker/__init__.py b/python/model_hosting_container_standards/sagemaker/__init__.py index 54b30fc..4d56c41 100644 --- a/python/model_hosting_container_standards/sagemaker/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/__init__.py @@ -21,7 +21,12 @@ from .lora.models import AppendOperation from .sagemaker_loader import SageMakerFunctionLoader from .sagemaker_router import create_sagemaker_router -from .sessions import create_session_transform_decorator +from .sessions import ( + build_session_request_shape, + create_session_transform_decorator, + register_engine_session_handler, +) +from .sessions.models import SageMakerSessionHeader # SageMaker decorator instances - created using utility functions @@ -118,17 +123,157 @@ def inject_adapter_id( ) -def stateful_session_manager(): +def stateful_session_manager(engine_request_session_id_path: Optional[str] = None): """Create a decorator for session-based sticky routing. - This decorator enables stateful session management without JMESPath transformations. - Pass empty dicts to enable transform infrastructure (for intercept functionality) - without requiring JMESPath expressions. + This decorator enables stateful session management for regular invocation requests, + allowing the session ID to be injected into the request body for stateful inference. + + Args: + engine_request_session_id_path: Optional target path in the request body where + the session ID will be injected. The session ID + is extracted from the SageMaker session header and + placed at this path in the request sent to the engine. + + Examples: "session_id", "metadata.session_id" + + If None, session management is enabled but the + session ID is not injected into the request body. Returns: A decorator that can be applied to route handlers to enable session management """ - return create_session_transform_decorator()(request_shape={}, response_shape={}) + request_shape = {} + if engine_request_session_id_path: + request_shape[engine_request_session_id_path] = ( + f'headers."{SageMakerSessionHeader.SESSION_ID}"' + ) + return create_session_transform_decorator()( + request_shape=request_shape, response_shape={} + ) + + +def register_create_session_handler( + engine_response_session_id_path: str, + engine_request_session_id_path: Optional[str] = None, + additional_request_shape: Optional[Dict[str, str]] = None, + content_path: str = "`successfully created session.`", +): + """Register a handler for session creation with custom request/response transformations. + + This decorator creates a session handler that transforms incoming requests to include + the session ID and extracts the session ID from the engine's response. + + Args: + engine_response_session_id_path: JMESPath expression specifying where to extract + the session ID from the engine's response. Must + include a prefix indicating the source location: + + - "body.session_id" - extract from response body + - "headers.X-Session-Id" - extract from response headers + + The extracted session ID is placed in the SageMaker + response body for the client. + + engine_request_session_id_path: Optional target path in the engine request body + where the session ID will be injected. The session + ID is extracted from the SageMaker session header + and placed at this path in the request sent to the + engine. + + Examples: "session_id", "metadata.session_id" + + If None, the session ID is not injected into the + engine request body. This is useful when the engine + manages session IDs internally and doesn't need them + in the request. + + Limitation: Currently only supports injection into + the request body, not headers. + + additional_request_shape: Optional dict of additional JMESPath transformations + to apply to the request. Keys are target paths in the + request body, values are source expressions. Defaults to None. + + content_path: JMESPath expression for the success message in the response. + Defaults to a literal success message. + + Returns: + A decorator that can be applied to engine-specific session creation handlers. + + Note: + If engine_request_session_id_path appears in additional_request_shape, it will be + overwritten to ensure the session ID is properly injected. + """ + request_shape = build_session_request_shape( + engine_request_session_id_path, additional_request_shape + ) + + return register_engine_session_handler( + "create_session", + request_shape=request_shape, + response_session_id_path=engine_response_session_id_path, + content_path=content_path, + ) + + +def register_close_session_handler( + engine_request_session_id_path: str, + additional_request_shape: Optional[Dict[str, str]] = None, + content_path: str = "`successfully closed session.`", +): + """Register a handler for session closure with custom request transformations. + + This decorator creates a session handler that transforms incoming requests to include + the session ID for proper session cleanup. + + Args: + engine_request_session_id_path: Required. Target path in the engine request body + where the session ID will be injected. The session + ID is extracted from the SageMaker session header + and placed at this path in the request sent to the + engine. + + Examples: "session_id", "metadata.session_id" + + This parameter is required because the engine needs + to know which session to close. + + Limitation: Currently only supports injection into + the request body, not headers. + + additional_request_shape: Optional dict of additional JMESPath transformations + to apply to the request. Keys are target paths in the + request body, values are source expressions. Defaults to None. + + content_path: JMESPath expression for the success message in the response. + Defaults to a literal success message. + + Returns: + A decorator that can be applied to engine-specific session closure handlers. + + Raises: + ValueError: If engine_request_session_id_path is None or empty. + + Note: + If engine_request_session_id_path appears in additional_request_shape, it will be + overwritten to ensure the session ID is properly injected. + """ + if not engine_request_session_id_path: + raise ValueError( + "engine_request_session_id_path is required for close_session handler. " + "The engine needs to know which session to close." + ) + + request_shape = build_session_request_shape( + engine_request_session_id_path, additional_request_shape + ) + + return register_engine_session_handler( + "close_session", + request_shape=request_shape, + content_path=content_path, + ) def bootstrap(app: FastAPI) -> FastAPI: diff --git a/python/model_hosting_container_standards/sagemaker/sagemaker_router.py b/python/model_hosting_container_standards/sagemaker/sagemaker_router.py index 0171709..0ed2d9c 100644 --- a/python/model_hosting_container_standards/sagemaker/sagemaker_router.py +++ b/python/model_hosting_container_standards/sagemaker/sagemaker_router.py @@ -41,6 +41,11 @@ def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: summary="Model inference endpoint", ) + if handler_type in ["create_session", "close_session"]: + # It's a request transformer, not a standalone API endpoint + # It modifies requests in-flight but doesn't expose its own route + return None + # Delegate to LoRA route resolver for LoRA-specific handlers return get_lora_route_config(handler_type) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md new file mode 100644 index 0000000..082622c --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md @@ -0,0 +1,390 @@ +# Custom Session Handlers + +This guide explains how to implement custom create and close session handlers when your inference engine has its own session management API. + +## Overview + +By default, SageMaker's session management uses the built-in `SessionManager` to handle session lifecycle. However, if your inference engine provides its own session API, you can register custom handlers to delegate session operations to the engine. + +### When to Use Custom Handlers + +Use custom handlers when: +- Your engine has native session management capabilities +- You want to leverage engine-specific session features +- Session state needs to be managed within the engine's memory space +- You need custom session initialization or cleanup logic + +### Architecture + +``` +Client Request (NEW_SESSION or CLOSE) + ↓ +SessionApiTransform (detects session request) + ↓ +Handler Registry Check + ↓ + ├─→ Custom Handler (if registered) + │ └─→ Your Engine's Session API + │ + └─→ Default Handler (if not registered) + └─→ SageMaker SessionManager +``` + +## Registration API + +Use the `@register_create_session_handler` and `@register_close_session_handler` decorators to register custom handlers: + +```python +from fastapi import FastAPI, Request +from pydantic import BaseModel +from model_hosting_container_standards.sagemaker import ( + register_create_session_handler, + register_close_session_handler, + stateful_session_manager, + bootstrap +) + +app = FastAPI() + +# Define your engine's request/response models +class CreateSessionRequest(BaseModel): + capacity: int + +class CreateSessionResponse(BaseModel): + session_id: str + message: str + +# Register custom create session handler +@register_create_session_handler( + engine_response_session_id_path="body.session_id", # Extract session ID from response + engine_request_session_id_path="session_id", # Where to inject session ID in engine request + additional_request_shape={ + "capacity": "`1024`" # Additional fields to include + }, + content_path="body.message" # Extract content for logging +) +@app.post("/engine/create_session") +async def create_session(obj: CreateSessionRequest, request: Request): + # Call your engine's session creation API + session_id = await my_engine.create_session(capacity=obj.capacity) + return CreateSessionResponse(session_id=session_id, message="Session created") + +# Alternative: If your engine manages session IDs internally +@register_create_session_handler( + engine_response_session_id_path="body.session_id", # Extract session ID from response + # No engine_request_session_id_path - engine generates its own session ID + additional_request_shape={ + "capacity": "`1024`" + } +) +@app.post("/engine/create_session") +async def create_session_auto(obj: CreateSessionRequest, request: Request): + # Engine generates and returns its own session ID + session_id = await my_engine.create_session_auto(capacity=obj.capacity) + return CreateSessionResponse(session_id=session_id, message="Session created") + +# Register custom close session handler +@register_close_session_handler( + engine_request_session_id_path="session_id", # Where to inject session ID in engine request + content_path="`Session closed successfully`" # Static message +) +@app.post("/engine/close_session") +async def close_session(session_id: str, request: Request): + # Call your engine's session closure API + await my_engine.close_session(session_id) + return {"status": "closed"} + +# Your main invocations endpoint +@app.post("/invocations") +@stateful_session_manager() +async def invocations(request: Request): + # Handle regular inference requests + pass + +bootstrap(app) +``` + +## Decorator Parameters + +### `@register_create_session_handler` + +```python +@register_create_session_handler( + engine_response_session_id_path: str, # Required: Where to extract session ID from engine response + engine_request_session_id_path: str = None, # Optional: Where to inject session ID in engine request + additional_request_shape: dict = None, # Optional: Additional JMESPath mappings + content_path: str = None # Optional: JMESPath to extract content for logging +) +``` + +- **`engine_response_session_id_path`**: JMESPath expression to extract the session ID from your engine's response. Must include prefix (`body.` or `headers.`). This is **required** because the framework needs to return the session ID to the client. +- **`engine_request_session_id_path`**: Optional target path in the engine request body where the session ID will be injected. The session ID is extracted from the SageMaker session header and placed at this path. Example: `"session_id"` or `"metadata.session_id"`. If None, the session ID is not injected (useful when the engine manages sessions internally) +- **`additional_request_shape`**: Optional dict mapping target keys to source JMESPath expressions for additional fields to include in the engine request. +- **`content_path`**: Optional JMESPath expression to extract a message for logging. Defaults to a generic success message. + +### `@register_close_session_handler` + +```python +@register_close_session_handler( + engine_request_session_id_path: str, # Required: Where to inject session ID in engine request + additional_request_shape: dict = None, # Optional: Additional JMESPath mappings + content_path: str = None # Optional: JMESPath to extract content for logging +) +``` + +- **`engine_request_session_id_path`**: **Required.** Target path in the engine request body where the session ID will be injected. The session ID is extracted from the SageMaker session header and placed at this path. This is required because the engine needs to know which session to close. Example: `"session_id"` or `"metadata.session_id"` +- **`additional_request_shape`**: Optional dict mapping target keys to source JMESPath expressions for additional fields to include in the engine request. +- **`content_path`**: Optional JMESPath expression to extract a message for logging. Defaults to a generic success message. + +**Note**: `engine_response_session_id_path` is not needed for close handlers because the session ID comes from the request header, not the response. + +## How It Works + +When you register custom handlers: + +1. **Client sends session request** to `/invocations` with `{"requestType": "NEW_SESSION"}` or `{"requestType": "CLOSE"}` +2. **SessionApiTransform intercepts** the request and checks the handler registry +3. **If custom handler registered**: Request is routed to your custom endpoint (e.g., `/engine/create_session`) +4. **Transform applies**: Request/response shapes are transformed using JMESPath +5. **Response returned**: With appropriate SageMaker session headers (`X-Amzn-SageMaker-New-Session-Id` or `X-Amzn-SageMaker-Closed-Session-Id`) + +The key benefit: Your `/invocations` endpoint stays clean, and session management is handled transparently. + +## JMESPath Expressions + +The parameters use JMESPath expressions to transform data: + +### Request Transformation + +The `engine_request_session_id_path` specifies where to inject the session ID (always relative to request body): + +```python +engine_request_session_id_path="session_id" # Inject at root level +engine_request_session_id_path="metadata.session_id" # Inject in nested path +``` + +The `additional_request_shape` maps target keys to source expressions: + +```python +additional_request_shape={ + "capacity": "`1024`", # Literal value +} +``` + +### Response Extraction + +For **create session**, you must specify: +- `engine_response_session_id_path`: Where to extract the session ID from the engine's response +- `content_path`: Where to extract content for logging (optional) + +```python +engine_response_session_id_path="body.session_id" # Extract from {"session_id": "..."} +engine_response_session_id_path="body" # If response is just the session ID string +engine_response_session_id_path="headers.X-Session-Id" # Extract from response header +content_path="body.message" # Extract message from response +content_path="`Session created`" # Use literal string +``` + +For **close session**, you only need: +- `content_path`: Where to extract content for logging (optional) + +## Response Formats + +Your custom handlers can return different response formats: + +### Dictionary Response +```python +async def create_session(obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + return {"session_id": session_id, "metadata": {"engine": "custom"}} +``` + +### String Response +```python +async def create_session(obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + return session_id # Just return the session ID +``` + +### FastAPI Response Object +```python +from fastapi import Response + +async def create_session(obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + return Response( + status_code=201, + content=json.dumps({"session_id": session_id}), + media_type="application/json" + ) +``` + +## Error Handling + +Raise `HTTPException` to return errors to the client: + +```python +from fastapi.exceptions import HTTPException + +@register_create_session_handler(...) +async def create_session(obj: CreateSessionRequest, request: Request): + try: + session_id = await my_engine.create_session() + return {"session_id": session_id} + except EngineError as e: + raise HTTPException(status_code=500, detail=f"Engine error: {e}") +``` + +## Complete Example + +Here's a complete example with error handling and session tracking: + +```python +from fastapi import FastAPI, Request, HTTPException, Response +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from typing import Optional +import uuid +import json + +from model_hosting_container_standards.sagemaker import ( + register_create_session_handler, + register_close_session_handler, + stateful_session_manager, + bootstrap +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader +) + +app = FastAPI() + +# Track sessions in memory (for demo purposes) +active_sessions = {} + +class CreateSessionRequest(BaseModel): + capacity: int + session_id: Optional[str] = None + +@register_create_session_handler( + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", + additional_request_shape={ + "capacity": "`1024`" + }, + content_path="body.message" +) +@app.post("/engine/create_session") +async def create_session(obj: CreateSessionRequest, request: Request): + # Generate or use provided session ID + session_id = obj.session_id or str(uuid.uuid4()) + + # Check if session already exists + if session_id in active_sessions: + raise HTTPException(status_code=400, detail="Session already exists") + + # Create session in your engine + active_sessions[session_id] = {"capacity": obj.capacity} + + return { + "session_id": session_id, + "message": f"Session created with capacity {obj.capacity}" + } + +@register_close_session_handler( + engine_request_session_id_path="session_id", + content_path="`Session closed successfully`" +) +@app.post("/engine/close_session") +async def close_session(session_id: str, request: Request): + if session_id not in active_sessions: + raise HTTPException(status_code=404, detail="Session not found") + + # Close session in your engine + del active_sessions[session_id] + + return Response(status_code=200, content="Session closed") + +@app.post("/invocations") +@stateful_session_manager(engine_request_session_id_path="session_id") +async def invocations(request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + session_id = body.get("session_id") + + if session_id and session_id not in active_sessions: + raise HTTPException(status_code=400, detail="Invalid session") + + # Process inference request with session context + return JSONResponse({ + "result": "success", + "session_id": session_id or "no-session", + "echo": body + }) + +bootstrap(app) +``` + +## Session Validation Behavior + +When custom handlers are registered, the framework **does not** validate session IDs against the default `SessionManager`. This means: + +- **With custom handlers**: Session validation is your responsibility. The framework only routes requests to your handlers. +- **Without custom handlers** (default mode): The framework validates session IDs against the `SessionManager` automatically. + +This design allows your engine to manage sessions independently without interference from the default session manager. + +## Best Practices + +1. **Validate session IDs**: Check that the engine returns valid session IDs in create handlers +2. **Handle errors gracefully**: Use HTTPException for clear error messages +3. **Log operations**: Log session creation/closure for debugging +4. **Test thoroughly**: Test both success and failure scenarios +5. **Idempotency**: Handle duplicate close requests gracefully (return 404 or succeed silently) +6. **Session isolation**: Ensure different sessions maintain independent state +7. **Thread safety**: If your engine stores session state, ensure thread-safe access for concurrent requests + +## Troubleshooting + +### Session ID not extracted from response + +**Problem**: Getting "Engine failed to return a valid session ID" error. + +**Solution**: Check that your `engine_response_session_id_path` matches your response structure: +```python +# If your handler returns: {"session_id": "abc123"} +engine_response_session_id_path="body.session_id" + +# If your handler returns: "abc123" +engine_response_session_id_path="body" + +# If session ID is in response header +engine_response_session_id_path="headers.X-Session-Id" +``` + +### Request not reaching custom handler + +**Problem**: Custom handler not being called. + +**Solution**: Ensure you call `bootstrap(app)` **after** registering your handlers: +```python +@register_create_session_handler(...) +async def create_session(...): + pass + +bootstrap(app) # Must be after handler registration +``` + +### Session ID not injected into engine request + +**Problem**: Engine receives request without session ID. + +**Solution**: Ensure your `engine_request_session_id_path` specifies where to inject the session ID: +```python +engine_request_session_id_path="session_id" # Injects at root level of request body +``` + +## See Also + +- [README.md](./README.md) - Main sessions documentation +- [Integration tests](../../../tests/integration/test_custom_session_handlers_integration.py) - Complete working examples diff --git a/python/model_hosting_container_standards/sagemaker/sessions/README.md b/python/model_hosting_container_standards/sagemaker/sessions/README.md index 4720d91..c067e08 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/README.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/README.md @@ -2,6 +2,17 @@ This module provides stateful session management for SageMaker model hosting containers, enabling multi-turn conversations and persistent state across requests. +## Table of Contents + +- [Overview](#overview) +- [Architecture](#architecture) +- [Quick Start](#quick-start) +- [Configuration](#configuration) +- [Session Storage](#session-storage) +- [Expiration and Cleanup](#expiration-and-cleanup) +- [Advanced Usage](#advanced-usage) + - [Custom Session Handlers](./CUSTOM_HANDLERS.md) + ## Overview Stateful sessions allow clients to maintain context across multiple inference requests without passing all state in every request. Each session has: @@ -10,40 +21,73 @@ Stateful sessions allow clients to maintain context across multiple inference re - **Automatic expiration**: Configurable TTL (default: 20 minutes) - **Thread-safe access**: Concurrent request handling +### Session Management Modes + +The framework supports two modes of session management: + +1. **SageMaker-Managed Sessions** (Default) + - Sessions managed by the built-in `SessionManager` + - File-based key-value storage in `/dev/shm` + - Automatic expiration and cleanup + - Best for general-purpose session state + +2. **Engine-Managed Sessions** (Custom Handlers) + - Sessions delegated to your inference engine's native API + - Leverages engine-specific session features + - Requires custom handler registration + - Best when engine has built-in session support + - See [CUSTOM_HANDLERS.md](./CUSTOM_HANDLERS.md) for details + ## Architecture ``` -SessionApiTransform (transform.py) +Client Request to /invocations ↓ - ├─→ Session Management Request - │ ├─→ create_session (handlers.py) - │ └─→ close_session (handlers.py) +SessionApiTransform (intercepts and inspects) + ↓ + ├─→ Session Management Request (NEW_SESSION or CLOSE) + │ ├─→ Check Handler Registry + │ │ ├─→ Custom Handler (if registered) + │ │ │ └─→ Your engine's session API + │ │ └─→ Default Handler (if not registered) + │ │ └─→ SageMaker SessionManager + │ └─→ Return with session headers │ └─→ Regular Inference Request - └─→ Pass through with session context + ├─→ Validate session ID (if present) + ├─→ Inject session ID into body (if configured) + └─→ Pass to your handler ``` ### Key Components -- **`SessionManager`** (`manager.py`): Manages session lifecycle, expiration, and cleanup +- **`SessionManager`** (`manager.py`): Manages session lifecycle, expiration, and cleanup (default mode) - **`Session`** (`manager.py`): Individual session with file-based key-value storage -- **`SessionApiTransform`** (`transform.py`): API transform that intercepts session requests -- **Session Handlers** (`handlers.py`): Functions to create and close sessions +- **`SessionApiTransform`** (`transform.py`): API transform that intercepts and routes session requests +- **Handler Registry**: Routes session requests to custom or default handlers +- **Session Handlers** (`handlers.py`): Default functions to create and close sessions +- **Engine Session Transforms** (`transforms/`): Transform classes for custom engine integration - **Utilities** (`utils.py`): Helper functions for session ID extraction and retrieval -## Usage +## Quick Start ### Enabling Sessions in Your Handler -Use the `stateful_session_manager()` convenience decorator: +Use the `stateful_session_manager()` decorator on your `/invocations` endpoint: ```python -from model_hosting_container_standards.sagemaker import stateful_session_manager +from fastapi import FastAPI, Request +from model_hosting_container_standards.sagemaker import stateful_session_manager, bootstrap + +app = FastAPI() +@app.post("/invocations") @stateful_session_manager() -def my_handler(request): +async def invocations(request: Request): # Handler logic with session support pass + +bootstrap(app) ``` ### Creating a Session @@ -90,24 +134,38 @@ X-Amzn-SageMaker-Closed-Session-Id: ## Configuration -Configure via `SessionManager` properties: +Configure via environment variables: -```python -session_manager = SessionManager({ - "sessions_expiration": "1200", # TTL in seconds (default: 1200) - "sessions_path": "/dev/shm/sagemaker_sessions" # Storage path -}) +```bash +export SAGEMAKER_ENABLE_STATEFUL_SESSIONS=true +export SAGEMAKER_SESSIONS_EXPIRATION=1200 # TTL in seconds (default: 1200) +export SAGEMAKER_SESSIONS_PATH=/dev/shm/sagemaker_sessions # Storage path (optional) ``` +The session manager is automatically initialized from these environment variables when you call `bootstrap(app)`. + +**Important**: If `SAGEMAKER_ENABLE_STATEFUL_SESSIONS` is not set to `true`, session management requests will fail with a 400 error. Regular inference requests without session headers will continue to work normally. + ### Storage Location Sessions are stored in memory-backed filesystem when available: -- **Preferred**: `/dev/shm/sagemaker_sessions` (tmpfs - fast) +- **Preferred**: `/dev/shm/sagemaker_sessions` (tmpfs - fast, in-memory) - **Fallback**: `{tempdir}/sagemaker_sessions` (disk-backed) +**Note**: Session data is not persistent across container restarts. + ## Session Storage -Each session maintains its own directory with JSON files for key-value pairs. +Each session maintains its own directory with JSON files for key-value pairs: + +``` +/dev/shm/sagemaker_sessions/ +├── / +│ ├── key1.json +│ └── key2.json +└── / + └── key1.json +``` ## Expiration and Cleanup @@ -119,16 +177,32 @@ Each session maintains its own directory with JSON files for key-value pairs. ## Advanced Usage -For more control, use `create_session_transform_decorator()` directly: +### Injecting Session ID into Request Body + +If your handler needs the session ID in the request body (not just headers), use the `request_session_id_path` parameter: ```python -from model_hosting_container_standards.sagemaker.sessions import create_session_transform_decorator +@app.post("/invocations") +@stateful_session_manager(request_session_id_path="session_id") +async def invocations(request: Request): + body = await request.json() + session_id = body.get("session_id") # Automatically injected from header + # Handler logic +``` -session_transform = create_session_transform_decorator() +For nested paths, use dot notation: -@session_transform(request_shape={}, response_shape={}) -def my_handler(request, context): - pass +```python +@stateful_session_manager(request_session_id_path="metadata.session_id") +async def invocations(request: Request): + body = await request.json() + session_id = body["metadata"]["session_id"] # Injected at nested path ``` -**Note**: `SessionApiTransform` ignores the `request_shape` and `response_shape` parameters. These are passed to the parent `BaseApiTransform` class for interface compatibility, but session requests use their own validation via `SessionRequest` model instead of JMESPath transformations. +**Note**: The session ID is only injected when the `X-Amzn-SageMaker-Session-Id` header is present in the request. + +### Custom Session Handlers + +If your inference engine has its own session management API, you can register custom handlers to delegate session creation and closure to the engine instead of using SageMaker's built-in session management. + +See [CUSTOM_HANDLERS.md](./CUSTOM_HANDLERS.md) for detailed documentation on implementing custom create/close session handlers. diff --git a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py index 9d7978d..2f01312 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py @@ -1,5 +1,11 @@ +from typing import Dict, Optional + from ...common.transforms.base_factory import create_transform_decorator +from ...logging_config import logger +from .models import SageMakerSessionHeader from .transform import SessionApiTransform +from .transforms import resolve_engine_session_transform +from .transforms.constants import RESPONSE_CONTENT_KEY def resolve_session_transform(handler_type: str) -> type: @@ -20,3 +26,90 @@ def create_session_transform_decorator(): return create_transform_decorator( "stateful_session_manager", resolve_session_transform ) + + +def _create_engine_session_transform_decorator(handler_type: str): + return create_transform_decorator(handler_type, resolve_engine_session_transform) + + +def register_engine_session_handler( + handler_type: str, + request_shape, + response_session_id_path: Optional[str] = None, + content_path: Optional[str] = None, +): + """Register a handler for engine-specific session management. + + Args: + handler_type: Type of session handler ('create_session' or 'close_session') + request_shape: JMESPath expressions for transforming request data + response_session_id_path: JMESPath expression for extracting session ID FROM + the engine's response (required for 'create_session', + ignored for 'close_session') + content_path: JMESPath expression for extracting content from response + + Returns: + Decorator function for registering the session handler + + Raises: + ValueError: If handler_type is invalid or required parameters are missing + """ + # Validate handler_type + if handler_type not in ("create_session", "close_session"): + raise ValueError( + f"Invalid handler_type '{handler_type}'. " + f"Must be 'create_session' or 'close_session'" + ) + + response_shape = { + RESPONSE_CONTENT_KEY: content_path, + } + + if handler_type == "create_session": + if not response_session_id_path: + raise ValueError("response_session_id_path is required for create_session") + response_shape[SageMakerSessionHeader.NEW_SESSION_ID] = response_session_id_path + + return _create_engine_session_transform_decorator(handler_type)( + request_shape, response_shape + ) + + +def build_session_request_shape( + session_id_path: Optional[str], + additional_shape: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + """Build the request shape for session handlers with proper session ID injection. + + This helper consolidates the logic for constructing request shapes, ensuring + the session ID is always properly mapped and warning about any conflicts. + + Args: + session_id_path: Optional target path for the session ID in the request. + If None, session ID is not injected into the request. + additional_shape: Optional additional transformations to merge. + + Returns: + A complete request shape dict with session ID and any additional mappings. + """ + request_shape: Dict[str, str] = {} + + if additional_shape: + request_shape.update(additional_shape) + + # Only inject session ID if a path is specified + if session_id_path: + # Warn if session_id_path would be overwritten + if session_id_path in request_shape: + existing_value = request_shape[session_id_path] + logger.warning( + f"Session ID path '{session_id_path}' found in additional_request_shape " + f"with value '{existing_value}'. This will be overwritten with the " + f"SageMaker session header value." + ) + + request_shape[session_id_path] = ( + f'headers."{SageMakerSessionHeader.SESSION_ID}"' + ) + + return request_shape diff --git a/python/model_hosting_container_standards/sagemaker/sessions/handlers.py b/python/model_hosting_container_standards/sagemaker/sessions/handlers.py index 7ba81ac..0ff9d4a 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/handlers.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/handlers.py @@ -4,6 +4,7 @@ from fastapi import Request, Response from fastapi.exceptions import HTTPException +from ...common.handler import handler_registry from ...logging_config import logger from .manager import get_session_manager from .models import ( @@ -25,10 +26,21 @@ def get_handler_for_request_type(request_type: SessionRequestType): Handler function for the request type, or None if no handler """ if request_type == SessionRequestType.NEW_SESSION: - return create_session + registered_handler = handler_registry.get_handler("create_session") + logger.info(f"Handler for {request_type} request: {registered_handler}") + if not registered_handler: + logger.debug(f"No handler found for {request_type} request, using default") + registered_handler = create_session # Default use SageMaker system + return registered_handler elif request_type == SessionRequestType.CLOSE: - return close_session + registered_handler = handler_registry.get_handler("close_session") + logger.info(f"Handler for {request_type} request: {registered_handler}") + if not registered_handler: + logger.debug(f"No handler found for {request_type} request, using default") + registered_handler = close_session # Default use SageMaker system + return registered_handler else: + logger.warning(f"No handler found for {request_type} request") return None diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index c890e13..aff9d7e 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -7,13 +7,17 @@ from pydantic import ValidationError from ...common import BaseApiTransform, BaseTransformRequestOutput +from ...common.handler import handler_registry +from ...common.transforms.utils import set_value from ...logging_config import logger from .handlers import get_handler_for_request_type from .manager import SessionManager, get_session_manager from .models import ( SESSION_DISABLED_ERROR_DETAIL, SESSION_DISABLED_LOG_MESSAGE, + SageMakerSessionHeader, SessionRequest, + SessionRequestType, ) from .utils import get_session, get_session_id_from_request @@ -43,100 +47,62 @@ def _parse_session_request(request_data: dict) -> Optional[SessionRequest]: return None -def _validate_session_if_present( - raw_request: Request, session_manager: Optional[SessionManager] -): - """Validate that the session ID in the request exists and is not expired. - - Args: - raw_request: FastAPI Request object - session_manager: SessionManager instance - - Raises: - HTTPException: If session validation fails - """ - session_id = get_session_id_from_request(raw_request) - if session_id: - try: - get_session(session_manager, raw_request) - except ValueError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"Bad request: {str(e)}", - ) - - -def process_session_request( - request_data: dict, raw_request: Request, session_manager: Optional[SessionManager] -): - """Process a potential session management request. - - Determines if the request is a session management operation (NEW_SESSION or CLOSE) - and routes it to the appropriate handler, or passes through for normal processing. - - Args: - request_data: Parsed JSON request body - raw_request: FastAPI Request object - session_manager: SessionManager instance - - Returns: - BaseTransformRequestOutput with either: - - intercept_func set if this is a session management request - - None/passthrough if this is a regular request - - Raises: - HTTPException: If request is malformed or session validation fails - """ - session_request = _parse_session_request(request_data) - - # Validate session if session ID is present in headers - # and raise error if session ID is invalid - _validate_session_if_present(raw_request, session_manager) - - # Not a session request - pass through for normal processing - if session_request is None: - return BaseTransformRequestOutput( - raw_request=raw_request, - intercept_func=None, - ) - - if session_manager is None: - logger.error(SESSION_DISABLED_LOG_MESSAGE) - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=SESSION_DISABLED_ERROR_DETAIL, - ) - - # Route to appropriate session management handler - intercept_func = get_handler_for_request_type(session_request.requestType) - - return BaseTransformRequestOutput( - raw_request=raw_request, intercept_func=intercept_func - ) - - class SessionApiTransform(BaseApiTransform): - """API transform that intercepts and processes stateful session management requests. - - This transform extends BaseApiTransform to add session management capabilities. - It parses incoming requests to detect session management operations (NEW_SESSION, CLOSE) - and routes them to appropriate handlers, while passing through regular API requests. - """ - def __init__(self, request_shape, response_shape={}): """Initialize the SessionApiTransform. Args: - request_shape: Passed to parent BaseApiTransform (unused in session logic) - response_shape: Passed to parent BaseApiTransform (unused in session logic) + request_shape: Passed to parent BaseApiTransform + response_shape: Passed to parent BaseApiTransform Note: The request/response shapes are passed to the parent class but not used for validation in this transform, as session requests use their own validation. """ self._session_manager = get_session_manager() + + # Hybrid caching strategy for _use_default_manager: + # - If custom handlers exist at init → cache False (fast path on every request) + # - If no custom handlers at init → cache True but check dynamically (allows late registration) + # This optimizes the common case while maintaining flexibility + self._use_default_manager_cached = not handler_registry.has_handler( + "create_session" + ) and not handler_registry.has_handler("close_session") + + # Extract session_id_target_key before compiling JMESPath expressions + self._session_id_target_key = self._get_session_id_target_key(request_shape) super().__init__(request_shape, response_shape) + def _use_default_manager(self) -> bool: + """Check if default session manager should be used. + + Hybrid approach for performance: + - If custom handlers existed at init time (cached=False), return False immediately + - If no custom handlers at init (cached=True), check dynamically in case they were registered later + + This optimizes the common case (custom handlers registered before transform creation) + while still supporting late registration for flexibility. + + Returns: + bool: True if default manager should be used, False if custom handlers exist + """ + # Fast path: if custom handlers existed at init, they still exist + if not self._use_default_manager_cached: + return False + + # Slow path: no custom handlers at init, check if any were registered since + return not handler_registry.has_handler( + "create_session" + ) and not handler_registry.has_handler("close_session") + + def _get_session_id_target_key(self, request_shape: dict) -> Optional[str]: + if not request_shape: + return None + for target_key, source_path in request_shape.items(): + if source_path == f'headers."{SageMakerSessionHeader.SESSION_ID}"': + return target_key + return None + async def transform_request(self, raw_request): """Transform incoming request, intercepting session management operations. @@ -155,7 +121,7 @@ async def transform_request(self, raw_request): """ try: request_data = await raw_request.json() - return process_session_request( + return self._process_request( request_data, raw_request, self._session_manager ) except json.JSONDecodeError as e: @@ -175,3 +141,80 @@ def transform_response(self, response, transform_request_output): The unmodified response object """ return response + + def _validate_session_id(self, session_id: Optional[str], raw_request: Request): + """Validate that the session ID in the request exists and is not expired. + + Raises: + HTTPException: If session validation fails + """ + try: + get_session(self._session_manager, raw_request) + return session_id + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"Bad request: {str(e)}", + ) + + def _process_invocations_request( + self, session_id: Optional[str], request_data: dict, raw_request: Request + ): + # If not a session request + if session_id and self._use_default_manager(): + # but it has a session id header and we are using the default session manager, + # then we need to validate that the session id exists in the session manager + self._validate_session_id(session_id, raw_request) + + # Inject session ID into request body if target key is specified + if session_id and self._session_id_target_key: + request_data = set_value( + obj=request_data, + path=self._session_id_target_key, + value=session_id, + create_parent=True, + ) + logger.debug(f"Updated request body: {request_data}") + raw_request._body = json.dumps(request_data).encode("utf-8") + + return BaseTransformRequestOutput( + raw_request=raw_request, + intercept_func=None, + ) + + def _process_session_request(self, session_request, session_id, raw_request): + # Validation + if self._use_default_manager() and not self._session_manager: + # if no custom handlers are registered, but default session manager + # does not exist -> then raise error that session management is disabled + logger.error(SESSION_DISABLED_LOG_MESSAGE) + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=SESSION_DISABLED_ERROR_DETAIL, + ) + elif self._use_default_manager() and self._session_manager: + if session_request.requestType == SessionRequestType.NEW_SESSION: + # Ignores any session id header in create session request + session_id = SessionRequestType.NEW_SESSION + session_id = self._validate_session_id(session_id, raw_request) + + # Route to appropriate session management handler + intercept_func = get_handler_for_request_type(session_request.requestType) + + return BaseTransformRequestOutput( + raw_request=raw_request, intercept_func=intercept_func + ) + + def _process_request( + self, request_data, raw_request, session_manager: Optional[SessionManager] + ): + session_request = _parse_session_request(request_data) + session_id = get_session_id_from_request(raw_request) + if not session_request: + return self._process_invocations_request( + session_id, request_data, raw_request + ) + else: + return self._process_session_request( + session_request, session_id, raw_request + ) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py new file mode 100644 index 0000000..c0b27d5 --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py @@ -0,0 +1,15 @@ +from .close_session import CloseSessionApiTransform +from .create_session import CreateSessionApiTransform + + +def resolve_engine_session_transform(handler_type: str): + """Resolve the appropriate transform class for engine session handlers. + + :param str handler_type: Type of session handler ('create_session' or 'close_session') + :return: Transform class or None if handler type is not recognized + """ + if handler_type == "create_session": + return CreateSessionApiTransform + elif handler_type == "close_session": + return CloseSessionApiTransform + return None diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py new file mode 100644 index 0000000..62ab3ae --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py @@ -0,0 +1,139 @@ +import abc +import json +from http import HTTPStatus + +from fastapi import Request, Response +from fastapi.exceptions import HTTPException +from pydantic import BaseModel + +from ....common import BaseApiTransform, BaseTransformRequestOutput + + +class BaseEngineSessionApiTransform(BaseApiTransform): + """Base abstract class for engine-specific session API transformations. + + This class provides the foundation for transforming HTTP requests and responses + for engines that implement their own session management APIs. It handles common + response normalization and routing logic, while subclasses implement specific + transformation behavior for create/close session operations. + """ + + async def transform_request( + self, raw_request: Request + ) -> BaseTransformRequestOutput: + """Transform an incoming HTTP request for engine session operations. + + Parses JSON request body, applies JMESPath transformations, and validates + any session-specific requirements. Subclasses can override to add custom + validation logic before or after the base transformation. + + :param Request raw_request: The incoming FastAPI request object + :return BaseTransformRequestOutput: Transformed request data and metadata + :raises HTTPException: If JSON parsing fails or validation errors occur + """ + # Subclasses can override _validate_request_preconditions for early validation + self._validate_request_preconditions(raw_request) + + try: + request_data = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + + transformed_request = self._transform_request(request_data, raw_request) + raw_request._body = json.dumps(transformed_request).encode("utf-8") + + return BaseTransformRequestOutput( + request=transformed_request, + raw_request=raw_request, + intercept_func=None, + ) + + def _validate_request_preconditions(self, raw_request: Request) -> None: + """Validate request preconditions before transformation. + + Subclasses can override this method to perform early validation + (e.g., checking for required headers). Default implementation does nothing. + + :param Request raw_request: The incoming request to validate + :raises HTTPException: If validation fails + """ + pass + + def transform_response( + self, response: Response, transform_request_output: BaseTransformRequestOutput + ) -> Response: + """Transform the response based on the request processing results. + + Normalizes various response types to FastAPI Response objects and routes + to appropriate transformation method based on HTTP status code. + + :param Response response: The response to transform (may be Response, BaseModel, dict, or str) + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response + """ + # Normalize response to Response object + response = self._normalize_response(response) + + # Route based on status code + if response.status_code == HTTPStatus.OK.value: + return self._transform_ok_response( + response, transform_request_output=transform_request_output + ) + else: + return self._transform_error_response(response) + + def _normalize_response(self, response): + """Convert various response types to FastAPI Response object. + + Handles responses that may be BaseModel instances, dictionaries, strings, + or already Response objects. If the response doesn't have a status_code, + it's assumed to be a successful response (200 OK) from the engine handler. + + Note: This method only normalizes the response format. Validation of required + fields (like session IDs) should be done in _transform_ok_response() to provide + appropriate error responses if the engine returns invalid data. + + :param response: Response in various formats + :return Response: Normalized FastAPI Response object + """ + if not hasattr(response, "status_code"): + # Handle the case where the response is not a Response object + # Assume success if the handler returned data without explicit status + if isinstance(response, BaseModel): + response = response.model_dump_json() + elif not isinstance(response, str): + response = json.dumps(response) + response = Response( + status_code=HTTPStatus.OK.value, + content=response, + ) + return response + + @abc.abstractmethod + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Transform successful (200 OK) responses. + + Subclasses must implement this method to handle session-specific response + formatting and header management. + + :param Response response: The successful response to transform + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response + :raises NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError() + + def _transform_error_response(self, response: Response, **kwargs) -> Response: + """Transform error responses. + + Default implementation passes through error responses unchanged. + Subclasses can override to add custom error handling. + + :param Response response: The error response to transform + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response (default: unchanged) + """ + return response diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py new file mode 100644 index 0000000..1878892 --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py @@ -0,0 +1,70 @@ +from http import HTTPStatus +from typing import Any, Dict + +from fastapi import Request, Response +from fastapi.exceptions import HTTPException + +from ....common import BaseTransformRequestOutput +from ....logging_config import logger +from ..models import SageMakerSessionHeader +from ..utils import get_session_id_from_request +from .base_engine_session_api_transform import BaseEngineSessionApiTransform +from .constants import RESPONSE_CONTENT_KEY + + +class CloseSessionApiTransform(BaseEngineSessionApiTransform): + def __init__( + self, request_shape: Dict[str, Any], response_shape: Dict[str, Any] = {} + ): + try: + assert RESPONSE_CONTENT_KEY in response_shape.keys() + except AssertionError as e: + raise ValueError( + f"Response shape must contain {RESPONSE_CONTENT_KEY} key" + ) from e + + super().__init__(request_shape, response_shape) + + def _validate_request_preconditions(self, raw_request: Request) -> None: + """Validate that session ID exists in request headers before processing. + + :param Request raw_request: The incoming request to validate + :raises HTTPException: If session ID is missing from headers + """ + session_id = get_session_id_from_request(raw_request) + if not session_id: + logger.error("No session ID found in request headers for close session") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Session ID is required in request headers to close a session", + ) + + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Transform successful close session response. + + Extracts session ID from request headers and content from engine response, + validates them, and returns formatted response with CLOSED_SESSION_ID header. + + :param Response response: The successful response to transform + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response with session headers + """ + transform_request_output: BaseTransformRequestOutput = kwargs.get("transform_request_output") # type: ignore + # Session ID already validated in transform_request, safe to extract + session_id = get_session_id_from_request(transform_request_output.raw_request) + + transformed_response_data = self._transform_response(response) + content = transformed_response_data.get(RESPONSE_CONTENT_KEY) + + # Validate that content was extracted from the response + if not content: + logger.debug( + f"No content extracted from close session response for session {session_id}" + ) + + logger.info(f"Session {session_id}: {content}") + return Response( + status_code=HTTPStatus.OK.value, + content=f"Session {session_id}: {content}", + headers={SageMakerSessionHeader.CLOSED_SESSION_ID: session_id}, + ) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/constants.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/constants.py new file mode 100644 index 0000000..3c696bd --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/constants.py @@ -0,0 +1,4 @@ +"""Constants for engine session transforms.""" + +# Key used in response_shape to specify where to extract content from engine response +RESPONSE_CONTENT_KEY = "content" diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py new file mode 100644 index 0000000..f92762a --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py @@ -0,0 +1,65 @@ +from http import HTTPStatus +from typing import Any, Dict + +from fastapi import Response +from fastapi.exceptions import HTTPException + +from ....logging_config import logger +from ..models import SageMakerSessionHeader +from .base_engine_session_api_transform import BaseEngineSessionApiTransform +from .constants import RESPONSE_CONTENT_KEY + + +class CreateSessionApiTransform(BaseEngineSessionApiTransform): + def __init__( + self, request_shape: Dict[str, Any], response_shape: Dict[str, Any] = {} + ): + try: + assert SageMakerSessionHeader.NEW_SESSION_ID in response_shape.keys() + assert RESPONSE_CONTENT_KEY in response_shape.keys() + except AssertionError as e: + raise ValueError( + f"Response shape must contain {SageMakerSessionHeader.NEW_SESSION_ID} and {RESPONSE_CONTENT_KEY} keys" + ) from e + + super().__init__(request_shape, response_shape) + + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Transform successful create session response. + + Extracts session ID and content from engine response, validates them, + and returns formatted response with NEW_SESSION_ID header. + + :param Response response: The successful response to transform + :return Response: Transformed response with session headers + :raises HTTPException: If session ID cannot be extracted from response + """ + transformed_response_data = self._transform_response(response) + content = transformed_response_data.get(RESPONSE_CONTENT_KEY) + session_id = transformed_response_data.get( + SageMakerSessionHeader.NEW_SESSION_ID + ) + + # Validate that session_id was extracted from the response + if not session_id: + logger.error( + f"Failed to extract session ID from engine response. " + f"Response data: {transformed_response_data}" + ) + raise HTTPException( + status_code=HTTPStatus.BAD_GATEWAY.value, + detail="Engine failed to return a valid session ID in the response", + ) + + # Validate that content was extracted from the response + if not content: + logger.debug( + f"No content extracted from create session response for session {session_id}" + ) + + logger.info(f"Session {session_id}: {content}") + return Response( + status_code=HTTPStatus.OK.value, + content=f"Session {session_id}: {content}", + headers={SageMakerSessionHeader.NEW_SESSION_ID: session_id}, + ) diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py new file mode 100644 index 0000000..a0c0914 --- /dev/null +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -0,0 +1,901 @@ +"""Integration tests for custom session handlers functionality. + +Tests the integration of custom session handlers with engine-specific session APIs using +the proper decorator-based registration pattern: +- @register_create_session_handler decorator +- @register_close_session_handler decorator +- Mixed scenarios (custom + default handlers) +- Transform request/response shape mapping via decorators +- Error handling in custom handlers +- Handler registration and resolution + +Key Testing Pattern: + These tests simulate real-world scenarios where an inference engine + has its own session management API. We use the + proper decorators to register handlers and verify that: + 1. Decorators properly register and invoke custom handlers + 2. Transforms correctly map between SageMaker and engine formats + 3. Session lifecycle works end-to-end with custom handlers + 4. Error cases are handled gracefully +""" + +import json +import os +import shutil +import tempfile +import uuid +from typing import Optional + +import pytest +from fastapi import APIRouter, FastAPI, Request +from fastapi.exceptions import HTTPException +from fastapi.responses import Response +from fastapi.testclient import TestClient +from pydantic import BaseModel + +import model_hosting_container_standards.sagemaker as sagemaker_standards +from model_hosting_container_standards.common.handler.registry import handler_registry +from model_hosting_container_standards.sagemaker.sessions.manager import ( + init_session_manager_from_env, +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) + +DEFAULT_SESSION_ID = "default-session" + + +class CreateSessionRequest(BaseModel): + capacity_of_str_len: int + session_id: Optional[str] = None + + +class CloseSessionRequest(BaseModel): + session_id: str + + +@pytest.fixture(autouse=True) +def enable_sessions_for_integration(monkeypatch): + """Automatically enable sessions for all integration tests in this module.""" + temp_dir = tempfile.mkdtemp() + + monkeypatch.setenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", "true") + monkeypatch.setenv("SAGEMAKER_SESSIONS_PATH", temp_dir) + monkeypatch.setenv("SAGEMAKER_SESSIONS_EXPIRATION", "600") + + init_session_manager_from_env() + + yield + + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_PATH", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_EXPIRATION", raising=False) + init_session_manager_from_env() + + +@pytest.fixture(autouse=True) +def cleanup_handler_registry(): + """Clean up handler registry after each test.""" + yield + handler_registry.remove_handler("create_session") + handler_registry.remove_handler("close_session") + + +def extract_session_id_from_header(header_value: str) -> str: + """Extract session ID from SageMaker session header.""" + if ";" in header_value: + return header_value.split(";")[0].strip() + return header_value.strip() + + +class BaseCustomHandlerIntegrationTest: + """Base class for custom handler integration tests with common setup. + + Provides: + - FastAPI app and router setup + - Mock engine client for simulating engine APIs + - Handler call tracking + - TestClient for making requests + - Common setup/teardown patterns + + Subclasses should override setup_handlers() to register their specific + custom handlers using the appropriate decorators. + """ + + def setup_method(self): + """Common setup for all custom handler integration tests.""" + self.app = FastAPI() + self.router = APIRouter() + + # Track handler invocations for verification + self.handler_calls = {"create": 0, "close": 0} + + # Setup handlers (to be overridden by subclasses) + self.setup_handlers() + + # Bootstrap the app with SageMaker standards + self.app.include_router(self.router) + sagemaker_standards.bootstrap(self.app) + self.client = TestClient(self.app) + + def setup_handlers(self): + """Override in subclasses to register custom handlers. + + This method should: + 1. Define custom handler functions + 2. Register them using @register_create_session_handler or @register_close_session_handler + 3. Set up the /invocations endpoint with @stateful_session_manager + """ + self.setup_common_handlers() + self.setup_invocation_handler() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + # Implement in child classes + pass + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + # Implement in child classes + pass + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + engine_request_session_id_path="session_id", + engine_response_session_id_path="body", + additional_request_shape={ + "capacity_of_str_len": "`1024`", + }, + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + engine_request_session_id_path="session_id", + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + async def custom_invocations(self, request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + # Extract session ID from request headers if present + session_id = body.get("session_id") or request.headers.get( + SageMakerSessionHeader.SESSION_ID + ) + return Response( + status_code=200, + content=json.dumps( + { + "message": "Request in session", + "session_id": session_id or "no-session", + "echo": body, + } + ), + ) + + def setup_invocation_handler(self): + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager() + async def invocations(request: Request): + return await self.custom_invocations(request) + + # Helper methods for common test operations + def create_session(self) -> str: + """Helper to create a session and return the session ID.""" + response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + assert response.status_code == 200 + assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers + return extract_session_id_from_header( + response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + def create_session_with_id(self, session_id: str) -> Response: + """Helper to create a session with a specific ID.""" + return self.client.post( + "/invocations", + json={"requestType": "NEW_SESSION"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + def close_session(self, session_id: str) -> Response: + """Helper to close a session.""" + return self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + def invoke_with_session(self, session_id: str, body: dict) -> Response: + """Helper to make an invocation request with a session.""" + return self.client.post( + "/invocations", + json=body, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + +class TestSimpleCreateSessionCustomHandler(BaseCustomHandlerIntegrationTest): + """Test basic custom create session handler with simple string return.""" + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + return DEFAULT_SESSION_ID + + def test_create_new_session(self): + """Test that custom handler returning a simple string works correctly. + + This validates the simplest case where a custom handler returns just a string + (the session ID) rather than a complex object. This is useful when the engine's + session API returns a simple session identifier. + """ + # Send NEW_SESSION request to trigger custom create handler + response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + + # Verify successful session creation + assert response.status_code == 200 + assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers + + # Extract session ID from response header + session_id = extract_session_id_from_header( + response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Verify the custom handler's return value (DEFAULT_SESSION_ID) is used as session ID + # This confirms the transform correctly extracted the session ID from the string response + assert session_id == DEFAULT_SESSION_ID + + +class TestErrorCreateSessionCustomHandler(BaseCustomHandlerIntegrationTest): + """Test error handling when custom create session handler fails.""" + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + raise HTTPException(status_code=400, detail="Engine failed to create session") + + def test_create_new_session_error(self): + """Test that errors from custom create handler are properly propagated. + + When the underlying engine fails to create a session (e.g., resource exhaustion, + invalid parameters), the error should be propagated to the client with appropriate + status code and error message. + """ + # Attempt to create session - custom handler will raise HTTPException + response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + + # Verify error status code is returned + assert response.status_code == 400 + + # Verify error message from custom handler is included in response + assert "Engine failed to create session" in response.text + + # Verify no session header is present on error (session was not created) + assert SageMakerSessionHeader.NEW_SESSION_ID not in response.headers + + +class TestErrorCloseSessionCustomHandler(BaseCustomHandlerIntegrationTest): + """Test error handling when custom close session handler fails.""" + + def setup_method(self): + self.sessions = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = session_id + return session_id + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + self.sessions.pop(obj.session_id) + return Response( + status_code=200, content=f"Session {obj.session_id} closed." + ) + raise HTTPException( + status_code=404, detail=f"Session {obj.session_id} does not exist." + ) + + def test_duplicate_close_session(self): + """Test that closing an already-closed session returns 404. + + This validates idempotency handling - attempting to close a session that's + already been closed should return a 404 error rather than succeeding silently. + This is important for detecting client-side bugs or race conditions. + """ + # Create a new session for testing + session_id = self.create_session() + + # First close should succeed - session exists in custom handler's storage + success_response = self.close_session(session_id) + assert success_response.status_code == 200 + assert SageMakerSessionHeader.CLOSED_SESSION_ID in success_response.headers + + # Second close should fail - session no longer exists (was removed on first close) + # Custom handler raises HTTPException(404) when session not found + duplicate_response = self.close_session(session_id) + assert duplicate_response.status_code == 404 + + +class TestCustomSessionEndToEndFlow(BaseCustomHandlerIntegrationTest): + """Test complete end-to-end flows with custom session handlers.""" + + def setup_method(self): + self.sessions = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + self.handler_calls["create"] += 1 + if not obj.session_id: + obj.session_id = str(uuid.uuid4()) + if obj.session_id in self.sessions: + return Response(status_code=400) + self.sessions[obj.session_id] = obj.session_id + return {"session_id": obj.session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + self.handler_calls["close"] += 1 + if obj.session_id not in self.sessions: + raise HTTPException( + status_code=404, detail=f"Session {obj.session_id} does not exist." + ) + self.sessions.pop(obj.session_id) + return Response(status_code=200, content=f"Session {obj.session_id} closed.") + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", # Nested + additional_request_shape={ + "capacity_of_str_len": "`1024`", + }, + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + engine_request_session_id_path="session_id", + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + def setup_invocation_handler(self): + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager( + engine_request_session_id_path="session_id" + ) + async def invocations(request: Request): + return await self.custom_invocations(request) + + def test_create_existing_session_error_handling(self): + """Test that attempting to create a session with existing ID fails. + + This validates that the custom handler properly rejects attempts to create + a session with a duplicate ID. This prevents session ID collisions and ensures + session uniqueness. + """ + # Create initial session + session_id = self.create_session() + + # Try to create another session with the same ID by passing it in the header + # Custom handler checks if session_id already exists and returns 400 if it does + header_response = self.create_session_with_id(session_id) + assert header_response.status_code == 400 + + def test_end_to_end_simple(self): + """Test complete session lifecycle: create -> use -> close. + + This is the primary happy path test that validates the full session workflow + works correctly with custom handlers. This simulates a typical client interaction + pattern for stateful ML inference (e.g., multi-turn conversation with an LLM). + """ + # Step 1: Create session via custom handler + session_id = self.create_session() + + # Step 2: Use session for inference request + # Session ID is passed in header and should be available to the handler + invoke_response = self.invoke_with_session(session_id, {"prompt": "hello"}) + assert invoke_response.status_code == 200 + # Verify session ID is echoed back, confirming session context was maintained + assert session_id in invoke_response.text + + # Step 3: Close session via custom handler + close_response = self.close_session(session_id) + assert close_response.status_code == 200 + # Verify closed session header is returned + assert SageMakerSessionHeader.CLOSED_SESSION_ID in close_response.headers + + def test_handler_call_tracking(self): + """Test that custom handlers are actually being invoked. + + This validates that the decorator registration system correctly routes session + requests to the custom handlers rather than using default handlers. The counters + prove the custom handler code is executing. + """ + # Reset counters to ensure clean state + self.handler_calls = {"create": 0, "close": 0} + + # Create session - should increment create counter + session_id = self.create_session() + assert self.handler_calls["create"] == 1 # Custom create handler was called + assert self.handler_calls["close"] == 0 # Close handler not called yet + + # Close session - should increment close counter + close_response = self.close_session(session_id) + assert close_response.status_code == 200 + assert self.handler_calls["create"] == 1 # Create counter unchanged + assert self.handler_calls["close"] == 1 # Custom close handler was called + + def test_multiple_sessions_independent_state(self): + """Test that multiple sessions maintain independent state in custom handlers. + + This validates session isolation - multiple concurrent sessions should not + interfere with each other. This is critical for multi-tenant scenarios where + different users/clients have active sessions simultaneously. + """ + # Create two independent sessions + session1_id = self.create_session() + session2_id = self.create_session() + + # Verify both sessions exist in custom handler's storage + assert session1_id in self.sessions + assert session2_id in self.sessions + # Verify sessions have unique IDs + assert session1_id != session2_id + + # Close first session only + self.close_session(session1_id) + + # Verify only first session was removed from storage + assert session1_id not in self.sessions + # Verify second session still exists and is unaffected + assert session2_id in self.sessions + + # Verify second session is still functional after first session closed + response = self.invoke_with_session(session2_id, {"prompt": "test"}) + assert response.status_code == 200 + + +class TestCustomHandlerResponseFormats(BaseCustomHandlerIntegrationTest): + """Test that custom handlers can return different response formats.""" + + def setup_method(self): + self.sessions = {} + self.response_format = "dict" # Can be "dict", "string", or "response_object" + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = True + + if self.response_format == "dict": + return {"session_id": session_id, "metadata": {"engine": "custom"}} + elif self.response_format == "string": + return session_id + elif self.response_format == "response_object": + return Response( + status_code=201, + content=json.dumps({"session_id": session_id}), + media_type="application/json", + ) + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + return Response(status_code=200, content="Closed") + + def setup_common_handlers(self): + # Use different response_session_id_path based on format + response_path = "body.session_id" if self.response_format == "dict" else "body" + + @sagemaker_standards.register_create_session_handler( + engine_request_session_id_path="session_id", + engine_response_session_id_path=response_path, + additional_request_shape={"capacity_of_str_len": "`1024`"}, + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + engine_request_session_id_path="session_id", + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + def test_dict_response_with_metadata(self): + """Test custom handler returning dict with additional metadata. + + Many engine APIs return rich response objects with metadata alongside the + session ID. This validates that the transform can extract the session ID + from a nested path while preserving other response data. + """ + self.response_format = "dict" + # Create session - handler returns {"session_id": "...", "metadata": {...}} + session_id = self.create_session() + + # Verify session was created successfully + assert session_id in self.sessions + # Verify session ID is in UUID format (36 characters with hyphens) + assert len(session_id) == 36 + + def test_dict_response_with_nested_session_id(self): + """Test custom handler returning dict with nested session ID path. + + This validates that the response_session_id_path configuration correctly + extracts the session ID from nested response structures (e.g., body.session_id). + """ + self.response_format = "dict" + # Create session with nested response structure + session_id = self.create_session() + + # Verify session was created and can be used for subsequent requests + response = self.invoke_with_session(session_id, {"test": "data"}) + assert response.status_code == 200 + + +class TestCustomHandlerMultipleInvocations(BaseCustomHandlerIntegrationTest): + """Test multiple invocations within the same session with custom handlers.""" + + def setup_method(self): + self.sessions = {} + self.invocation_counts = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = {"created": True} + self.invocation_counts[session_id] = 0 + return {"session_id": session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + if obj.session_id in self.invocation_counts: + del self.invocation_counts[obj.session_id] + return Response(status_code=200) + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", + additional_request_shape={"capacity_of_str_len": "`1024`"}, + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + engine_request_session_id_path="session_id", + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + async def custom_invocations(self, request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + session_id = request.headers.get(SageMakerSessionHeader.SESSION_ID) + + # Track invocation count per session + if session_id and session_id in self.invocation_counts: + self.invocation_counts[session_id] += 1 + + return Response( + status_code=200, + content=json.dumps( + { + "message": "success", + "session_id": session_id, + "invocation_count": self.invocation_counts.get(session_id, 0), + "echo": body, + } + ), + ) + + def test_multiple_invocations_same_session(self): + """Test that multiple invocations work correctly within the same session. + + This validates that session state (invocation count) accumulates correctly + across multiple requests. This is essential for stateful ML scenarios like + maintaining conversation context or tracking request history. + """ + session_id = self.create_session() + + # Make 5 sequential invocations to the same session + for i in range(5): + response = self.invoke_with_session(session_id, {"request_num": i + 1}) + assert response.status_code == 200 + data = json.loads(response.text) + # Verify invocation count increments with each request + assert data["invocation_count"] == i + 1 + # Verify session ID remains consistent + assert data["session_id"] == session_id + + def test_invocation_counts_independent_across_sessions(self): + """Test that invocation counts are independent across different sessions. + + This validates session isolation at the invocation level - each session + maintains its own independent counter. Critical for ensuring one user's + session activity doesn't affect another user's session. + """ + # Create two separate sessions + session1_id = self.create_session() + session2_id = self.create_session() + + # Make 3 invocations to session 1 + for i in range(3): + self.invoke_with_session(session1_id, {"msg": "session1"}) + + # Make 5 invocations to session 2 + for i in range(5): + self.invoke_with_session(session2_id, {"msg": "session2"}) + + # Verify each session has its own independent count + assert self.invocation_counts[session1_id] == 3 + assert self.invocation_counts[session2_id] == 5 + + +class TestCustomHandlerWithSessionIdInjection(BaseCustomHandlerIntegrationTest): + """Test custom handlers with request_session_id_path parameter.""" + + def setup_method(self): + self.sessions = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = {"created": True} + return {"session_id": session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + return Response(status_code=200, content="Session closed") + raise HTTPException(status_code=404, detail="Session not found") + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", + additional_request_shape={"capacity_of_str_len": "`1024`"}, + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + engine_request_session_id_path="session_id", + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + def setup_invocation_handler(self): + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager( + engine_request_session_id_path="metadata.session_id" + ) + async def invocations(request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + # Extract session ID from nested path + session_id = body.get("metadata", {}).get("session_id") + + return Response( + status_code=200, + content=json.dumps( + {"message": "success", "session_id": session_id, "body": body} + ), + ) + + def test_session_id_injected_into_nested_path(self): + """Test that session ID is injected into nested path in request body. + + Some ML engines expect the session ID to be in the request body rather than + just in headers. The request_session_id_path parameter allows automatic + injection of the session ID into a specified path in the request body + (e.g., metadata.session_id). This test validates that the session ID is + correctly injected when the metadata dict already exists. + """ + # Create session + session_id = self.create_session() + + # Make request with session - note we don't include session_id in the body + # The framework should inject it automatically at metadata.session_id + response = self.invoke_with_session( + session_id, {"prompt": "test", "metadata": {"user": "test_user"}} + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was automatically injected into the nested path + assert data["session_id"] == session_id + assert data["body"]["metadata"]["session_id"] == session_id + # Verify original metadata fields are preserved + assert data["body"]["metadata"]["user"] == "test_user" + + def test_session_id_injected_creates_missing_metadata_dict(self): + """Test that session ID injection creates missing parent structures. + + When the request path expects metadata.session_id but the request doesn't + include a "metadata" dict, the set_value function should create the missing + parent structure and inject the session ID. This tests the create_parent=True + functionality in set_value. + """ + # Create session + session_id = self.create_session() + + # Make request with session - note we don't include "metadata" dict at all + # The framework should create the missing "metadata" dict and inject session_id + response = self.invoke_with_session( + session_id, {"prompt": "test", "user": "test_user"} + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was automatically injected and metadata dict was created + assert data["session_id"] == session_id + assert data["body"]["metadata"]["session_id"] == session_id + # Verify original fields at root level are preserved + assert data["body"]["prompt"] == "test" + assert data["body"]["user"] == "test_user" + + +class TestCustomHandlerSessionPersistence(BaseCustomHandlerIntegrationTest): + """Test that session state persists correctly across invocations with custom handlers.""" + + def setup_method(self): + self.sessions = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + # Store session with initial state for ML inference + self.sessions[session_id] = { + "conversation_history": [], + "inference_params": {}, + "created_at": "2024-01-01", + } + return {"session_id": session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + return Response(status_code=200) + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", + additional_request_shape={"capacity_of_str_len": "`1024`"}, + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + engine_request_session_id_path="session_id", + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + async def custom_invocations(self, request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + session_id = request.headers.get(SageMakerSessionHeader.SESSION_ID) + + # Simulate updating session state for ML inference + if session_id and session_id in self.sessions: + if "message" in body: + self.sessions[session_id]["conversation_history"].append( + body["message"] + ) + if "inference_params" in body: + self.sessions[session_id]["inference_params"].update( + body["inference_params"] + ) + + session_data = self.sessions.get(session_id, {}) + + return Response( + status_code=200, + content=json.dumps( + { + "session_id": session_id, + "conversation_history": session_data.get( + "conversation_history", [] + ), + "inference_params": session_data.get("inference_params", {}), + } + ), + ) + + def test_conversation_history_persists(self): + """Test that conversation history accumulates across invocations. + + This simulates a multi-turn conversation with an LLM where each message + is added to the session's conversation history. This is a common pattern + for chatbots and conversational AI where context from previous turns + needs to be maintained. + """ + session_id = self.create_session() + + # Send multiple messages in sequence (simulating a conversation) + messages = ["Hello", "How are you?", "Tell me a joke"] + for msg in messages: + # Each message is added to the session's conversation history + response = self.invoke_with_session(session_id, {"message": msg}) + assert response.status_code == 200 + + # Make a final request to retrieve the accumulated history + final_response = self.invoke_with_session(session_id, {}) + data = json.loads(final_response.text) + # Verify all messages were stored in order + assert data["conversation_history"] == messages + + def test_inference_parameters_persist(self): + """Test that ML inference parameters are maintained across invocations. + + This validates that ML-specific inference parameters (temperature, max_tokens, top_p) + can be set incrementally and persist across the session. This is useful for: + - LLM inference where users want consistent generation parameters + - A/B testing different parameter combinations within a session + - Gradual parameter tuning based on user feedback + """ + session_id = self.create_session() + + # Set inference parameters incrementally across multiple requests + # Temperature: controls randomness in text generation (0.0 = deterministic, 1.0 = creative) + self.invoke_with_session(session_id, {"inference_params": {"temperature": 0.7}}) + # Max tokens: limits the length of generated output + self.invoke_with_session(session_id, {"inference_params": {"max_tokens": 512}}) + # Top-p (nucleus sampling): controls diversity of token selection + self.invoke_with_session(session_id, {"inference_params": {"top_p": 0.9}}) + + # Retrieve accumulated parameters + response = self.invoke_with_session(session_id, {}) + data = json.loads(response.text) + # Verify all parameters were stored and are accessible + assert data["inference_params"]["temperature"] == 0.7 + assert data["inference_params"]["max_tokens"] == 512 + assert data["inference_params"]["top_p"] == 0.9 + + def test_session_state_cleared_after_close(self): + """Test that session state is properly cleared when session is closed. + + This validates proper cleanup of session resources. When a session is closed, + all associated state (conversation history, parameters, etc.) should be + removed to prevent memory leaks and ensure data privacy. + """ + session_id = self.create_session() + + # Add some state to the session + self.invoke_with_session(session_id, {"message": "test"}) + # Verify state was stored + assert len(self.sessions[session_id]["conversation_history"]) == 1 + + # Close the session - should trigger cleanup in custom handler + self.close_session(session_id) + + # Verify session and all its state was completely removed from storage + # This is important for memory management and data privacy + assert session_id not in self.sessions diff --git a/python/tests/integration/test_sagemaker_sessions_integration.py b/python/tests/integration/test_sagemaker_sessions_integration.py index baddf77..716d023 100644 --- a/python/tests/integration/test_sagemaker_sessions_integration.py +++ b/python/tests/integration/test_sagemaker_sessions_integration.py @@ -635,5 +635,225 @@ def test_regular_requests_with_session_header_when_disabled( assert SESSION_DISABLED_ERROR_DETAIL in response.text -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +class TestSessionIdPathInjection(BaseSessionIntegrationTest): + """Test request_session_id_path parameter for injecting session ID into request body.""" + + def setup_handlers(self): + """Define handlers with request_session_id_path parameter.""" + + @self.router.post("/invocations-with-path") + @sagemaker_standards.stateful_session_manager( + engine_request_session_id_path="session_id" + ) + async def invocations_with_path(request: Request): + """Handler that injects session ID into request body at 'session_id' key.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + # Capture for test verification + self.capture.capture( + "invocation_with_path", body.get("session_id"), {"body": body} + ) + + return Response( + status_code=200, + content=json.dumps( + { + "message": "success", + "session_id_from_body": body.get("session_id"), + "echo": body, + } + ), + ) + + @self.router.post("/invocations-nested-path") + @sagemaker_standards.stateful_session_manager( + engine_request_session_id_path="metadata.session_id" + ) + async def invocations_nested_path(request: Request): + """Handler that injects session ID into nested path in request body.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + # Capture for test verification + session_id = ( + body.get("metadata", {}).get("session_id") + if isinstance(body.get("metadata"), dict) + else None + ) + self.capture.capture("invocation_nested_path", session_id, {"body": body}) + + return Response( + status_code=200, + content=json.dumps( + { + "message": "success", + "session_id_from_body": session_id, + "echo": body, + } + ), + ) + + def test_session_id_injected_into_body(self): + """Test that session ID from header is injected into request body.""" + # Create a session + create_response = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make request with session ID in header + self.capture.clear() + response = self.client.post( + "/invocations-with-path", + json={"prompt": "test request"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was injected into body + assert data["session_id_from_body"] == session_id + assert data["echo"]["session_id"] == session_id + assert data["echo"]["prompt"] == "test request" + + def test_session_id_injected_into_nested_path(self): + """Test that session ID is injected into nested path in request body.""" + # Create a session + create_response = self.client.post( + "/invocations-nested-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make request with session ID in header + self.capture.clear() + response = self.client.post( + "/invocations-nested-path", + json={"prompt": "test request", "metadata": {"user": "test"}}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was injected into nested path + assert data["session_id_from_body"] == session_id + assert data["echo"]["metadata"]["session_id"] == session_id + assert data["echo"]["metadata"]["user"] == "test" + assert data["echo"]["prompt"] == "test request" + + def test_session_id_not_injected_without_header(self): + """Test that session ID is not injected when header is not present.""" + response = self.client.post( + "/invocations-with-path", + json={"prompt": "test request"}, + # No session header + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was not injected + assert data["session_id_from_body"] is None + assert "session_id" not in data["echo"] or data["echo"]["session_id"] is None + + def test_session_id_injection_with_multiple_requests(self): + """Test that session ID injection works across multiple requests.""" + # Create a session + create_response = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make multiple requests with the same session ID + for i in range(3): + response = self.client.post( + "/invocations-with-path", + json={"prompt": f"request {i+1}"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + assert data["session_id_from_body"] == session_id + + def test_different_sessions_inject_different_ids(self): + """Test that different sessions inject their respective IDs.""" + # Create two sessions + create1 = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session1_id = extract_session_id_from_header( + create1.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + create2 = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session2_id = extract_session_id_from_header( + create2.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make requests with each session + response1 = self.client.post( + "/invocations-with-path", + json={"prompt": "session 1"}, + headers={SageMakerSessionHeader.SESSION_ID: session1_id}, + ) + response2 = self.client.post( + "/invocations-with-path", + json={"prompt": "session 2"}, + headers={SageMakerSessionHeader.SESSION_ID: session2_id}, + ) + + # Verify each request got the correct session ID + data1 = json.loads(response1.text) + data2 = json.loads(response2.text) + + assert data1["session_id_from_body"] == session1_id + assert data2["session_id_from_body"] == session2_id + assert session1_id != session2_id + + def test_session_id_injection_preserves_existing_body_fields(self): + """Test that session ID injection doesn't overwrite other body fields.""" + # Create a session + create_response = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make request with multiple body fields + original_body = { + "prompt": "test", + "temperature": 0.7, + "max_tokens": 100, + "metadata": {"user": "test_user", "request_id": "123"}, + } + + response = self.client.post( + "/invocations-with-path", + json=original_body, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was added + assert data["echo"]["session_id"] == session_id + + # Verify all original fields are preserved + assert data["echo"]["prompt"] == "test" + assert data["echo"]["temperature"] == 0.7 + assert data["echo"]["max_tokens"] == 100 + assert data["echo"]["metadata"]["user"] == "test_user" + assert data["echo"]["metadata"]["request_id"] == "123" diff --git a/python/tests/sagemaker/sessions/test_init.py b/python/tests/sagemaker/sessions/test_init.py new file mode 100644 index 0000000..fccbd13 --- /dev/null +++ b/python/tests/sagemaker/sessions/test_init.py @@ -0,0 +1,208 @@ +"""Unit tests for sessions module public API.""" + +from unittest.mock import patch + +import pytest + +from model_hosting_container_standards.sagemaker.sessions import ( + build_session_request_shape, +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) + + +class TestBuildSessionRequestShape: + """Test build_session_request_shape function.""" + + def test_creates_basic_request_shape_with_session_id_only(self): + """Test creates request shape with only session ID path.""" + result = build_session_request_shape("session_id") + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_creates_request_shape_with_nested_session_id_path(self): + """Test creates request shape with nested session ID path.""" + result = build_session_request_shape("metadata.session_id") + + assert result == { + "metadata.session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_merges_additional_shape_without_conflicts(self): + """Test merges additional shape when no conflicts exist.""" + additional = { + "capacity": "`1024`", + "model_name": "`gpt-4`", + } + + result = build_session_request_shape("session_id", additional) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "capacity": "`1024`", + "model_name": "`gpt-4`", + } + + @patch("model_hosting_container_standards.sagemaker.sessions.logger") + def test_overwrites_conflicting_session_id_path_and_warns(self, mock_logger): + """Test overwrites session ID path in additional shape and logs warning.""" + additional = { + "session_id": "some_other_value", + "capacity": "`1024`", + } + + result = build_session_request_shape("session_id", additional) + + # Session ID should be overwritten with the correct value + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "capacity": "`1024`", + } + + # Should have logged a warning + mock_logger.warning.assert_called_once() + warning_message = mock_logger.warning.call_args[0][0] + assert "session_id" in warning_message + assert "some_other_value" in warning_message + assert "overwritten" in warning_message.lower() + + def test_handles_none_additional_shape(self): + """Test handles None as additional shape gracefully.""" + result = build_session_request_shape("session_id", None) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_handles_empty_additional_shape(self): + """Test handles empty dict as additional shape.""" + result = build_session_request_shape("session_id", {}) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_preserves_all_additional_fields(self): + """Test preserves all fields from additional shape.""" + additional = { + "field1": "value1", + "field2": "value2", + "field3": "value3", + "nested.field": "nested_value", + } + + result = build_session_request_shape("session_id", additional) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "field1": "value1", + "field2": "value2", + "field3": "value3", + "nested.field": "nested_value", + } + + @patch("model_hosting_container_standards.sagemaker.sessions.logger") + def test_session_id_always_takes_precedence(self, mock_logger): + """Test session ID value always takes precedence even after merge.""" + additional = { + "session_id": "wrong_value", + "other_field": "other_value", + } + + result = build_session_request_shape("session_id", additional) + + # Verify session_id has the correct value, not the one from additional + assert result["session_id"] == f'headers."{SageMakerSessionHeader.SESSION_ID}"' + assert result["session_id"] != "wrong_value" + assert result["other_field"] == "other_value" + + def test_works_with_complex_jmespath_expressions(self): + """Test works with complex JMESPath expressions in additional shape.""" + additional = { + "model": 'headers."X-Model-Name"', + "temperature": "`0.7`", + "max_tokens": "body.parameters.max_tokens", + } + + result = build_session_request_shape("request.session_id", additional) + + assert result == { + "request.session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "model": 'headers."X-Model-Name"', + "temperature": "`0.7`", + "max_tokens": "body.parameters.max_tokens", + } + + @patch("model_hosting_container_standards.sagemaker.sessions.logger") + def test_no_warning_when_no_conflict(self, mock_logger): + """Test no warning is logged when there's no conflict.""" + additional = { + "capacity": "`1024`", + "model": "`gpt-4`", + } + + result = build_session_request_shape("session_id", additional) + + # Should not have logged any warning + mock_logger.warning.assert_not_called() + assert result["session_id"] == f'headers."{SageMakerSessionHeader.SESSION_ID}"' + + def test_none_session_id_path_returns_only_additional_shape(self): + """Test that None session_id_path returns only additional shape.""" + additional = { + "capacity": "`1024`", + "model": "`gpt-4`", + } + + result = build_session_request_shape(None, additional) + + # Should only have additional fields, no session ID + assert result == additional + assert f'headers."{SageMakerSessionHeader.SESSION_ID}"' not in result.values() + + def test_none_session_id_path_with_no_additional_shape(self): + """Test that None session_id_path with no additional shape returns empty dict.""" + result = build_session_request_shape(None, None) + + assert result == {} + + def test_none_session_id_path_with_empty_additional_shape(self): + """Test that None session_id_path with empty additional shape returns empty dict.""" + result = build_session_request_shape(None, {}) + + assert result == {} + + +class TestRegisterCloseSessionHandler: + """Test register_close_session_handler validation.""" + + def test_raises_error_when_engine_request_session_id_path_is_none(self): + """Test raises ValueError when engine_request_session_id_path is None.""" + from model_hosting_container_standards.sagemaker import ( + register_close_session_handler, + ) + + with pytest.raises( + ValueError, match="engine_request_session_id_path is required" + ): + register_close_session_handler( + engine_request_session_id_path=None, + content_path="`Session closed`", + ) + + def test_raises_error_when_engine_request_session_id_path_is_empty(self): + """Test raises ValueError when engine_request_session_id_path is empty string.""" + from model_hosting_container_standards.sagemaker import ( + register_close_session_handler, + ) + + with pytest.raises( + ValueError, match="engine_request_session_id_path is required" + ): + register_close_session_handler( + engine_request_session_id_path="", + content_path="`Session closed`", + ) diff --git a/python/tests/sagemaker/sessions/test_transform.py b/python/tests/sagemaker/sessions/test_transform.py index 70f0e53..080b08c 100644 --- a/python/tests/sagemaker/sessions/test_transform.py +++ b/python/tests/sagemaker/sessions/test_transform.py @@ -16,13 +16,12 @@ from model_hosting_container_standards.sagemaker.sessions.manager import SessionManager from model_hosting_container_standards.sagemaker.sessions.models import ( SageMakerSessionHeader, + SessionRequest, SessionRequestType, ) from model_hosting_container_standards.sagemaker.sessions.transform import ( SessionApiTransform, _parse_session_request, - _validate_session_if_present, - process_session_request, ) @@ -82,165 +81,154 @@ def test_raises_http_exception_for_extra_fields(self): assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value -class TestValidateSessionIfPresent: - """Test _validate_session_if_present function.""" +class TestValidateSessionId: + """Test _validate_session_id method.""" - def test_does_not_raise_when_no_session_id_present( - self, mock_request, mock_session_manager - ): - """Test does not raise exception when no session ID in request.""" - # Should not raise any exception - _validate_session_if_present(mock_request, mock_session_manager) - - def test_does_not_raise_when_session_id_valid(self, mock_session_manager): + def test_does_not_raise_when_session_id_valid(self, enable_sessions_env): """Test does not raise exception when session ID is valid.""" + transform = SessionApiTransform(request_shape={}, response_shape={}) mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "valid-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session_id_from_request" - ) as mock_get_id: - mock_get_id.return_value = "valid-session" - - with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session" - ) as mock_get_session: - mock_session = Mock() - mock_get_session.return_value = mock_session + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_session = Mock() + mock_get_session.return_value = mock_session - # Should not raise any exception - _validate_session_if_present(mock_request, mock_session_manager) + # Should not raise any exception + result = transform._validate_session_id("valid-session", mock_request) + assert result == "valid-session" - def test_raises_http_exception_when_session_not_found(self, mock_session_manager): + def test_raises_http_exception_when_session_not_found(self, enable_sessions_env): """Test raises HTTPException when session ID not found.""" + transform = SessionApiTransform(request_shape={}, response_shape={}) mock_request = Mock(spec=Request) mock_request.headers = { SageMakerSessionHeader.SESSION_ID: "nonexistent-session" } with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session_id_from_request" - ) as mock_get_id: - mock_get_id.return_value = "nonexistent-session" - - with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session" - ) as mock_get_session: - mock_get_session.side_effect = ValueError("session not found") + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_get_session.side_effect = ValueError("session not found") - with pytest.raises(HTTPException) as exc_info: - _validate_session_if_present(mock_request, mock_session_manager) + with pytest.raises(HTTPException) as exc_info: + transform._validate_session_id("nonexistent-session", mock_request) - assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value - def test_error_message_includes_original_error(self, mock_session_manager): + def test_error_message_includes_original_error(self, enable_sessions_env): """Test error message includes the original error message.""" + transform = SessionApiTransform(request_shape={}, response_shape={}) mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "bad-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session_id_from_request" - ) as mock_get_id: - mock_get_id.return_value = "bad-session" - - with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session" - ) as mock_get_session: - mock_get_session.side_effect = ValueError("custom error message") + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_get_session.side_effect = ValueError("custom error message") - with pytest.raises(HTTPException) as exc_info: - _validate_session_if_present(mock_request, mock_session_manager) + with pytest.raises(HTTPException) as exc_info: + transform._validate_session_id("bad-session", mock_request) - assert "custom error message" in exc_info.value.detail + assert "custom error message" in exc_info.value.detail class TestProcessSessionRequest: - """Test process_session_request function.""" - - def test_returns_passthrough_for_non_session_request( - self, mock_request, mock_session_manager - ): - """Test returns passthrough output for non-session request.""" - request_data = {"data": "regular_data"} + """Test _process_session_request method.""" - result = process_session_request( - request_data, mock_request, mock_session_manager - ) - - assert isinstance(result, BaseTransformRequestOutput) - assert result.request is None - assert result.raw_request == mock_request - assert result.intercept_func is None + @pytest.fixture + def transform(self, enable_sessions_env): + """Create SessionApiTransform instance.""" + return SessionApiTransform(request_shape={}, response_shape={}) def test_returns_create_handler_for_new_session_request( - self, mock_request, mock_session_manager + self, transform, mock_request ): """Test returns create_session handler for NEW_SESSION request.""" - request_data = {"requestType": "NEW_SESSION"} + session_request = SessionRequest(requestType=SessionRequestType.NEW_SESSION) - result = process_session_request( - request_data, mock_request, mock_session_manager - ) + result = transform._process_session_request(session_request, None, mock_request) assert isinstance(result, BaseTransformRequestOutput) - assert result.request is None assert result.raw_request == mock_request assert result.intercept_func == create_session - def test_returns_close_handler_for_close_request( - self, mock_request, mock_session_manager - ): + def test_returns_close_handler_for_close_request(self, transform, mock_request): """Test returns close_session handler for CLOSE request.""" - request_data = {"requestType": "CLOSE"} + session_request = SessionRequest(requestType=SessionRequestType.CLOSE) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "test-session"} - result = process_session_request( - request_data, mock_request, mock_session_manager - ) + with patch( + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_session = Mock() + mock_get_session.return_value = mock_session - assert isinstance(result, BaseTransformRequestOutput) - assert result.request is None - assert result.raw_request == mock_request - assert result.intercept_func == close_session + result = transform._process_session_request( + session_request, "test-session", mock_request + ) + + assert isinstance(result, BaseTransformRequestOutput) + assert result.raw_request == mock_request + assert result.intercept_func == close_session - def test_validates_session_if_session_id_present(self, mock_session_manager): + def test_validates_session_if_session_id_present(self, transform): """Test validates session when session ID is present in headers.""" - request_data = {"data": "regular_data"} mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "test-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform._validate_session_if_present" - ) as mock_validate: - process_session_request(request_data, mock_request, mock_session_manager) + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_session = Mock() + mock_get_session.return_value = mock_session + + transform._process_session_request( + SessionRequest(requestType=SessionRequestType.CLOSE), + "test-session", + mock_request, + ) - mock_validate.assert_called_once_with(mock_request, mock_session_manager) + # Should validate the session + mock_get_session.assert_called_once() - def test_raises_exception_for_invalid_session_request( - self, mock_request, mock_session_manager + def test_raises_exception_when_sessions_disabled( + self, mock_request, monkeypatch, temp_session_storage ): - """Test raises HTTPException for invalid session request.""" - request_data = {"requestType": "INVALID_TYPE"} + """Test raises HTTPException when sessions are disabled.""" + # Disable sessions + monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) + from model_hosting_container_standards.sagemaker.sessions.manager import ( + init_session_manager_from_env, + ) - with pytest.raises(HTTPException): - process_session_request(request_data, mock_request, mock_session_manager) + init_session_manager_from_env() + + transform = SessionApiTransform(request_shape={}, response_shape={}) + session_request = SessionRequest(requestType=SessionRequestType.NEW_SESSION) - def test_propagates_validation_errors(self, mock_session_manager): - """Test propagates validation errors from _validate_session_if_present.""" - request_data = {"data": "regular_data"} + with pytest.raises(HTTPException) as exc_info: + transform._process_session_request(session_request, None, mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + + def test_propagates_validation_errors(self, transform): + """Test propagates validation errors from session validation.""" mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "invalid-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform._validate_session_if_present" - ) as mock_validate: - mock_validate.side_effect = HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Session validation failed", - ) + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_get_session.side_effect = ValueError("Session not found") with pytest.raises(HTTPException) as exc_info: - process_session_request( - request_data, mock_request, mock_session_manager + transform._process_session_request( + SessionRequest(requestType=SessionRequestType.CLOSE), + "invalid-session", + mock_request, ) assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value diff --git a/python/tests/sagemaker/sessions/transforms/__init__.py b/python/tests/sagemaker/sessions/transforms/__init__.py new file mode 100644 index 0000000..98a045f --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/__init__.py @@ -0,0 +1 @@ +"""Tests for session transforms.""" diff --git a/python/tests/sagemaker/sessions/transforms/test_base_engine_session_api_transform.py b/python/tests/sagemaker/sessions/transforms/test_base_engine_session_api_transform.py new file mode 100644 index 0000000..0b7407f --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/test_base_engine_session_api_transform.py @@ -0,0 +1,417 @@ +"""Unit tests for BaseEngineSessionApiTransform.""" + +import json +from http import HTTPStatus +from unittest.mock import AsyncMock, Mock + +import pytest +from fastapi import Request, Response +from fastapi.exceptions import HTTPException +from pydantic import BaseModel + +from model_hosting_container_standards.common import BaseTransformRequestOutput +from model_hosting_container_standards.sagemaker.sessions.transforms.base_engine_session_api_transform import ( + BaseEngineSessionApiTransform, +) + + +class ConcreteTransform(BaseEngineSessionApiTransform): + """Concrete implementation for testing the abstract base class.""" + + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Simple implementation that just returns the response.""" + return response + + +class TestBaseEngineSessionApiTransformRequest: + """Test transform_request method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform( + request_shape={"field": "body.field"}, response_shape={} + ) + + @pytest.mark.asyncio + async def test_transforms_request_body(self, transform): + """Test that request body is transformed using JMESPath.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"field": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.request["field"] == "value" + assert isinstance(result, BaseTransformRequestOutput) + + @pytest.mark.asyncio + async def test_updates_raw_request_body(self, transform): + """Test that raw request _body is updated with transformed data.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"field": "value"} + mock_request.headers = {} + + await transform.transform_request(mock_request) + + updated_body = json.loads(mock_request._body.decode()) + assert updated_body == {"field": "value"} + + @pytest.mark.asyncio + async def test_handles_json_decode_error(self, transform): + """Test that JSON decode errors raise HTTPException.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.side_effect = json.JSONDecodeError("Invalid", "doc", 0) + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "JSON decode error" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_calls_validate_request_preconditions(self, transform): + """Test that _validate_request_preconditions is called.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {} + + # Mock the validation method + transform._validate_request_preconditions = Mock() + + await transform.transform_request(mock_request) + + transform._validate_request_preconditions.assert_called_once_with(mock_request) + + @pytest.mark.asyncio + async def test_validation_errors_propagate(self, transform): + """Test that validation errors from preconditions propagate.""" + mock_request = AsyncMock(spec=Request) + mock_request.headers = {} + + # Make validation raise an exception + def raise_validation_error(req): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, detail="Validation failed" + ) + + transform._validate_request_preconditions = raise_validation_error + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "Validation failed" in exc_info.value.detail + + +class TestBaseEngineSessionApiTransformResponse: + """Test transform_response method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_routes_ok_response_to_transform_ok_response(self, transform): + """Test that 200 OK responses are routed to _transform_ok_response.""" + response = Response(status_code=HTTPStatus.OK.value, content=b"success") + transform_output = Mock(spec=BaseTransformRequestOutput) + + # Mock the _transform_ok_response method + transform._transform_ok_response = Mock(return_value=response) + + result = transform.transform_response(response, transform_output) + + transform._transform_ok_response.assert_called_once() + assert result == response + + def test_routes_error_response_to_transform_error_response(self, transform): + """Test that error responses are routed to _transform_error_response.""" + response = Response( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, content=b"error" + ) + transform_output = Mock(spec=BaseTransformRequestOutput) + + # Mock the _transform_error_response method + transform._transform_error_response = Mock(return_value=response) + + result = transform.transform_response(response, transform_output) + + transform._transform_error_response.assert_called_once_with(response) + assert result == response + + def test_normalizes_response_before_routing(self, transform): + """Test that response is normalized before routing.""" + # Pass a dict instead of Response object + response_dict = {"status": "success"} + transform_output = Mock(spec=BaseTransformRequestOutput) + + result = transform.transform_response(response_dict, transform_output) + + # Should be normalized to Response object + assert isinstance(result, Response) + assert result.status_code == HTTPStatus.OK.value + + +class TestNormalizeResponse: + """Test _normalize_response method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_passes_through_response_object(self, transform): + """Test that Response objects pass through unchanged.""" + response = Response(status_code=HTTPStatus.OK.value, content=b"test") + + normalized = transform._normalize_response(response) + + assert normalized is response + + def test_normalizes_dict_to_response(self, transform): + """Test that dict is normalized to Response.""" + response_dict = {"key": "value"} + + normalized = transform._normalize_response(response_dict) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert body["key"] == "value" + + def test_normalizes_string_to_response(self, transform): + """Test that string is normalized to Response.""" + response_str = "success message" + + normalized = transform._normalize_response(response_str) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"success message" + + def test_normalizes_pydantic_model_to_response(self, transform): + """Test that Pydantic model is normalized to Response.""" + + class TestModel(BaseModel): + field1: str + field2: int + + model = TestModel(field1="test", field2=42) + + normalized = transform._normalize_response(model) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert body["field1"] == "test" + assert body["field2"] == 42 + + def test_normalizes_none_to_response(self, transform): + """Test that None is normalized to Response.""" + normalized = transform._normalize_response(None) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"null" + + def test_normalizes_list_to_response(self, transform): + """Test that list is normalized to Response.""" + response_list = [{"id": 1}, {"id": 2}] + + normalized = transform._normalize_response(response_list) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert len(body) == 2 + assert body[0]["id"] == 1 + + def test_normalizes_int_to_response(self, transform): + """Test that integer is normalized to Response.""" + normalized = transform._normalize_response(42) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"42" + + def test_normalizes_bool_to_response(self, transform): + """Test that boolean is normalized to Response.""" + normalized = transform._normalize_response(True) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"true" + + def test_normalizes_empty_dict_to_response(self, transform): + """Test that empty dict is normalized to Response.""" + normalized = transform._normalize_response({}) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"{}" + + def test_normalizes_nested_structure_to_response(self, transform): + """Test that nested structure is normalized to Response.""" + response_data = { + "session": {"id": "sess-123", "metadata": {"user": "test"}}, + "status": "active", + } + + normalized = transform._normalize_response(response_data) + + assert isinstance(normalized, Response) + body = json.loads(normalized.body) + assert body["session"]["id"] == "sess-123" + assert body["session"]["metadata"]["user"] == "test" + + def test_preserves_response_with_error_status_code(self, transform): + """Test that Response with error status code is preserved.""" + response = Response( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, content=b"error" + ) + + normalized = transform._normalize_response(response) + + assert normalized is response + assert normalized.status_code == HTTPStatus.INTERNAL_SERVER_ERROR.value + + +class TestTransformErrorResponse: + """Test _transform_error_response method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_passes_through_error_response_unchanged(self, transform): + """Test that error responses pass through unchanged by default.""" + response = Response( + status_code=HTTPStatus.NOT_FOUND.value, content=b"Not found" + ) + + result = transform._transform_error_response(response) + + assert result is response + assert result.status_code == HTTPStatus.NOT_FOUND.value + assert result.body == b"Not found" + + def test_handles_various_error_status_codes(self, transform): + """Test that various error status codes are handled.""" + error_codes = [ + HTTPStatus.BAD_REQUEST.value, + HTTPStatus.UNAUTHORIZED.value, + HTTPStatus.FORBIDDEN.value, + HTTPStatus.NOT_FOUND.value, + HTTPStatus.INTERNAL_SERVER_ERROR.value, + HTTPStatus.BAD_GATEWAY.value, + HTTPStatus.SERVICE_UNAVAILABLE.value, + ] + + for status_code in error_codes: + response = Response(status_code=status_code, content=b"error") + result = transform._transform_error_response(response) + assert result.status_code == status_code + + +class TestValidateRequestPreconditions: + """Test _validate_request_preconditions method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_default_implementation_does_nothing(self, transform): + """Test that default implementation doesn't raise exceptions.""" + mock_request = Mock(spec=Request) + mock_request.headers = {} + + # Should not raise any exception + transform._validate_request_preconditions(mock_request) + + def test_can_be_overridden_in_subclass(self): + """Test that subclasses can override validation.""" + + class CustomTransform(BaseEngineSessionApiTransform): + def _validate_request_preconditions(self, raw_request: Request) -> None: + if not raw_request.headers.get("X-Custom-Header"): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing custom header", + ) + + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + return response + + transform = CustomTransform(request_shape={}, response_shape={}) + mock_request = Mock(spec=Request) + mock_request.headers = {} + + with pytest.raises(HTTPException) as exc_info: + transform._validate_request_preconditions(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "Missing custom header" in exc_info.value.detail + + +class TestAbstractMethods: + """Test that abstract methods must be implemented.""" + + def test_transform_ok_response_must_be_implemented(self): + """Test that _transform_ok_response must be implemented by subclasses.""" + + # Try to instantiate without implementing abstract method + with pytest.raises(TypeError) as exc_info: + + class IncompleteTransform(BaseEngineSessionApiTransform): + pass + + IncompleteTransform(request_shape={}, response_shape={}) + + assert "_transform_ok_response" in str(exc_info.value) + + +class TestTransformRequestOutputStructure: + """Test the structure of transform_request output.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform( + request_shape={"param": "body.param"}, response_shape={} + ) + + @pytest.mark.asyncio + async def test_output_contains_transformed_request(self, transform): + """Test that output contains the transformed request data.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"param": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.request == {"param": "value"} + + @pytest.mark.asyncio + async def test_output_contains_raw_request(self, transform): + """Test that output contains the raw request object.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"param": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.raw_request is mock_request + + @pytest.mark.asyncio + async def test_output_intercept_func_is_none(self, transform): + """Test that intercept_func is None for base transform.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"param": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.intercept_func is None diff --git a/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py b/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py new file mode 100644 index 0000000..11a0ec8 --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py @@ -0,0 +1,304 @@ +"""Unit tests for CloseSessionApiTransform.""" + +import json +from http import HTTPStatus +from unittest.mock import AsyncMock, Mock + +import pytest +from fastapi import Request, Response +from fastapi.exceptions import HTTPException + +from model_hosting_container_standards.common import BaseTransformRequestOutput +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.close_session import ( + CloseSessionApiTransform, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.constants import ( + RESPONSE_CONTENT_KEY, +) + + +class TestCloseSessionInitialization: + """Test CloseSessionApiTransform initialization.""" + + def test_requires_content_in_response_shape(self): + """Test that initialization requires RESPONSE_CONTENT_KEY in response_shape.""" + with pytest.raises(ValueError) as exc_info: + CloseSessionApiTransform(request_shape={}, response_shape={}) + + assert RESPONSE_CONTENT_KEY in str(exc_info.value) + + def test_successful_initialization(self): + """Test successful initialization with valid response_shape.""" + transform = CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + assert transform is not None + + +class TestCloseSessionValidation: + """Test request validation.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + @pytest.mark.asyncio + async def test_requires_session_id_header(self, transform): + """Test that session ID header is required.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {} # No session ID + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "Session ID is required" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_succeeds_with_session_id_header(self, transform): + """Test that request succeeds with session ID header.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + + result = await transform.transform_request(mock_request) + + assert result is not None + + +class TestCloseSessionTransformRequest: + """Test request transformation.""" + + @pytest.fixture + def transform(self): + """Create transform with request shape.""" + return CloseSessionApiTransform( + request_shape={"reason": "body.reason"}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + @pytest.mark.asyncio + async def test_transforms_request_body(self, transform): + """Test that request body is transformed using JMESPath.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"reason": "timeout"} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + + result = await transform.transform_request(mock_request) + + assert result.request["reason"] == "timeout" + + @pytest.mark.asyncio + async def test_updates_raw_request_body(self, transform): + """Test that raw request body is updated.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"reason": "timeout"} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + + await transform.transform_request(mock_request) + + updated_body = json.loads(mock_request._body.decode()) + assert updated_body == {"reason": "timeout"} + + +class TestCloseSessionTransformResponse: + """Test response transformation.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + def test_extracts_content_and_adds_header(self, transform): + """Test that content is extracted and session ID added to headers.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": "Session closed"}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + assert b"sess-123" in result.body + assert b"Session closed" in result.body + + def test_handles_missing_content(self, transform): + """Test that missing content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + def test_passes_through_error_responses(self, transform): + """Test that error responses pass through unchanged.""" + response = Response( + status_code=HTTPStatus.NOT_FOUND.value, + content=b"Session not found", + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.NOT_FOUND.value + assert result.body == b"Session not found" + + +class TestCloseSessionEdgeCases: + """Test edge cases for CloseSessionApiTransform.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + def test_handles_none_content(self, transform): + """Test that None content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": None}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + def test_handles_empty_string_content(self, transform): + """Test that empty string content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": ""}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + def test_extracts_content_from_nested_path(self, transform): + """Test extraction of content from nested response structure.""" + transform_nested = CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.result.message"}, + ) + + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"result": {"message": "Session closed successfully"}}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-nested-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform_nested.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert ( + result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] + == "sess-nested-123" + ) + assert b"Session closed successfully" in result.body + + def test_handles_malformed_json_in_response(self, transform): + """Test that malformed JSON in response is handled gracefully. + + The serialize_response function catches JSONDecodeError and keeps the body as a string, + so malformed JSON doesn't cause the transform to fail. + """ + response = Response( + status_code=HTTPStatus.OK.value, + content=b"not valid json {{{", + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + # Should handle gracefully - malformed JSON is kept as string + result = transform.transform_response(response, transform_output) + + # Should still return a response with the session ID header + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + @pytest.mark.asyncio + async def test_validates_session_id_before_transformation(self, transform): + """Test that session ID validation happens before request transformation.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"reason": "timeout"} + mock_request.headers = {} # Missing session ID + + # Should fail validation before even attempting transformation + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + # json() should not have been called since validation failed first + mock_request.json.assert_not_called() + + @pytest.mark.asyncio + async def test_validates_empty_session_id(self, transform): + """Test that empty session ID is rejected.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: ""} + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value diff --git a/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py b/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py new file mode 100644 index 0000000..baa5b98 --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py @@ -0,0 +1,339 @@ +"""Unit tests for CreateSessionApiTransform.""" + +import json +from http import HTTPStatus +from unittest.mock import AsyncMock, Mock + +import pytest +from fastapi import Request, Response +from fastapi.exceptions import HTTPException +from pydantic import BaseModel + +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.constants import ( + RESPONSE_CONTENT_KEY, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.create_session import ( + CreateSessionApiTransform, +) + + +class TestCreateSessionInitialization: + """Test CreateSessionApiTransform initialization.""" + + def test_requires_session_id_in_response_shape(self): + """Test that initialization requires NEW_SESSION_ID in response_shape.""" + with pytest.raises(ValueError) as exc_info: + CreateSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + assert SageMakerSessionHeader.NEW_SESSION_ID in str(exc_info.value) + + def test_requires_content_in_response_shape(self): + """Test that initialization requires RESPONSE_CONTENT_KEY in response_shape.""" + with pytest.raises(ValueError) as exc_info: + CreateSessionApiTransform( + request_shape={}, + response_shape={SageMakerSessionHeader.NEW_SESSION_ID: "body.id"}, + ) + assert RESPONSE_CONTENT_KEY in str(exc_info.value) + + def test_successful_initialization(self): + """Test successful initialization with valid response_shape.""" + transform = CreateSessionApiTransform( + request_shape={"model": "body.model"}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + assert transform is not None + + +class TestCreateSessionTransformRequest: + """Test request transformation.""" + + @pytest.fixture + def transform(self): + """Create transform with request shape.""" + return CreateSessionApiTransform( + request_shape={"model": "body.model"}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + + @pytest.mark.asyncio + async def test_transforms_request_body(self, transform): + """Test that request body is transformed using JMESPath.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"model": "llama-3"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.request["model"] == "llama-3" + + @pytest.mark.asyncio + async def test_updates_raw_request_body(self, transform): + """Test that raw request body is updated with transformed data.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"model": "llama-3"} + mock_request.headers = {} + + await transform.transform_request(mock_request) + + updated_body = json.loads(mock_request._body.decode()) + assert updated_body == {"model": "llama-3"} + + @pytest.mark.asyncio + async def test_handles_invalid_json(self, transform): + """Test that invalid JSON raises HTTPException.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.side_effect = json.JSONDecodeError("Invalid", "doc", 0) + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + + +class TestCreateSessionTransformResponse: + """Test response transformation.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + + def test_extracts_session_id_from_response(self, transform): + """Test that session ID is extracted and added to headers.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "sess-123", "message": "created"}), + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-123" + assert b"sess-123" in result.body + assert b"created" in result.body + + def test_fails_when_session_id_missing(self, transform): + """Test that missing session ID raises HTTPException.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": "created"}), + ) + + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + assert "session ID" in exc_info.value.detail + + def test_fails_when_session_id_empty(self, transform): + """Test that empty session ID raises HTTPException.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "", "message": "created"}), + ) + + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + + def test_passes_through_error_responses(self, transform): + """Test that error responses pass through unchanged.""" + response = Response( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + content=b"Engine error", + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR.value + assert result.body == b"Engine error" + + +class TestCreateSessionNormalizeResponse: + """normalization.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.id", + RESPONSE_CONTENT_KEY: "body.msg", + }, + ) + + def test_normalizes_dict_response(self, transform): + """Test normalization of dict response.""" + response_dict = {"id": "sess-123", "msg": "created"} + + normalized = transform._normalize_response(response_dict) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert body["id"] == "sess-123" + + def test_normalizes_string_response(self, transform): + """Test normalizatistring response.""" + response_str = "Session created" + + normalized = transform._normalize_response(response_str) + + assert isinstance(normalized, Response) + assert normalized.body == b"Session created" + + def test_normalizes_pydantic_response(self, transform): + """Test normalization of Pydantic model response.""" + + class SessionResponse(BaseModel): + id: str + msg: str + + response_model = SessionResponse(id="sess-123", msg="created") + + normalized = transform._normalize_response(response_model) + + assert isinstance(normalized, Response) + body = json.loads(normalized.body) + assert body["id"] == "sess-123" + + def test_passes_through_response_object(self, transform): + """Test that Response objects pass through unchanged.""" + response = Response(status_code=HTTPStatus.OK.value, content=b"test") + + normalized = transform._normalize_response(response) + + assert normalized is response + + def test_normalizes_none_response(self, transform): + """Test normalization of None response.""" + normalized = transform._normalize_response(None) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"null" + + def test_normalizes_list_response(self, transform): + """Test normalization of list response.""" + response_list = [{"id": "sess-1"}, {"id": "sess-2"}] + + normalized = transform._normalize_response(response_list) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert len(body) == 2 + assert body[0]["id"] == "sess-1" + + +class TestCreateSessionEdgeCases: + """Test edge cases for CreateSessionApiTransform.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + + def test_fails_when_session_id_is_none(self, transform): + """Test that None session ID raises HTTPException.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": None, "message": "created"}), + ) + + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + + def test_handles_none_content(self, transform): + """Test that None content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "sess-123", "message": None}), + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-123" + + def test_handles_empty_string_content(self, transform): + """Test that empty string content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "sess-123", "message": ""}), + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-123" + + def test_extracts_session_id_from_nested_path(self, transform): + """Test extraction of session ID from nested response structure.""" + transform_nested = CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.data.session.id", + RESPONSE_CONTENT_KEY: "body.data.message", + }, + ) + + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps( + {"data": {"session": {"id": "sess-nested-123"}, "message": "created"}} + ), + ) + + result = transform_nested.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert ( + result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-nested-123" + ) + + def test_handles_malformed_json_in_response(self, transform): + """Test that malformed JSON in response is handled gracefully. + + The serialize_response function catches JSONDecodeError and keeps the body as a string, + but since we can't extract a session_id from a string, this should fail validation. + """ + response = Response( + status_code=HTTPStatus.OK.value, + content=b"not valid json {{{", + ) + + # Should fail because session_id cannot be extracted from malformed JSON + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + assert "session ID" in exc_info.value.detail