diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d74dc3f..eff9242 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: types: - python args: - - "--max-line-length=90" + - "--max-line-length=100" - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml diff --git a/vetiver/__init__.py b/vetiver/__init__.py index 31a56bc..ae50039 100644 --- a/vetiver/__init__.py +++ b/vetiver/__init__.py @@ -10,7 +10,7 @@ ) # noqa from .vetiver_model import VetiverModel # noqa from .server import VetiverAPI, vetiver_endpoint, predict # noqa -from .mock import get_mock_data, get_mock_model # noqa +from .mock import get_mock_data, get_mock_model, get_mtcars_model # noqa from .pin_read_write import vetiver_pin_write # noqa from .attach_pkgs import load_pkgs, get_board_pkgs # noqa from .meta import VetiverMeta # noqa diff --git a/vetiver/handlers/base.py b/vetiver/handlers/base.py index 3f0044c..7139f2a 100644 --- a/vetiver/handlers/base.py +++ b/vetiver/handlers/base.py @@ -121,7 +121,7 @@ def handler_startup(): """ ... - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """Generates method for /predict endpoint in VetiverAPI The `handler_predict` function executes at each API call. Use this diff --git a/vetiver/handlers/sklearn.py b/vetiver/handlers/sklearn.py index a940118..fbac67a 100644 --- a/vetiver/handlers/sklearn.py +++ b/vetiver/handlers/sklearn.py @@ -16,7 +16,7 @@ class SKLearnHandler(BaseHandler): model_class = staticmethod(lambda: sklearn.base.BaseEstimator) pip_name = "scikit-learn" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype: bool, **kw): """ Generates method for /predict endpoint in VetiverAPI @@ -28,16 +28,22 @@ def handler_predict(self, input_data, check_prototype): ---------- input_data: Test data + check_prototype: bool + prediction_type: str + Type of prediction to make. One of "predict", "predict_proba", + or "predict_log_proba". Default is "predict". Returns ------- prediction: Prediction from model """ + prediction_type = kw.get("prediction_type", "predict") - if not check_prototype or isinstance(input_data, pd.DataFrame): - prediction = self.model.predict(input_data) - else: - prediction = self.model.predict([input_data]) + input_data = ( + [input_data] + if check_prototype and not isinstance(input_data, pd.DataFrame) + else input_data + ) - return prediction.tolist() + return getattr(self.model, prediction_type)(input_data).tolist() diff --git a/vetiver/handlers/spacy.py b/vetiver/handlers/spacy.py index 80dfdac..eb4d0de 100644 --- a/vetiver/handlers/spacy.py +++ b/vetiver/handlers/spacy.py @@ -53,7 +53,7 @@ def construct_prototype(self): return prototype - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/statsmodels.py b/vetiver/handlers/statsmodels.py index 084b5ff..ab39289 100644 --- a/vetiver/handlers/statsmodels.py +++ b/vetiver/handlers/statsmodels.py @@ -22,7 +22,7 @@ class StatsmodelsHandler(BaseHandler): if sm_exists: pip_name = "statsmodels" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI @@ -43,9 +43,7 @@ def handler_predict(self, input_data, check_prototype): if not sm_exists: raise ImportError("Cannot import `statsmodels`") - if isinstance(input_data, (list, pd.DataFrame)): - prediction = self.model.predict(input_data) - else: - prediction = self.model.predict([input_data]) - - return prediction.tolist() + input_data = ( + input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data] + ) + return self.model.predict(input_data).tolist() diff --git a/vetiver/handlers/torch.py b/vetiver/handlers/torch.py index 15dafa5..625d4cb 100644 --- a/vetiver/handlers/torch.py +++ b/vetiver/handlers/torch.py @@ -22,7 +22,7 @@ class TorchHandler(BaseHandler): if torch_exists: pip_name = "torch" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/xgboost.py b/vetiver/handlers/xgboost.py index 9c11234..21bbb6b 100644 --- a/vetiver/handlers/xgboost.py +++ b/vetiver/handlers/xgboost.py @@ -22,7 +22,7 @@ class XGBoostHandler(BaseHandler): if xgb_exists: pip_name = "xgboost" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/mock.py b/vetiver/mock.py index 780e4a5..0ed3de4 100644 --- a/vetiver/mock.py +++ b/vetiver/mock.py @@ -1,7 +1,11 @@ -from sklearn.dummy import DummyRegressor import pandas as pd import numpy as np +from sklearn.dummy import DummyRegressor +from sklearn.linear_model import LogisticRegression + +from .data import mtcars + def get_mock_data(): """Create mock data for testing @@ -26,5 +30,17 @@ def get_mock_model(): model : sklearn.dummy.DummyRegressor Arbitrary model for testing purposes """ - model = DummyRegressor() - return model + return DummyRegressor() + + +def get_mtcars_model(): + """Create mock model for testing + + Returns + ------- + model : sklearn.dummy.DummyRegressor + Arbitrary model for testing purposes + """ + return LogisticRegression(max_iter=1000, random_state=500).fit( + mtcars.drop(columns="cyl"), mtcars["cyl"] + ) diff --git a/vetiver/server.py b/vetiver/server.py index ea55298..a5be420 100644 --- a/vetiver/server.py +++ b/vetiver/server.py @@ -15,11 +15,12 @@ from fastapi.exceptions import RequestValidationError from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse - from .helpers import api_data_to_frame, response_to_frame +from .handlers.sklearn import SKLearnHandler from .meta import VetiverMeta from .utils import _jupyter_nb, get_workbench_path from .vetiver_model import VetiverModel +from .types import SklearnPredictionTypes class VetiverAPI: @@ -111,7 +112,6 @@ async def startup_event(): @app.get("/", include_in_schema=False) def docs_redirect(): - redirect = "__docs__" return RedirectResponse(redirect) @@ -200,65 +200,94 @@ async def validation_exception_handler(request, exc): return app - def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): - """Create new POST endpoint that is aware of model input data + def vetiver_post( + self, + endpoint_fx: Union[Callable, SklearnPredictionTypes], + endpoint_name: str = None, + **kw, + ): + """Define a new POST endpoint that utilizes the model's input data. Parameters ---------- - endpoint_fx : typing.Callable - Custom function to be run at endpoint + endpoint_fx + : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]] + A callable function that specifies the custom logic to execute when the + endpoint is called. This function should take input data (e.g., a DataFrame + or dictionary) and return the desired output(e.g., predictions or transformed + data). For scikit-learn models, endpoint_fx can also be one of "predict", + "predict_proba", or "predict_log_proba" if the model supports these methods. + endpoint_name : str - Name of endpoint + The name of the endpoint to be created. Examples ------- - ```{python} + ```python from vetiver import mock, VetiverModel, VetiverAPI X, y = mock.get_mock_data() model = mock.get_mock_model().fit(X, y) - v = VetiverModel(model = model, model_name = "model", prototype_data = X) - v_api = VetiverAPI(model = v, check_prototype = True) + v = VetiverModel(model=model, model_name="model", prototype_data=X) + v_api = VetiverAPI(model=v, check_prototype=True) def sum_values(x): return x.sum() + v_api.vetiver_post(sum_values, "sums") ``` """ - if not endpoint_name: - endpoint_name = endpoint_fx.__name__ - if endpoint_fx.__doc__ is not None: - api_desc = dedent(endpoint_fx.__doc__) - else: - api_desc = None + if not isinstance(endpoint_fx, Callable): + if endpoint_fx not in ["predict", "predict_proba", "predict_log_proba"]: + raise ValueError( + f""" + Prediction type {endpoint_fx} not available. + Available prediction types: {SklearnPredictionTypes} + """ + ) + if not isinstance(self.model.handler_predict.__self__, SKLearnHandler): + raise ValueError( + """ + The 'endpoint_fx' parameter can only be a + string when using scikit-learn models. + """ + ) + self.vetiver_post( + self.model.handler_predict, + endpoint_fx, + check_prototype=self.check_prototype, + prediction_type=endpoint_fx, + ) + return - if self.check_prototype is True: + endpoint_name = endpoint_name or endpoint_fx.__name__ + endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None - @self.app.post( - urljoin("/", endpoint_name), - name=endpoint_name, - description=api_desc, - ) - async def custom_endpoint(input_data: List[self.model.prototype]): - _to_frame = api_data_to_frame(input_data) - predictions = endpoint_fx(_to_frame, **kw) - if isinstance(predictions, List): - return {endpoint_name: predictions} - else: - return predictions + # this must be split up this way to preserve the correct type hints for + # the input_data schema validation via Pydantic + FastAPI + input_data_type = ( + List[self.model.prototype] if self.check_prototype else Request + ) - else: + @self.app.post( + urljoin("/", endpoint_name), + name=endpoint_name, + description=endpoint_doc, + ) + async def custom_endpoint(input_data: input_data_type): - @self.app.post(urljoin("/", endpoint_name)) - async def custom_endpoint(input_data: Request): - served_data = await input_data.json() - predictions = endpoint_fx(served_data, **kw) + served_data = ( + api_data_to_frame(input_data) + if self.check_prototype + else await input_data.json() + ) + predictions = endpoint_fx(served_data, **kw) - if isinstance(predictions, List): - return {endpoint_name: predictions} - else: - return predictions + if isinstance(predictions, List): + return {endpoint_name: predictions} + else: + return predictions def run(self, port: int = 8000, host: str = "127.0.0.1", quiet_open=False, **kw): """ diff --git a/vetiver/tests/test_add_endpoint.py b/vetiver/tests/test_add_endpoint.py deleted file mode 100644 index 5a5f1a2..0000000 --- a/vetiver/tests/test_add_endpoint.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -import pandas as pd -from vetiver import mock, VetiverModel - - -@pytest.fixture() -def model(): - X, y = mock.get_mock_data() - model = mock.get_mock_model() - - return VetiverModel(model.fit(X, y), "model", prototype_data=X) - - -@pytest.fixture -def data() -> pd.DataFrame: - return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) - - -def test_endpoint_adds(client, data): - response = client.post("/sum/", data=data.to_json(orient="records")) - - assert response.status_code == 200 - assert response.json() == {"sum": [3, 6, 9]} - - -def test_endpoint_adds_no_prototype(client_no_prototype, data): - - data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) - response = client_no_prototype.post("/sum/", data=data.to_json(orient="records")) - - assert response.status_code == 200 - assert response.json() == {"sum": [3, 6, 9]} diff --git a/vetiver/tests/test_server.py b/vetiver/tests/test_server.py index 97150c0..6d95378 100644 --- a/vetiver/tests/test_server.py +++ b/vetiver/tests/test_server.py @@ -1,3 +1,11 @@ +import pytest +import sys +import pandas as pd +import numpy as np +from fastapi.testclient import TestClient +from pydantic import BaseModel, conint + +from vetiver.data import mtcars from vetiver import ( mock, VetiverModel, @@ -7,24 +15,18 @@ vetiver_endpoint, predict, ) -from pydantic import BaseModel, conint -from fastapi.testclient import TestClient -import numpy as np -import pytest -import sys @pytest.fixture def model(): np.random.seed(500) - X, y = mock.get_mock_data() - model = mock.get_mock_model().fit(X, y) + model = mock.get_mtcars_model() v = VetiverModel( model=model, - prototype_data=X, + prototype_data=mtcars.drop(columns="cyl"), model_name="my_model", versioned=None, - description="A regression model for testing purposes", + description="A logistic regression model for testing purposes", ) return v @@ -82,11 +84,29 @@ def test_get_prototype(client, model): assert response.status_code == 200, response.text assert response.json() == { "properties": { - "B": {"example": 55, "type": "integer"}, - "C": {"example": 65, "type": "integer"}, - "D": {"example": 17, "type": "integer"}, + "mpg": {"example": 21.0, "type": "number"}, + "disp": {"example": 160.0, "type": "number"}, + "hp": {"example": 110.0, "type": "number"}, + "drat": {"example": 3.9, "type": "number"}, + "wt": {"example": 2.62, "type": "number"}, + "qsec": {"example": 16.46, "type": "number"}, + "vs": {"example": 0.0, "type": "number"}, + "am": {"example": 1.0, "type": "number"}, + "gear": {"example": 4.0, "type": "number"}, + "carb": {"example": 4.0, "type": "number"}, }, - "required": ["B", "C", "D"], + "required": [ + "mpg", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb", + ], "title": "prototype", "type": "object", } @@ -125,3 +145,78 @@ def test_vetiver_endpoint(): url = vetiver_endpoint(url_raw) assert url == "http://127.0.0.1:8000/predict" + + +@pytest.fixture +def data() -> pd.DataFrame: + return pd.DataFrame( + { + "mpg": [20, 20], + "disp": [160, 160], + "hp": [110, 110], + "drat": [3.9, 3.9], + "wt": [2.62, 2.62], + "qsec": [16.00, 16.00], + "vs": [0, 0], + "am": [1, 1], + "gear": [4, 4], + "carb": [4, 4], + } + ) + + +def test_endpoint_adds(client, data): + + response = client.post("/sum/", data=data.to_json(orient="records")) + + assert response.status_code == 200 + assert response.json() == {"sum": [40, 320, 220, 7.8, 5.24, 32.00, 0, 2, 8, 8]} + + +def test_endpoint_adds_no_prototype(client_no_prototype, data): + + data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) + response = client_no_prototype.post("/sum/", data=data.to_json(orient="records")) + + assert response.status_code == 200 + assert response.json() == {"sum": [3, 6, 9]} + + +def test_vetiver_post_sklearn_predict(model, data): + api = VetiverAPI(model=model) + api.vetiver_post("predict_proba") + + client = TestClient(api.app) + response = predict(endpoint="/predict_proba/", data=data, test_client=client) + + assert isinstance(response, pd.DataFrame) + assert len(response) == 2 + # Allow for slight differences in architecture or library versions + expected = { + "predict_proba": { + 0: [ + 0.0063, + 0.9937, + 3.59e-12, + ], + 1: [ + 0.0063, + 0.9937, + 3.59e-12, + ], + }, + } + + response_dict = response.to_dict() + for key, value in expected["predict_proba"].items(): + assert response_dict["predict_proba"][key] == pytest.approx(value, rel=1e-2) + + +def test_vetiver_post_invalid_sklearn_type(model): + vetiver_api = VetiverAPI(model=model) + + with pytest.raises( + ValueError, + match="Prediction type invalid_type not available", + ): + vetiver_api.vetiver_post("invalid_type") diff --git a/vetiver/types.py b/vetiver/types.py index a097027..e36eeba 100644 --- a/vetiver/types.py +++ b/vetiver/types.py @@ -1,4 +1,5 @@ from pydantic import BaseModel, create_model +from typing import Literal all = ["Prototype", "create_prototype"] @@ -7,5 +8,8 @@ class Prototype(BaseModel): pass +SklearnPredictionTypes = Literal["predict", "predict_proba", "predict_log_proba"] + + def create_prototype(**dict_data): return create_model("prototype", __base__=Prototype, **dict_data)