Skip to content

Commit fa96279

Browse files
Merge branch 'main' into restart
2 parents e9e1f20 + 5d40ceb commit fa96279

File tree

20 files changed

+762
-246
lines changed

20 files changed

+762
-246
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ ModelHostingContainerStandards/
2121
│ │ │ ├── custom_code_ref_resolver/ # Dynamic code loading
2222
│ │ │ └── handler/ # Handler specifications
2323
│ │ └── sagemaker/ # SageMaker integration
24-
│ │ └── lora/ # LoRA adapter support
24+
│ │ ├── lora/ # LoRA adapter support
25+
│ │ └── sessions/ # Stateful session management
2526
│ ├── tests/ # Package tests
2627
│ ├── examples/ # Python examples and demos
2728
│ ├── pyproject.toml # Python project configuration

python/README.md

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def invocations(request: Request) -> dict:
9898

9999
# Optional: Add LoRA adapter support
100100
@sagemaker_standards.register_invocation_handler
101-
@sagemaker_standards.inject_adapter_id("model")
101+
@sagemaker_standards.inject_adapter_id("model") # Replace mode
102102
async def invocations_with_lora(request: Request) -> dict:
103103
"""Invocation handler with LoRA adapter ID injection."""
104104
body = await request.json()
@@ -180,7 +180,8 @@ The system automatically resolves handlers in this order:
180180
@sagemaker_standards.register_invocation_handler
181181

182182
# LoRA adapter support
183-
@sagemaker_standards.inject_adapter_id("model")
183+
@sagemaker_standards.inject_adapter_id("model") # Replace mode (default)
184+
@sagemaker_standards.inject_adapter_id("model", append=True, separator=":") # Append mode
184185
```
185186
186187
### Customer Decorators (for model customization)
@@ -193,6 +194,10 @@ The system automatically resolves handlers in this order:
193194
# LoRA transform decorators
194195
@sagemaker_standards.register_load_adapter_handler(request_shape={...}, response_shape={...})
195196
@sagemaker_standards.register_unload_adapter_handler(request_shape={...}, response_shape={...})
197+
198+
# LoRA adapter injection modes
199+
@sagemaker_standards.inject_adapter_id("model") # Replace mode (default)
200+
@sagemaker_standards.inject_adapter_id("model", append=True, separator=":") # Append mode
196201
```
197202
198203
## Framework Examples
@@ -209,15 +214,24 @@ import json
209214

210215
# Create router like real vLLM does
211216
router = APIRouter()
212-
217+
@router.post("/ping", response_class=Response)
218+
@router.get("/ping", response_class=Response)
213219
@sagemaker_standards.register_ping_handler
214220
async def ping(raw_request: Request) -> Response:
215221
"""Default vLLM ping handler with automatic routing."""
216222
return Response(
217223
content='{"status": "healthy", "source": "vllm_default", "message": "Default ping from vLLM server"}',
218224
media_type="application/json",
219225
)
220-
226+
@router.post(
227+
"/invocations",
228+
dependencies=[Depends(validate_json_request)],
229+
responses={
230+
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
231+
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
232+
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
233+
},
234+
)
221235
@sagemaker_standards.register_invocation_handler
222236
@sagemaker_standards.inject_adapter_id("model")
223237
async def invocations(raw_request: Request) -> Response:
@@ -247,6 +261,32 @@ async def invocations(raw_request: Request) -> Response:
247261
media_type="application/json",
248262
)
249263

