Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,15 +2604,30 @@ def data_file_statistics_from_parquet_metadata(
)


def _resolve_row_group_size(arrow_table: pa.Table, row_group_limit: int | None, row_group_size_bytes: int | None) -> int | None:
if not row_group_size_bytes or arrow_table.num_rows == 0:
return row_group_limit
bytes_per_row = max(1, arrow_table.nbytes // arrow_table.num_rows)
rows_for_byte_budget = max(1, row_group_size_bytes // bytes_per_row)
if row_group_limit is None:
return rows_for_byte_budget
return min(row_group_limit, rows_for_byte_budget)


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties

parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
row_group_size = property_as_int(
row_group_limit = property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT,
default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT,
)
row_group_size_bytes = property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
)
location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties)

def write_parquet(task: WriteTask) -> DataFile:
Expand All @@ -2636,6 +2651,7 @@ def write_parquet(task: WriteTask) -> DataFile:
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
row_group_size = _resolve_row_group_size(arrow_table, row_group_limit, row_group_size_bytes)
file_path = location_provider.new_data_location(
data_file_name=task.generate_data_file_filename("parquet"),
partition_key=task.partition_key,
Expand Down Expand Up @@ -2819,7 +2835,6 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> dict[str, Any]:
from pyiceberg.table import TableProperties

for key_pattern in [
TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
TableProperties.PARQUET_BLOOM_FILTER_MAX_BYTES,
f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX}.*",
]:
Expand Down
84 changes: 84 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
_determine_partitions,
_primitive_to_physical,
_read_deletes,
_resolve_row_group_size,
_task_to_record_batches,
_to_requested_schema,
bin_pack_arrow_table,
Expand Down Expand Up @@ -3045,6 +3046,89 @@ def test_write_file_rejects_timestamptz_to_timestamp(tmp_path: Path) -> None:
list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task])))


@pytest.mark.parametrize(
"arrow_table,row_group_limit,row_group_size_bytes,expected",
[
# Byte limit tighter than row limit — 2 int64 cols => 16 bytes/row,
# 1024-byte budget => 64 rows/group.
(pa.table({"a": list(range(1000)), "b": list(range(1000))}), 10_000, 1024, 64),
# Row limit tighter than byte limit.
(pa.table({"a": list(range(1000))}), 10, 10**9, 10),
# Byte limit disabled (0) falls back to the row limit.
(pa.table({"a": list(range(1000))}), 500, 0, 500),
# Empty input falls back to the row limit.
(pa.table({"a": pa.array([], type=pa.int64())}), 500, 1024, 500),
],
)
def test__resolve_row_group_size(arrow_table: pa.Table, row_group_limit: int, row_group_size_bytes: int, expected: int) -> None:
"""Pick min(row_group_limit, bytes/(bytes_per_row)) when byte limit is set."""
assert _resolve_row_group_size(arrow_table, row_group_limit, row_group_size_bytes) == expected


def test_write_file_byte_limit_produces_more_row_groups_than_row_limit_alone(tmp_path: Path) -> None:
"""A tight byte limit splits a single arrow table across multiple row groups."""
from pyiceberg.table import WriteTask

table_schema = Schema(
NestedField(1, "a", LongType(), required=False),
NestedField(2, "b", LongType(), required=False),
)
arrow_data = pa.table({"a": list(range(10_000)), "b": list(range(10_000))})

def _write(properties: dict[str, str], subdir: str) -> Path:
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/{subdir}",
last_column_id=2,
format_version=2,
schemas=[table_schema],
partition_specs=[PartitionSpec()],
properties=properties,
)
task = WriteTask(
write_uuid=uuid.uuid4(),
task_id=0,
record_batches=arrow_data.to_batches(),
schema=table_schema,
)
data_files = list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task])))
return Path(data_files[0].file_path.removeprefix("file://"))

default_groups = pq.ParquetFile(_write({}, "default")).num_row_groups
constrained_groups = pq.ParquetFile(
_write({TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES: "1024"}, "constrained")
).num_row_groups
assert default_groups == 1
assert constrained_groups > 1


def test_write_file_byte_limit_respects_row_limit_upper_bound(tmp_path: Path) -> None:
"""With an effectively infinite byte target, the row limit caps row groups."""
from pyiceberg.table import WriteTask

table_schema = Schema(NestedField(1, "a", LongType(), required=False))
arrow_data = pa.table({"a": list(range(10_000))})
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}",
last_column_id=1,
format_version=2,
schemas=[table_schema],
partition_specs=[PartitionSpec()],
properties={
TableProperties.PARQUET_ROW_GROUP_LIMIT: "1000",
TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES: str(10**12),
},
)
task = WriteTask(
write_uuid=uuid.uuid4(),
task_id=0,
record_batches=arrow_data.to_batches(),
schema=table_schema,
)
data_files = list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task])))
pf = pq.ParquetFile(data_files[0].file_path.removeprefix("file://"))
assert pf.num_row_groups == 10


def test__to_requested_schema_timestamps(
arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
arrow_table_with_all_timestamp_precisions: pa.Table,
Expand Down
Loading