diff --git a/agentplatform/_genai/_datasets_utils.py b/agentplatform/_genai/_datasets_utils.py index 410267d052..de8d60327f 100644 --- a/agentplatform/_genai/_datasets_utils.py +++ b/agentplatform/_genai/_datasets_utils.py @@ -280,6 +280,63 @@ async def save_dataframe_to_bigquery_async( await asyncio.to_thread(bq_client.delete_table, temp_table_id) +def load_dataframe_from_bigquery( + *, + bigquery_uri: str, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821 + """Loads a BigQuery table into a BigFrames DataFrame. + + Args: + bigquery_uri: The URI of the BigQuery table, with or without the `bq://` + prefix. + project: The project to use for the BigFrames session. + location: The location to use for the BigFrames session. + credentials: The credentials to use for the BigFrames session. + + Returns: + A BigFrames DataFrame backed by the BigQuery table. + """ + bigframes = _try_import_bigframes() + session_options = bigframes.BigQueryOptions( + credentials=credentials, + project=project, + location=location, + ) + with bigframes.connect(session_options) as session: + return session.read_gbq(bigquery_uri.removeprefix("bq://")) + + +async def load_dataframe_from_bigquery_async( + *, + bigquery_uri: str, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821 + """Loads a BigQuery table into a BigFrames DataFrame. + + Args: + bigquery_uri: The URI of the BigQuery table, with or without the `bq://` + prefix. + project: The project to use for the BigFrames session. + location: The location to use for the BigFrames session. + credentials: The credentials to use for the BigFrames session. + + Returns: + A BigFrames DataFrame backed by the BigQuery table. + """ + return await asyncio.to_thread( + load_dataframe_from_bigquery, + bigquery_uri=bigquery_uri, + project=project, + location=location, + credentials=credentials, + ) + + def resolve_dataset_name(resource_name_or_id: str, project: str, location: str) -> str: """Resolves a dataset name or ID to a full resource name.""" if "/" not in resource_name_or_id: diff --git a/agentplatform/_genai/datasets.py b/agentplatform/_genai/datasets.py index 77cc3277e8..277efe786f 100644 --- a/agentplatform/_genai/datasets.py +++ b/agentplatform/_genai/datasets.py @@ -1324,8 +1324,9 @@ def assemble( gemini_request_read_config: Optional[ types.GeminiRequestReadConfigOrDict ] = None, + load_dataframe: bool = True, config: Optional[types.AssembleDatasetConfigOrDict] = None, - ) -> str: + ) -> tuple[str, Optional["bigframes.pandas.DataFrame"]]: # type: ignore # noqa: F821 """Assemble the dataset into a BigQuery table. Waits for the assemble operation to complete before returning. @@ -1338,12 +1339,20 @@ def assemble( Optional. The read config to use to assemble the dataset. If not provided, the read config attached to the dataset will be used. + load_dataframe: + Optional. Whether to load the assembled BigQuery table into a + BigFrames DataFrame and return it. If False, the returned + DataFrame is None and no BigQuery read is performed. Defaults to + True. config: Optional. A configuration for assembling the dataset. If not provided, the default configuration will be used. Returns: - The URI of the bigquery table of the assembled dataset. + A tuple `(table_id, dataframe)`, where `table_id` is the BigQuery + table id of the assembled dataset (without the `bq://` prefix) and + `dataframe` is the assembled table loaded as a BigFrames DataFrame. + `dataframe` is None if `load_dataframe` is False. """ if isinstance(config, dict): config = types.AssembleDatasetConfig(**config) @@ -1363,7 +1372,17 @@ def assemble( operation=operation, timeout_seconds=config.timeout, ) - return response["bigqueryDestination"] # type: ignore[no-any-return] + bigquery_uri = response["bigqueryDestination"] + table_id = bigquery_uri.removeprefix("bq://") + dataframe = None + if load_dataframe: + dataframe = _datasets_utils.load_dataframe_from_bigquery( + bigquery_uri=bigquery_uri, + project=self._api_client.project, + location=self._api_client.location, + credentials=self._api_client._credentials, + ) + return (table_id, dataframe) def assess_tuning_resources( self, @@ -2713,8 +2732,9 @@ async def assemble( gemini_request_read_config: Optional[ types.GeminiRequestReadConfigOrDict ] = None, + load_dataframe: bool = True, config: Optional[types.AssembleDatasetConfigOrDict] = None, - ) -> str: + ) -> tuple[str, Optional["bigframes.pandas.DataFrame"]]: # type: ignore # noqa: F821 """Assemble the dataset into a BigQuery table. Waits for the assemble operation to complete before returning. @@ -2727,12 +2747,20 @@ async def assemble( Optional. The read config to use to assemble the dataset. If not provided, the read config attached to the dataset will be used. + load_dataframe: + Optional. Whether to load the assembled BigQuery table into a + BigFrames DataFrame and return it. If False, the returned + DataFrame is None and no BigQuery read is performed. Defaults to + True. config: Optional. A configuration for assembling the dataset. If not provided, the default configuration will be used. Returns: - The URI of the bigquery table of the assembled dataset. + A tuple `(table_id, dataframe)`, where `table_id` is the BigQuery + table id of the assembled dataset (without the `bq://` prefix) and + `dataframe` is the assembled table loaded as a BigFrames DataFrame. + `dataframe` is None if `load_dataframe` is False. """ if isinstance(config, dict): config = types.AssembleDatasetConfig(**config) @@ -2752,7 +2780,17 @@ async def assemble( operation=operation, timeout_seconds=config.timeout, ) - return response["bigqueryDestination"] # type: ignore[no-any-return] + bigquery_uri = response["bigqueryDestination"] + table_id = bigquery_uri.removeprefix("bq://") + dataframe = None + if load_dataframe: + dataframe = await _datasets_utils.load_dataframe_from_bigquery_async( + bigquery_uri=bigquery_uri, + project=self._api_client.project, + location=self._api_client.location, + credentials=self._api_client._credentials, + ) + return (table_id, dataframe) async def assess_tuning_resources( self, diff --git a/tests/unit/agentplatform/genai/replays/test_assemble_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_assemble_multimodal_datasets.py index 2eb61e7ea9..d532e7a0a1 100644 --- a/tests/unit/agentplatform/genai/replays/test_assemble_multimodal_datasets.py +++ b/tests/unit/agentplatform/genai/replays/test_assemble_multimodal_datasets.py @@ -14,7 +14,10 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring +from unittest import mock + from tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import _datasets_utils from agentplatform._genai import types import pytest @@ -26,6 +29,24 @@ DATASET = "projects/vertex-sdk-dev/locations/us-central1/datasets/8810841321427173376" +@pytest.fixture +def mock_import_bigframes(is_replay_mode): + # `assemble` reads the assembled table via a bigframes session, which is not + # part of the recorded Vertex interactions, so it must be mocked in replay + # mode. + if is_replay_mode: + with mock.patch.object( + _datasets_utils, "_try_import_bigframes" + ) as mock_import_bigframes: + bigframes = mock.MagicMock() + session = bigframes.connect.return_value.__enter__.return_value + session.read_gbq.return_value = mock.MagicMock() + mock_import_bigframes.return_value = bigframes + yield mock_import_bigframes + else: + yield None + + def test_assemble_dataset(client): operation = client.datasets._assemble_multimodal_dataset( name=DATASET, @@ -38,8 +59,9 @@ def test_assemble_dataset(client): assert isinstance(operation, types.MultimodalDatasetOperation) +@pytest.mark.usefixtures("mock_import_bigframes") def test_assemble_dataset_public(client): - bigquery_destination = client.datasets.assemble( + table_id, dataframe = client.datasets.assemble( name=DATASET, gemini_request_read_config=types.GeminiRequestReadConfig( template_config=types.GeminiTemplateConfig( @@ -54,8 +76,11 @@ def test_assemble_dataset_public(client): ), ) ), + load_dataframe=True, ) - assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}") + assert table_id.startswith(BIGQUERY_TABLE_NAME) + assert not table_id.startswith("bq://") + assert dataframe is not None pytestmark = pytest_helper.setup( @@ -80,8 +105,9 @@ async def test_assemble_dataset_async(client): @pytest.mark.asyncio +@pytest.mark.usefixtures("mock_import_bigframes") async def test_assemble_dataset_public_async(client): - bigquery_destination = await client.aio.datasets.assemble( + table_id, dataframe = await client.aio.datasets.assemble( name=DATASET, gemini_request_read_config=types.GeminiRequestReadConfig( template_config=types.GeminiTemplateConfig( @@ -97,4 +123,6 @@ async def test_assemble_dataset_public_async(client): ) ), ) - assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}") + assert table_id.startswith(BIGQUERY_TABLE_NAME) + assert not table_id.startswith("bq://") + assert dataframe is not None