264+
# Alternative: append mode for model field
265+
@sagemaker_standards.register_invocation_handler
266+
@sagemaker_standards.inject_adapter_id("model", append=True, separator=":")
267+
async def invocations_append_mode(raw_request: Request) -> Response:
268+
"""vLLM invocation handler with adapter ID appending."""
269+
body_bytes = await raw_request.body()
270+
try:
271+
body = json.loads(body_bytes.decode()) if body_bytes else {}
272+
except (json.JSONDecodeError, UnicodeDecodeError):
273+
body = {}
274+
275+
# If body has {"model": "Qwen-7B"} and header has "my-lora"
276+
# Result will be {"model": "Qwen-7B:my-lora"}
277+
model_with_adapter = body.get("model", "base-model")
278+
279+
response_data = {
280+
"predictions": ["Generated text from vLLM"],
281+
"model_used": model_with_adapter,
282+
"message": f"Response using model: {model_with_adapter}",
283+
}
284+
285+
return Response(
286+
content=json.dumps(response_data),
287+
media_type="application/json",
288+
)
289+
250290
# Setup FastAPI app like real vLLM
251291
app = FastAPI(title="vLLM Server", version="1.0.0")
252292
app.include_router(router)

python/model_hosting_container_standards/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
- FastAPI: from .common.fastapi import EnvVars, ENV_CONFIG
66
"""
77

8-
__version__ = "0.1.2"
8+
__version__ = "0.1.4"

python/model_hosting_container_standards/common/handler/resolver.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ def resolve_handler(self, handler_type: str) -> Optional[Callable]:
220220
if handler:
221221
return handler
222222

223-
# No handler found anywhere, use the framework default
224-
handler = self.registry.get_framework_default(handler_type)
225-
if handler:
226-
logger.info(f"Use {handler_type} handler registered in framework")
227-
return handler
223+
# No handler found anywhere, let us just do nothing
224+
# handler = self.registry.get_framework_default(handler_type)
225+
# if handler:
226+
# logger.info(f"Use {handler_type} handler registered in framework")
227+
# return handler
228228

229229
logger.debug(f"No {handler_type} handler found anywhere")
230230
return None

python/model_hosting_container_standards/sagemaker/__init__.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""SageMaker integration decorators."""
22

3-
from typing import List, Optional
3+
from typing import Dict, List, Optional, Union
44

55
from fastapi import FastAPI
66

@@ -18,6 +18,7 @@
1818
SageMakerLoRAApiHeader,
1919
create_lora_transform_decorator,
2020
)
21+
from .lora.models import AppendOperation
2122
from .sagemaker_loader import SageMakerFunctionLoader
2223
from .sagemaker_router import create_sagemaker_router
2324
from .sessions import create_session_transform_decorator
@@ -52,7 +53,9 @@ def register_unload_adapter_handler(
5253
)
5354

5455

55-
def inject_adapter_id(adapter_path: str):
56+
def inject_adapter_id(
57+
adapter_path: str, append: bool = False, separator: Optional[str] = None
58+
):
5659
"""Create a decorator that injects adapter ID from SageMaker headers into request body.
5760
5861
This decorator extracts the adapter identifier from the SageMaker LoRA API header
@@ -63,27 +66,53 @@ def inject_adapter_id(adapter_path: str):
6366
adapter_path: The JSON path where the adapter ID should be injected in the
6467
request body (e.g., "model", "body.model.lora_name", etc.).
6568
Supports both simple keys and nested paths using dot notation.
69+
append: If True, appends the adapter ID to the existing value at adapter_path
70+
using the specified separator. If False (default), replaces the value.
71+
When True, separator parameter is required.
72+
Example with append=True and separator=":":
73+
{"model": "base-model"} -> {"model": "base-model:adapter-123"}
74+
separator: The separator to use when append=True. Required when append=True.
75+
Common values include ":", "-", "_", etc.
6676
6777
Returns:
6878
A decorator function that can be applied to FastAPI route handlers to
6979
automatically inject adapter IDs from headers into the request body.
7080
81+
Raises:
82+
ValueError: If adapter_path is empty or not a string, or if append=True
83+
but separator is not provided.
84+
7185
Note:
7286
This is a transform-only decorator that does not create its own route.
7387
It must be applied to existing route handlers.
7488
"""
7589
# validate and preprocess
7690
if not adapter_path:
77-
logger.exception("adapter_path cannot be empty")
91+
logger.error("adapter_path cannot be empty")
7892
raise ValueError("adapter_path cannot be empty")
7993
if not isinstance(adapter_path, str):
80-
logger.exception("adapter_path must be a string")
94+
logger.error("adapter_path must be a string")
8195
raise ValueError("adapter_path must be a string")
82-
# create request_shape
83-
request_shape = {}
84-
request_shape[adapter_path] = (
85-
f'headers."{SageMakerLoRAApiHeader.ADAPTER_IDENTIFIER}"'
86-
)
96+
if append and separator is None:
97+
logger.error(f"separator must be provided when {append=}")
98+
raise ValueError(f"separator must be provided when {append=}")
99+
if separator and not append:
100+
logger.error(f"separator is specified {separator} but {append=}")
101+
raise ValueError(f"separator is specified {separator} but {append=}")
102+
103+
# create request_shape with operation encoding
104+
request_shape: Dict[str, Union[str, AppendOperation]] = {}
105+
header_expr = f'headers."{SageMakerLoRAApiHeader.ADAPTER_IDENTIFIER}"'
106+
107+
if append:
108+
# Encode append operation as a Pydantic model
109+
request_shape[adapter_path] = AppendOperation(
110+
separator=separator, expression=header_expr
111+
)
112+
else:
113+
# Default replace behavior (backward compatible)
114+
request_shape[adapter_path] = header_expr
115+
87116
return create_lora_transform_decorator(LoRAHandlerType.INJECT_ADAPTER_ID)(
88117
request_shape=request_shape, response_shape=None
89118
)

