Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
types:
- python
args:
- "--max-line-length=90"
- "--max-line-length=100"
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
Expand Down
2 changes: 1 addition & 1 deletion vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
) # noqa
from .vetiver_model import VetiverModel # noqa
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
from .mock import get_mock_data, get_mock_model # noqa
from .mock import get_mock_data, get_mock_model, get_mtcars_model # noqa
from .pin_read_write import vetiver_pin_write # noqa
from .attach_pkgs import load_pkgs, get_board_pkgs # noqa
from .meta import VetiverMeta # noqa
Expand Down
2 changes: 1 addition & 1 deletion vetiver/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def handler_startup():
"""
...

def handler_predict(self, input_data, check_prototype):
def handler_predict(self, input_data, check_prototype, **kw):
"""Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
Expand Down
18 changes: 12 additions & 6 deletions vetiver/handlers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class SKLearnHandler(BaseHandler):
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
pip_name = "scikit-learn"

def handler_predict(self, input_data, check_prototype):
def handler_predict(self, input_data, check_prototype: bool, **kw):
"""
Generates method for /predict endpoint in VetiverAPI

Expand All @@ -28,16 +28,22 @@ def handler_predict(self, input_data, check_prototype):
----------
input_data:
Test data
check_prototype: bool
prediction_type: str
Type of prediction to make. One of "predict", "predict_proba",
or "predict_log_proba". Default is "predict".

Returns
-------
prediction:
Prediction from model
"""
prediction_type = kw.get("prediction_type", "predict")

if not check_prototype or isinstance(input_data, pd.DataFrame):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])
input_data = (
[input_data]
if check_prototype and not isinstance(input_data, pd.DataFrame)
else input_data
)

return prediction.tolist()
return getattr(self.model, prediction_type)(input_data).tolist()
2 changes: 1 addition & 1 deletion vetiver/handlers/spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def construct_prototype(self):

return prototype

def handler_predict(self, input_data, check_prototype):
def handler_predict(self, input_data, check_prototype, **kw):
"""
Generates method for /predict endpoint in VetiverAPI

Expand Down
12 changes: 5 additions & 7 deletions vetiver/handlers/statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class StatsmodelsHandler(BaseHandler):
if sm_exists:
pip_name = "statsmodels"

def handler_predict(self, input_data, check_prototype):
def handler_predict(self, input_data, check_prototype, **kw):
"""
Generates method for /predict endpoint in VetiverAPI

Expand All @@ -43,9 +43,7 @@ def handler_predict(self, input_data, check_prototype):
if not sm_exists:
raise ImportError("Cannot import `statsmodels`")

if isinstance(input_data, (list, pd.DataFrame)):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])

return prediction.tolist()
input_data = (
input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data]
)
return self.model.predict(input_data).tolist()
2 changes: 1 addition & 1 deletion vetiver/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TorchHandler(BaseHandler):
if torch_exists:
pip_name = "torch"

def handler_predict(self, input_data, check_prototype):
def handler_predict(self, input_data, check_prototype, **kw):
"""
Generates method for /predict endpoint in VetiverAPI

Expand Down
2 changes: 1 addition & 1 deletion vetiver/handlers/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class XGBoostHandler(BaseHandler):
if xgb_exists:
pip_name = "xgboost"

def handler_predict(self, input_data, check_prototype):
def handler_predict(self, input_data, check_prototype, **kw):
"""
Generates method for /predict endpoint in VetiverAPI

Expand Down
22 changes: 19 additions & 3 deletions vetiver/mock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from sklearn.dummy import DummyRegressor
import pandas as pd
import numpy as np

from sklearn.dummy import DummyRegressor
from sklearn.linear_model import LogisticRegression

from .data import mtcars


def get_mock_data():
"""Create mock data for testing
Expand All @@ -26,5 +30,17 @@ def get_mock_model():
model : sklearn.dummy.DummyRegressor
Arbitrary model for testing purposes
"""
model = DummyRegressor()
return model
return DummyRegressor()


def get_mtcars_model():
"""Create mock model for testing

Returns
-------
model : sklearn.dummy.DummyRegressor
Arbitrary model for testing purposes
"""
return LogisticRegression(max_iter=1000, random_state=500).fit(
mtcars.drop(columns="cyl"), mtcars["cyl"]
)
105 changes: 67 additions & 38 deletions vetiver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from fastapi.exceptions import RequestValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse

