Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions mp_api/_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Define testing utils that need to imported."""

# pragma: exclude file

from __future__ import annotations

try:
import pytest
except ImportError as exc:
raise ImportError(
"You must `pip install 'mp-api[test]' to use these testing utilities."
) from exc

import os
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing import Any

requires_api_key = pytest.mark.skipif(
os.getenv("MP_API_KEY") is None,
reason="No API key found.",
)

NUM_DOCS = 5


def client_search_testing(
search_method: Callable,
excluded_params: list[str],
alt_name_dict: dict[str, str],
custom_field_tests: dict[str, Any],
sub_doc_fields: list[str],
int_bounds: tuple[int, int] = (-100, 100),
float_bounds: tuple[float, float] = (-100.12, 100.12),
):
"""Function to test a client using its search method.
Each parameter is used to query for data, which is then checked.

Args:
search_method (Callable): Client search method
excluded_params (list[str]): List of parameters to exclude from testing
alt_name_dict (dict[str, str]): Alternative names for parameters used in the projection and subsequent data checking
custom_field_tests (dict[str, Any]): Custom queries for specific fields.
sub_doc_fields (list[str]): Prefixes for fields to check in resulting data. Useful when data to be tested is nested.
int_bounds (tuple[int,int]) : integer bounds to use in testing int-type query arguments
float_bounds (tuple[float,float]) : float bounds to use in testing float-type query arguments
"""
if search_method is None:
return
# Get list of parameters
param_tuples = list(search_method.__annotations__.items())

# Query API for each numeric and boolean parameter and check if returned
for entry in param_tuples:
param = entry[0]

if param not in excluded_params + ["return"]:
param_type = entry[1]
q: dict[str, Any] = {"chunk_size": 1, "num_chunks": 1}

if "tuple[int, int]" in param_type:
q[param] = int_bounds
elif "tuple[float, float]" in param_type:
q[param] = float_bounds
elif "bool" in param_type:
q[param] = False
elif param in custom_field_tests:
q[param] = custom_field_tests[param]
else:
raise ValueError(
f"Parameter '{param}' with type '{param_type}' was not "
"properly identified in the generic search method test."
)

if len(docs := search_method(**q)) > 0:
doc = docs[0].model_dump()
else:
raise ValueError("No documents returned")

for sub_field in sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]

assert doc[alt_name_dict.get(param, param)] is not None


def client_pagination(search_method: Callable, id_name: str):
page_1 = search_method(_page=1, chunk_size=NUM_DOCS, fields=[id_name])
page_2 = search_method(_page=2, chunk_size=NUM_DOCS, fields=[id_name])
assert all(len(results) == NUM_DOCS for results in (page_1, page_2))
assert {str(getattr(doc, id_name)) for doc in page_1}.intersection(
{str(getattr(doc, id_name)) for doc in page_2}
) == set()


def client_sort(search_method: Callable, sort_fields: str | Sequence[str]):
for sort_field in [sort_fields] if isinstance(sort_fields, str) else sort_fields:
asc = search_method(
_page=1, _sort_fields=sort_field, chunk_size=NUM_DOCS, fields=[sort_field]
)
desc = search_method(
_page=1,
_sort_fields=f"-{sort_field}",
chunk_size=NUM_DOCS,
fields=[sort_field],
)

idxs = list(range(NUM_DOCS))
assert sorted(idxs, key=lambda idx: getattr(asc[idx], sort_field)) == idxs

assert (
sorted(
idxs,
key=lambda idx: getattr(desc[idx], sort_field),
reverse=True,
)
== idxs
)
7 changes: 7 additions & 0 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,10 @@ def _submit_requests( # noqa
# No splitting needed - get first page
total_data = {"data": []}
initial_criteria = copy(criteria)
if isinstance(
initial_criteria.get("_page"), int
) and not initial_criteria.get("_per_page"):
initial_criteria["_per_page"] = initial_criteria.get("_limit")
data, total_num_docs = self._submit_request_and_process(
url=url,
verify=True,
Expand Down Expand Up @@ -1438,6 +1442,9 @@ def _search(
# This method should be customized for each end point to give more user friendly,
# documented kwargs.

# If user specifies page, ensure only one chunk is returned
if isinstance(kwargs.get("_page"), int) and num_chunks is None:
num_chunks = 1
return self._get_all_documents(
kwargs,
all_fields=all_fields,
Expand Down
80 changes: 33 additions & 47 deletions mp_api/client/routes/materials/electrodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BaseElectrodeRester(BaseRester):
primary_key = "battery_id"
_exclude_search_fields: list[str] | None = None

def search( # pragma: ignore
def search(
self,
battery_ids: str | list[str] | None = None,
average_voltage: tuple[float, float] | None = None,
Expand All @@ -39,6 +39,8 @@ def search( # pragma: ignore
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
_page: int | None = None,
_sort_fields: str | None = None,
) -> list[InsertionElectrodeDoc | ConversionElectrodeDoc] | list[dict]:
"""Query using a variety of search criteria.

Expand Down Expand Up @@ -77,63 +79,45 @@ def search( # pragma: ignore
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in InsertionElectrodeDoc or ConversionElectrodeDoc to return data for.
Default is battery_id and last_updated if all_fields is False.
_page (int or None) : Page of the results to skip to.
_sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order.

Returns:
([InsertionElectrodeDoc or ConversionElectrodeDoc], [dict]) List of insertion/conversion electrode documents or dictionaries.
"""
query_params: dict = defaultdict(dict)

if battery_ids:
if isinstance(battery_ids, str):
battery_ids = [battery_ids]

query_params.update({"battery_ids": ",".join(validate_ids(battery_ids))})

if working_ion:
if isinstance(working_ion, (str, Element)):
working_ion = [working_ion] # type: ignore

query_params.update(
{"working_ion": ",".join([str(ele) for ele in working_ion])} # type: ignore
)

if formula:
if isinstance(formula, str):
formula = [formula]

query_params.update({"formula": ",".join(formula)})

if elements:
query_params.update({"elements": ",".join(elements)})

if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
query_params.update(
{"nelements_min": num_elements[0], "nelements_max": num_elements[1]}
)

if exclude_elements:
query_params.update({"exclude_elements": ",".join(exclude_elements)})

for param, value in locals().items():
if (
param
not in [
"__class__",
"self",
"working_ion",
"query_params",
"num_elements",
]
and value
):
if isinstance(value, tuple):
if param not in {"__class__", "self", "query_params"} and value is not None:
if param == "num_elements": # this must come first
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
query_params.update(
{
"nelements_min": num_elements[0], # type: ignore[index]
"nelements_max": num_elements[1], # type: ignore[index]
}
)

elif isinstance(value, tuple):
query_params.update(
{f"{param}_min": value[0], f"{param}_max": value[1]}
)
elif param == "battery_ids":
query_params[param] = ",".join(validate_ids(value))
elif param == "working_ion":
query_params["working_ion"] = ",".join(
str(ele)
for ele in (
[value] if isinstance(value, str | Element) else value
)
)
elif param in ("formula", "elements", "exclude_elements"):
query_params[param] = ",".join(
[value] if isinstance(value, str) else value
)
else:
query_params.update({param: value})
query_params[param] = value

excluded_fields = self._exclude_search_fields or []
ignored_fields = {
Expand Down Expand Up @@ -177,4 +161,6 @@ class ConversionElectrodeRester(BaseElectrodeRester):
"stability_charge",
"stability_discharge",
"exclude_elements",
"_page",
"_sort_fields",
]
Loading
Loading