python/model_hosting_container_standards/sagemaker/handler_resolver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def __init__(self) -> None:
9696

9797
def register_sagemaker_overrides():
9898
def set_handler(handler_type):
99-
handler_registry.set_handler(
100-
handler_type, _resolver.resolve_handler(handler_type)
101-
)
99+
handler = _resolver.resolve_handler(handler_type)
100+
if handler:
101+
handler_registry.set_handler(handler_type, handler)
102102

103103
set_handler("invoke")
104104
set_handler("ping")

python/model_hosting_container_standards/sagemaker/lora/FACTORY_USAGE.md

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -140,41 +140,54 @@ async def unload_adapter(data: SimpleNamespace, raw_request: Request):
140140
return Response(status_code=200)
141141
```
142142

143-
**3. `inject_adapter_id(adapter_path)`**
143+
**3. `inject_adapter_id(adapter_path, append=False, separator=None)`**
144144

145-
Creates a decorator for injecting adapter IDs from headers into the request body. Takes a simple string path specifying where to inject the adapter ID:
145+
Creates a decorator for injecting adapter IDs from headers into the request body. Supports both replace and append modes:
146146

147147
```python
148148
from model_hosting_container_standards.sagemaker import inject_adapter_id
149149

150+
# Replace mode (default)
150151
@inject_adapter_id("lora_id")
151-
async def inject_adapter_id(raw_request: Request):
152+
async def inject_adapter_replace(raw_request: Request):
152153
# The request body now contains the adapter ID from the header
153154
return Response(status_code=200)
155+
156+
# Append mode
157+
@inject_adapter_id("model", append=True, separator=":")
158+
async def inject_adapter_append(raw_request: Request):
159+
# Appends adapter ID to existing model field
160+
return Response(status_code=200)
154161
```
155162

156163
**How `inject_adapter_id` works:**
157-
- Takes a single `adapter_path` string parameter specifying where to inject the adapter ID in the request body
164+
- Takes an `adapter_path` string parameter specifying where to inject the adapter ID in the request body
158165
- Supports both simple keys (e.g., `"model"`) and nested paths using dot notation (e.g., `"body.model.lora_name"`)
159166
- Automatically extracts the adapter ID from the SageMaker header `X-Amzn-SageMaker-Adapter-Identifier`
160-
- Raises `ValueError` if `adapter_path` is empty or if `adapter_path` is not a string
167+
- **Replace mode (default)**: Replaces the existing value at the target path
168+
- **Append mode**: Appends the adapter ID to existing value using a separator
169+
- Raises `ValueError` if `adapter_path` is empty, not a string, or if `append=True` without `separator`
170+
171+
**Injection Modes:**
161172

162173
```python
163-
# Simple path - injects at top level
174+
# Replace mode (default)
164175
@inject_adapter_id("model")
165-
# Results in: {"model": "<adapter_id>"}
166176

