Skip to content

Commit 4c052bd

Browse files
authored
Add more exact types to request context interfaces and validate them (#465)
If we don't validate ourselfs then we'd be getting internal errors now because we're using pydantic models to pass user data from request context client to server.
1 parent c02e906 commit 4c052bd

File tree

5 files changed

+41
-9
lines changed

5 files changed

+41
-9
lines changed

src/tensorlake/applications/interface/request_context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get(self, key: str, default: Any | None = None) -> Any | None:
3232
class RequestMetrics:
3333
"""Abstract interface for reporting application request metrics."""
3434

35-
def timer(self, name: str, value: float) -> None:
35+
def timer(self, name: str, value: int | float) -> None:
3636
"""Records a duration metric with the supplied name and value.
3737
3838
Raises TensorlakeError on error.
@@ -54,8 +54,8 @@ class FunctionProgress:
5454

5555
def update(
5656
self,
57-
current: float,
58-
total: float,
57+
current: int | float,
58+
total: int | float,
5959
message: str | None = None,
6060
attributes: dict[str, str] | None = None,
6161
) -> None:

src/tensorlake/applications/request_context/http_client/metrics.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ def __init__(self, request_id: str, allocation_id: str, http_client: httpx.Clien
2222
self._allocation_id: str = allocation_id
2323
self._http_client: httpx.Client = http_client
2424

25-
def timer(self, name: str, value: float):
25+
def timer(self, name: str, value: int | float):
26+
# If we don't validate user supplied inputs here then there will be a Pydantic validation error
27+
# below which will raise an InternalError instead of SDKUsageError.
28+
if not isinstance(name, str):
29+
raise SDKUsageError(f"Timer name must be a string, got: {name}")
30+
if not isinstance(value, (int, float)):
31+
raise SDKUsageError(f"Timer value must be a number, got: {value}")
32+
2633
request_payload: AddMetricsRequest = AddMetricsRequest(
2734
request_id=self._request_id,
2835
allocation_id=self._allocation_id,
@@ -32,6 +39,13 @@ def timer(self, name: str, value: float):
3239
self._run_add_request(request_payload)
3340

3441
def counter(self, name: str, value: int = 1):
42+
# If we don't validate user supplied inputs here then there will be a Pydantic validation error
43+
# below which will raise an InternalError instead of SDKUsageError.
44+
if not isinstance(name, str):
45+
raise SDKUsageError(f"Counter name must be a string, got: {name}")
46+
if not isinstance(value, int):
47+
raise SDKUsageError(f"Counter value must be an int, got: {value}")
48+
3549
request_payload: AddMetricsRequest = AddMetricsRequest(
3650
request_id=self._request_id,
3751
allocation_id=self._allocation_id,

src/tensorlake/applications/request_context/http_client/progress.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import httpx
44

5-
from ...interface.exceptions import InternalError, SDKUsageError, SerializationError
5+
from ...interface.exceptions import InternalError, SDKUsageError
66
from ...interface.request_context import FunctionProgress
77
from ..http_server.handlers.progress_update import (
88
PROGRESS_UPDATE_PATH,
@@ -24,12 +24,20 @@ def __init__(self, request_id: str, allocation_id: str, http_client: httpx.Clien
2424

2525
def update(
2626
self,
27-
current: float,
28-
total: float,
27+
current: int | float,
28+
total: int | float,
2929
message: str | None = None,
3030
attributes: dict[str, str] | None = None,
3131
) -> None:
32-
# Instead of handling serialization errors on the Server, just validate attributes on client side.
32+
# If we don't validate user supplied inputs here then there will be a Pydantic validation error
33+
# below which will raise an InternalError instead of SDKUsageError.
34+
if not isinstance(current, (int, float)):
35+
raise SDKUsageError(f"'current' needs to be a number, got: {current}")
36+
if not isinstance(total, (int, float)):
37+
raise SDKUsageError(f"'total' needs to be a number, got: {total}")
38+
if message is not None and not isinstance(message, str):
39+
raise SDKUsageError(f"'message' needs to be a string, got: {message}")
40+
3341
if attributes is not None:
3442
if not isinstance(attributes, dict):
3543
raise SDKUsageError(

src/tensorlake/applications/request_context/http_client/state.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def set(self, key: str, value: Any) -> None:
5353
# NB: This is called from user code, user code is blocked.
5454
# Any exception raised here goes directly to user code.
5555

56+
# If we don't validate user supplied inputs here then there will be a Pydantic validation error
57+
# below which will raise an InternalError instead of SDKUsageError.
58+
if not isinstance(key, str):
59+
raise SDKUsageError(f"State key must be a string, got: {key}")
5660
# Raises SerializationError to customer code on failure.
5761
serialized_value: bytes = REQUEST_STATE_USER_DATA_SERIALIZER.serialize(value)
5862

@@ -75,6 +79,12 @@ def set(self, key: str, value: Any) -> None:
7579
def get(self, key: str, default: Any | None = None) -> Any | None:
7680
# NB: This is called from user code, user code is blocked.
7781
# Any exception raised here goes directly to user code.
82+
83+
# If we don't validate user supplied inputs here then there will be a Pydantic validation error
84+
# below which will raise an InternalError instead of SDKUsageError.
85+
if not isinstance(key, str):
86+
raise SDKUsageError(f"State key must be a string, got: {key}")
87+
7888
try:
7989
blob: BLOB | None = self._get_read_only_blob(key=key)
8090
if blob is None:

src/tensorlake/applications/request_context/http_server/handlers/add_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class AddTimerRequest(BaseModel):
1010
name: str
11-
value: float
11+
value: int | float
1212

1313

1414
class AddCounterRequest(BaseModel):

0 commit comments

Comments
 (0)