from .helpers import api_data_to_frame, response_to_frame
from .handlers.sklearn import SKLearnHandler
from .meta import VetiverMeta
from .utils import _jupyter_nb, get_workbench_path
from .vetiver_model import VetiverModel
from .types import SklearnPredictionTypes


class VetiverAPI:
Expand Down Expand Up @@ -111,7 +112,6 @@ async def startup_event():

@app.get("/", include_in_schema=False)
def docs_redirect():

redirect = "__docs__"

return RedirectResponse(redirect)
Expand Down Expand Up @@ -200,65 +200,94 @@ async def validation_exception_handler(request, exc):

return app

def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
"""Create new POST endpoint that is aware of model input data
def vetiver_post(
self,
endpoint_fx: Union[Callable, SklearnPredictionTypes],
endpoint_name: str = None,
**kw,
):
"""Define a new POST endpoint that utilizes the model's input data.

Parameters
----------
endpoint_fx : typing.Callable
Custom function to be run at endpoint
endpoint_fx
: Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]]
A callable function that specifies the custom logic to execute when the
endpoint is called. This function should take input data (e.g., a DataFrame
or dictionary) and return the desired output(e.g., predictions or transformed
data). For scikit-learn models, endpoint_fx can also be one of "predict",
"predict_proba", or "predict_log_proba" if the model supports these methods.

endpoint_name : str
Name of endpoint
The name of the endpoint to be created.

Examples
-------
```{python}
```python
from vetiver import mock, VetiverModel, VetiverAPI
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)

v = VetiverModel(model = model, model_name = "model", prototype_data = X)
v_api = VetiverAPI(model = v, check_prototype = True)
v = VetiverModel(model=model, model_name="model", prototype_data=X)
v_api = VetiverAPI(model=v, check_prototype=True)

def sum_values(x):
return x.sum()

v_api.vetiver_post(sum_values, "sums")
```
"""
if not endpoint_name:
endpoint_name = endpoint_fx.__name__

if endpoint_fx.__doc__ is not None:
api_desc = dedent(endpoint_fx.__doc__)
else:
api_desc = None
if not isinstance(endpoint_fx, Callable):
if endpoint_fx not in ["predict", "predict_proba", "predict_log_proba"]:
raise ValueError(
f"""
Prediction type {endpoint_fx} not available.
Available prediction types: {SklearnPredictionTypes}
"""
)
if not isinstance(self.model.handler_predict.__self__, SKLearnHandler):
raise ValueError(
"""
The 'endpoint_fx' parameter can only be a
string when using scikit-learn models.
"""
)
self.vetiver_post(
self.model.handler_predict,
endpoint_fx,
check_prototype=self.check_prototype,
prediction_type=endpoint_fx,
)
return

if self.check_prototype is True:
endpoint_name = endpoint_name or endpoint_fx.__name__
endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None

@self.app.post(
urljoin("/", endpoint_name),
name=endpoint_name,
description=api_desc,
)
async def custom_endpoint(input_data: List[self.model.prototype]):
_to_frame = api_data_to_frame(input_data)
predictions = endpoint_fx(_to_frame, **kw)
if isinstance(predictions, List):
return {endpoint_name: predictions}
else:
return predictions
# this must be split up this way to preserve the correct type hints for
# the input_data schema validation via Pydantic + FastAPI
input_data_type = (
List[self.model.prototype] if self.check_prototype else Request
)

else:
@self.app.post(
urljoin("/", endpoint_name),
name=endpoint_name,
description=endpoint_doc,
)
async def custom_endpoint(input_data: input_data_type):

@self.app.post(urljoin("/", endpoint_name))
async def custom_endpoint(input_data: Request):
served_data = await input_data.json()
predictions = endpoint_fx(served_data, **kw)
served_data = (
api_data_to_frame(input_data)
if self.check_prototype
else await input_data.json()
)
predictions = endpoint_fx(served_data, **kw)

if isinstance(predictions, List):
return {endpoint_name: predictions}
else:
return predictions
if isinstance(predictions, List):
return {endpoint_name: predictions}
else:
return predictions

def run(self, port: int = 8000, host: str = "127.0.0.1", quiet_open=False, **kw):
"""
Expand Down
32 changes: 0 additions & 32 deletions vetiver/tests/test_add_endpoint.py

This file was deleted.

Loading
Loading