diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index d085c6fd87..f5cab70080 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -16,18 +16,25 @@ # under the License. from __future__ import annotations +import importlib +import json from collections import deque from enum import Enum +from types import UnionType from typing import ( TYPE_CHECKING, Any, + Union, + get_args, + get_origin, + get_type_hints, ) from urllib.parse import quote, unquote from pydantic import ConfigDict, Field, TypeAdapter, field_validator from requests import HTTPError, Session from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt -from typing_extensions import override +from typing_extensions import NotRequired, TypedDict, override from pyiceberg import __version__ from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary @@ -396,6 +403,134 @@ class ListViewsResponse(IcebergBaseModel): _PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse) +class ParsedAuthConfig(TypedDict): + auth_type: str + auth_manager_name: str + auth_type_config: dict[str, Any] + + +class AuthConfigEnvelope(TypedDict): + type: str + impl: NotRequired[str] + + +def _get_auth_manager_class(class_or_name: str) -> type[AuthManager]: + if class_or_name in AuthManagerFactory._registry: + return AuthManagerFactory._registry[class_or_name] + + try: + module_path, class_name = class_or_name.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + except Exception as err: + raise ValueError(f"Could not load AuthManager class for '{class_or_name}'") from err + + +def _coerce_auth_option_value(key: str, value: Any, annotation: Any) -> Any: + if not isinstance(value, str): + return value + + origin = get_origin(annotation) + if origin is list: + try: + parsed = json.loads(value) + except json.JSONDecodeError as err: + raise ValueError(f"Failed to parse auth configuration value '{key}' as JSON array") from err + + if not isinstance(parsed, list) or not all(isinstance(item, str) for item in parsed): + raise ValueError(f"auth configuration value '{key}' must be a JSON array of strings") + return parsed + + if origin in (Union, UnionType): + non_none_args = [arg for arg in get_args(annotation) if arg is not type(None)] + if len(non_none_args) == 1: + return _coerce_auth_option_value(key, value, non_none_args[0]) + + if origin is not None: + if origin is list: + try: + parsed = json.loads(value) + except json.JSONDecodeError as err: + raise ValueError(f"Failed to parse auth configuration value '{key}' as JSON array") from err + + if not isinstance(parsed, list) or not all(isinstance(item, str) for item in parsed): + raise ValueError(f"auth configuration value '{key}' must be a JSON array of strings") + return parsed + + if annotation is int: + try: + return int(value) + except ValueError as err: + raise ValueError(f"Failed to parse auth configuration value '{key}' as integer") from err + + return value + + +def _coerce_auth_config_values(class_or_name: str, config: dict[str, Any]) -> dict[str, Any]: + manager_class = _get_auth_manager_class(class_or_name) + hints = get_type_hints(manager_class.__init__) + return {key: _coerce_auth_option_value(key, value, hints.get(key, Any)) for key, value in config.items()} + + +def _load_auth_config_from_properties(properties: Properties) -> AuthConfigEnvelope | dict[str, Any] | None: + raw_auth = properties.get(AUTH) + if isinstance(raw_auth, str): + try: + decoded_auth = json.loads(raw_auth) + except json.JSONDecodeError as e: + raise ValueError("Failed to parse auth configuration as JSON") from e + if decoded_auth is not None and not isinstance(decoded_auth, dict): + raise ValueError("auth configuration must be a dictionary") + return decoded_auth + + if raw_auth is not None: + if not isinstance(raw_auth, dict): + raise ValueError("auth configuration must be a dictionary") + return raw_auth + + if auth_type := properties.get(f"{AUTH}.type"): + type_prefix = f"{AUTH}.{auth_type}." + return { + "type": auth_type, + "impl": properties.get(f"{AUTH}.impl"), + auth_type: { + key[len(type_prefix) :].replace("-", "_"): value + for key, value in properties.items() + if key.startswith(type_prefix) + }, + } + + return None + + +def _resolve_auth_config(auth_config: AuthConfigEnvelope | dict[str, Any]) -> ParsedAuthConfig: + auth_type = auth_config.get("type") + if not isinstance(auth_type, str): + raise ValueError("auth.type must be defined") + + auth_type_config = auth_config.get(auth_type, {}) + if not isinstance(auth_type_config, dict): + raise ValueError(f"auth.{auth_type} must be a dictionary") + + auth_impl = auth_config.get("impl") + if auth_impl is not None and not isinstance(auth_impl, str): + raise ValueError("auth.impl must be a string") + + auth_manager_name = auth_impl or auth_type + + if auth_type == CUSTOM and not auth_impl: + raise ValueError("auth.impl must be specified when using custom auth.type") + + if auth_type != CUSTOM and auth_impl: + raise ValueError("auth.impl can only be specified when using custom auth.type") + + return { + "auth_type": auth_type, + "auth_manager_name": auth_manager_name, + "auth_type_config": auth_type_config, + } + + class RestCatalog(Catalog): uri: str _session: Session @@ -435,20 +570,13 @@ def _create_session(self) -> Session: elif ssl_client_cert := ssl_client.get(CERT): session.cert = ssl_client_cert - if auth_config := self.properties.get(AUTH): - auth_type = auth_config.get("type") - if auth_type is None: - raise ValueError("auth.type must be defined") - auth_type_config = auth_config.get(auth_type, {}) - auth_impl = auth_config.get("impl") - - if auth_type == CUSTOM and not auth_impl: - raise ValueError("auth.impl must be specified when using custom auth.type") - - if auth_type != CUSTOM and auth_impl: - raise ValueError("auth.impl can only be specified when using custom auth.type") - - self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config) + auth_config = _load_auth_config_from_properties(self.properties) + if auth_config: + resolved_auth = _resolve_auth_config(auth_config) + typed_auth_type_config = _coerce_auth_config_values( + resolved_auth["auth_manager_name"], resolved_auth["auth_type_config"] + ) + self._auth_manager = AuthManagerFactory.create(resolved_auth["auth_manager_name"], typed_auth_type_config) session.auth = AuthManagerAdapter(self._auth_manager) else: self._auth_manager = self._create_legacy_oauth2_auth_manager(session) diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 1eb9f26a56..b6fecb049a 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -18,6 +18,7 @@ from __future__ import annotations import base64 +import json import os from collections.abc import Callable from typing import Any, cast @@ -2470,6 +2471,186 @@ def test_rest_catalog_oauth2_non_200_token_response(requests_mock: Mocker) -> No RestCatalog("rest", **catalog_properties) # type: ignore +def _rest_catalog_properties_from_environment() -> RecursiveDict: + env_config = Config._from_environment_variables({}) + catalogs = cast(RecursiveDict, env_config["catalog"]) + return cast(RecursiveDict, catalogs["rest"]) + + +@mock.patch.dict( + os.environ, + { + "PYICEBERG_CATALOG__REST__URI": TEST_URI, + "PYICEBERG_CATALOG__REST__AUTH": json.dumps({"type": "basic", "basic": {"username": "one", "password": "two"}}), + }, + clear=True, +) +def test_rest_catalog_with_basic_auth_json_environment_variable(rest_mock: Mocker) -> None: + rest_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200) + + RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore + + encoded_user_pass = base64.b64encode(b"one:two").decode() + assert rest_mock.last_request.headers["Authorization"] == f"Basic {encoded_user_pass}" + + +@mock.patch.dict( + os.environ, + { + "PYICEBERG_CATALOG__REST__URI": TEST_URI, + "PYICEBERG_CATALOG__REST__AUTH": json.dumps( + { + "type": "oauth2", + "oauth2": { + "client_id": "some_client_id", + "client_secret": "some_client_secret", + "token_url": f"{TEST_URI}oauth2/token", + }, + } + ), + }, + clear=True, +) +def test_rest_catalog_with_oauth2_auth_json_environment_variable(requests_mock: Mocker) -> None: + requests_mock.post( + f"{TEST_URI}oauth2/token", + json={"access_token": TEST_TOKEN, "token_type": "Bearer", "expires_in": 3600}, + status_code=200, + ) + requests_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200) + + catalog = RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore + + assert catalog.uri == TEST_URI + + +@mock.patch.dict( + os.environ, + { + "PYICEBERG_CATALOG__REST__URI": TEST_URI, + "PYICEBERG_CATALOG__REST__AUTH": "not-valid-json", + }, + clear=True, +) +def test_rest_catalog_with_invalid_json_auth_environment_variable() -> None: + with pytest.raises(ValueError, match="Failed to parse auth configuration as JSON"): + RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore + + +@mock.patch.dict( + os.environ, + { + "PYICEBERG_CATALOG__REST__URI": TEST_URI, + "PYICEBERG_CATALOG__REST__AUTH__TYPE": "basic", + "PYICEBERG_CATALOG__REST__AUTH__BASIC__USERNAME": "one", + "PYICEBERG_CATALOG__REST__AUTH__BASIC__PASSWORD": "two", + }, + clear=True, +) +def test_rest_catalog_with_basic_auth_flat_environment_variables(rest_mock: Mocker) -> None: + rest_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200) + + RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore + + encoded_user_pass = base64.b64encode(b"one:two").decode() + assert rest_mock.last_request.headers["Authorization"] == f"Basic {encoded_user_pass}" + + +@mock.patch.dict( + os.environ, + { + "PYICEBERG_CATALOG__REST__URI": TEST_URI, + "PYICEBERG_CATALOG__REST__AUTH__TYPE": "oauth2", + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_ID": "some_client_id", + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_SECRET": "some_client_secret", + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__TOKEN_URL": f"{TEST_URI}oauth2/token", + }, + clear=True, +) +def test_rest_catalog_with_oauth2_auth_flat_environment_variables(requests_mock: Mocker) -> None: + requests_mock.post( + f"{TEST_URI}oauth2/token", + json={"access_token": TEST_TOKEN, "token_type": "Bearer", "expires_in": 3600}, + status_code=200, + ) + requests_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200) + + catalog = RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore + + assert catalog.uri == TEST_URI + + +@pytest.mark.parametrize( + "auth_type, env_overrides, expected_config", + [ + pytest.param( + "oauth2", + { + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_ID": "some_client_id", + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_SECRET": "some_client_secret", + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__TOKEN_URL": f"{TEST_URI}oauth2/token", + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__REFRESH_MARGIN": "90", + "PYICEBERG_CATALOG__REST__AUTH__OAUTH2__EXPIRES_IN": "3600", + }, + { + "client_id": "some_client_id", + "client_secret": "some_client_secret", + "token_url": f"{TEST_URI}oauth2/token", + "refresh_margin": 90, + "expires_in": 3600, + }, + id="oauth2-numeric-fields", + ), + pytest.param( + "google", + { + "PYICEBERG_CATALOG__REST__AUTH__GOOGLE__CREDENTIALS_PATH": "/fake/path.json", + "PYICEBERG_CATALOG__REST__AUTH__GOOGLE__SCOPES": '["scope-a", "scope-b"]', + }, + { + "credentials_path": "/fake/path.json", + "scopes": ["scope-a", "scope-b"], + }, + id="google-scopes", + ), + pytest.param( + "entra", + { + "PYICEBERG_CATALOG__REST__AUTH__ENTRA__SCOPES": '["scope-a", "scope-b"]', + }, + { + "scopes": ["scope-a", "scope-b"], + }, + id="entra-scopes", + ), + ], +) +def test_rest_catalog_with_typed_auth_flat_environment_variables( + rest_mock: Mocker, + auth_type: str, + env_overrides: dict[str, str], + expected_config: dict[str, Any], +) -> None: + rest_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200) + + fake_auth_manager = mock.Mock() + fake_auth_manager.auth_header.return_value = "" + env = { + "PYICEBERG_CATALOG__REST__URI": TEST_URI, + "PYICEBERG_CATALOG__REST__AUTH__TYPE": auth_type, + **env_overrides, + } + + with ( + mock.patch.dict(os.environ, env, clear=True), + mock.patch("pyiceberg.catalog.rest.AuthManagerFactory.create", return_value=fake_auth_manager) as create_auth_manager, + ): + catalog = RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore + + assert catalog.uri == TEST_URI + assert create_auth_manager.call_args_list == [mock.call(auth_type, expected_config), mock.call(auth_type, expected_config)] + + EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}