diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4ec7a73afe..6f802fd62c 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2604,15 +2604,30 @@ def data_file_statistics_from_parquet_metadata( ) +def _resolve_row_group_size(arrow_table: pa.Table, row_group_limit: int | None, row_group_size_bytes: int | None) -> int | None: + if not row_group_size_bytes or arrow_table.num_rows == 0: + return row_group_limit + bytes_per_row = max(1, arrow_table.nbytes // arrow_table.num_rows) + rows_for_byte_budget = max(1, row_group_size_bytes // bytes_per_row) + if row_group_limit is None: + return rows_for_byte_budget + return min(row_group_limit, rows_for_byte_budget) + + def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties) - row_group_size = property_as_int( + row_group_limit = property_as_int( properties=table_metadata.properties, property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT, default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT, ) + row_group_size_bytes = property_as_int( + properties=table_metadata.properties, + property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, + default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT, + ) location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties) def write_parquet(task: WriteTask) -> DataFile: @@ -2636,6 +2651,7 @@ def write_parquet(task: WriteTask) -> DataFile: for batch in task.record_batches ] arrow_table = pa.Table.from_batches(batches) + row_group_size = _resolve_row_group_size(arrow_table, row_group_limit, row_group_size_bytes) file_path = location_provider.new_data_location( data_file_name=task.generate_data_file_filename("parquet"), partition_key=task.partition_key, @@ -2819,7 +2835,6 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> dict[str, Any]: from pyiceberg.table import TableProperties for key_pattern in [ - TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, TableProperties.PARQUET_BLOOM_FILTER_MAX_BYTES, f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX}.*", ]: diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2f36661a1f..8ec692d8ee 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -74,6 +74,7 @@ _determine_partitions, _primitive_to_physical, _read_deletes, + _resolve_row_group_size, _task_to_record_batches, _to_requested_schema, bin_pack_arrow_table, @@ -3045,6 +3046,89 @@ def test_write_file_rejects_timestamptz_to_timestamp(tmp_path: Path) -> None: list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) +@pytest.mark.parametrize( + "arrow_table,row_group_limit,row_group_size_bytes,expected", + [ + # Byte limit tighter than row limit — 2 int64 cols => 16 bytes/row, + # 1024-byte budget => 64 rows/group. + (pa.table({"a": list(range(1000)), "b": list(range(1000))}), 10_000, 1024, 64), + # Row limit tighter than byte limit. + (pa.table({"a": list(range(1000))}), 10, 10**9, 10), + # Byte limit disabled (0) falls back to the row limit. + (pa.table({"a": list(range(1000))}), 500, 0, 500), + # Empty input falls back to the row limit. + (pa.table({"a": pa.array([], type=pa.int64())}), 500, 1024, 500), + ], +) +def test__resolve_row_group_size(arrow_table: pa.Table, row_group_limit: int, row_group_size_bytes: int, expected: int) -> None: + """Pick min(row_group_limit, bytes/(bytes_per_row)) when byte limit is set.""" + assert _resolve_row_group_size(arrow_table, row_group_limit, row_group_size_bytes) == expected + + +def test_write_file_byte_limit_produces_more_row_groups_than_row_limit_alone(tmp_path: Path) -> None: + """A tight byte limit splits a single arrow table across multiple row groups.""" + from pyiceberg.table import WriteTask + + table_schema = Schema( + NestedField(1, "a", LongType(), required=False), + NestedField(2, "b", LongType(), required=False), + ) + arrow_data = pa.table({"a": list(range(10_000)), "b": list(range(10_000))}) + + def _write(properties: dict[str, str], subdir: str) -> Path: + table_metadata = TableMetadataV2( + location=f"file://{tmp_path}/{subdir}", + last_column_id=2, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + properties=properties, + ) + task = WriteTask( + write_uuid=uuid.uuid4(), + task_id=0, + record_batches=arrow_data.to_batches(), + schema=table_schema, + ) + data_files = list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) + return Path(data_files[0].file_path.removeprefix("file://")) + + default_groups = pq.ParquetFile(_write({}, "default")).num_row_groups + constrained_groups = pq.ParquetFile( + _write({TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES: "1024"}, "constrained") + ).num_row_groups + assert default_groups == 1 + assert constrained_groups > 1 + + +def test_write_file_byte_limit_respects_row_limit_upper_bound(tmp_path: Path) -> None: + """With an effectively infinite byte target, the row limit caps row groups.""" + from pyiceberg.table import WriteTask + + table_schema = Schema(NestedField(1, "a", LongType(), required=False)) + arrow_data = pa.table({"a": list(range(10_000))}) + table_metadata = TableMetadataV2( + location=f"file://{tmp_path}", + last_column_id=1, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + properties={ + TableProperties.PARQUET_ROW_GROUP_LIMIT: "1000", + TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES: str(10**12), + }, + ) + task = WriteTask( + write_uuid=uuid.uuid4(), + task_id=0, + record_batches=arrow_data.to_batches(), + schema=table_schema, + ) + data_files = list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) + pf = pq.ParquetFile(data_files[0].file_path.removeprefix("file://")) + assert pf.num_row_groups == 10 + + def test__to_requested_schema_timestamps( arrow_table_schema_with_all_timestamp_precisions: pa.Schema, arrow_table_with_all_timestamp_precisions: pa.Table,