167-
# Nested path - supports dot notation
168-
@inject_adapter_id("body.model.lora_name")
169-
# Results in: {"body": {"model": {"lora_name": "<adapter_id>"}}}
177+
# Append mode with colon separator
178+
@inject_adapter_id("model", append=True, separator=":")
179+
180+
# Custom separators
181+
@inject_adapter_id("model", append=True, separator="-") # Dash
182+
@inject_adapter_id("model", append=True, separator="") # Direct concatenation
170183
```
171184

172185
### Benefits of Convenience Functions
173186

174187
1. **Shorter imports**: Import from `sagemaker` instead of `sagemaker.lora.factory`
175188
2. **Clearer intent**: Function names explicitly state what they do
176189
3. **Less boilerplate**: No need to import and reference `LoRAHandlerType`
177-
4. **Built-in validation**: `inject_adapter_id` validates and auto-fills the header mapping
190+
4. **Built-in validation**: `inject_adapter_id` validates parameters and auto-fills the header mapping
178191
5. **Future-proof**: If the implementation changes, your code doesn't need updates
179192

180193
### When to Use Direct Factory Access
@@ -320,20 +333,23 @@ This example shows how to extract adapter information from HTTP headers and inje
320333
from fastapi import Request, Response
321334
from model_hosting_container_standards.sagemaker import inject_adapter_id
322335

323-
@inject_adapter_id(
324-
request_shape={
325-
"lora_id": None # Value is automatically filled with the SageMaker header
326-
}
327-
)
328-
async def inject_adapter_to_body(raw_request: Request):
329-
"""Inject adapter ID from header into request body for inference.
336+
# Replace mode example
337+
@inject_adapter_id("lora_id")
338+
async def inject_adapter_replace(raw_request: Request):
339+
"""Inject adapter ID from header into request body (replace mode).
330340
331-
This transformer modifies the request body in-place, adding the adapter ID
332-
extracted from the X-Amzn-SageMaker-Adapter-Identifier header.
341+
This transformer modifies the request body in-place, replacing the lora_id
342+
field with the adapter ID from the X-Amzn-SageMaker-Adapter-Identifier header.
333343
"""
334344
# The transformation has already modified raw_request._body
335345
# Just pass it through to the next handler
336346
return Response(status_code=200)
347+
348+
# Append mode example
349+
@inject_adapter_id("model", append=True, separator=":")
350+
async def inject_adapter_append(raw_request: Request):
351+
"""Inject adapter ID using append mode."""
352+
return Response(status_code=200)
337353
```
338354

339355
**SageMaker Request:**
@@ -488,7 +504,9 @@ bootstrap(app)
488504

489505
1. **Use the Convenience Functions:** Always use `register_load_adapter_handler`, `register_unload_adapter_handler`, and `inject_adapter_id` from the `sagemaker` module instead of directly using `create_lora_transform_decorator`. They provide better error messages, validation, and automatic header handling.
490506

491-
2. **Validate Adapter Sources:** Always validate that adapter sources are accessible and in the correct format (S3 paths, local paths, etc.).
507+
2. **Choose the Right Injection Mode:** Use `inject_adapter_id` replace mode (default) for most cases, but use append mode with appropriate separators for frameworks that expect concatenated model names.
508+
509+
3. **Validate Adapter Sources:** Always validate that adapter sources are accessible and in the correct format (S3 paths, local paths, etc.).
492510

493511
3. **Handle Adapter Loading Errors:** Wrap adapter loading in try-except blocks and return appropriate HTTP status codes:
494512
- 400 for invalid requests

0 commit comments

Comments
 (0)