Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.table.puffin import PuffinFile
from pyiceberg.transforms import IdentityTransform, TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record, TableVersion
from pyiceberg.typedef import EMPTY_DICT, ArrowStreamExportable, Properties, Record, TableVersion
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -2680,30 +2680,45 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[list[
"""Bin-pack ``tbl`` into groups of RecordBatches, each ~``target_file_size``.

Note:
``target_file_size`` is measured in **uncompressed in-memory** Arrow bytes
(``Table.nbytes`` / ``RecordBatch.nbytes``), not compressed on-disk Parquet
bytes. The resulting Parquet file after compression (zstd by default,
plus dictionary/RLE encoding) is typically 3-10× smaller than
``target_file_size``. This is a coarse proxy for the spec-defined
``target_file_size`` is measured in **uncompressed in-memory** Arrow
bytes, not compressed on-disk Parquet bytes. The size estimate uses
``nbytes`` when available and falls back to referenced buffer size for
Arrow view types that do not support ``nbytes``. The resulting Parquet
file after compression (zstd by default, plus dictionary/RLE encoding)
is typically 3-10× smaller than ``target_file_size``. This is a coarse
proxy for the spec-defined
``write.target-file-size-bytes`` and will be tightened to true on-disk
bytes once the writer is switched to a rolling-``ParquetWriter`` with
``OutputStream.tell()`` (#2998).
"""
from pyiceberg.utils.bin_packing import PackingIterator

avg_row_size_bytes = tbl.nbytes / tbl.num_rows
avg_row_size_bytes = _arrow_data_size(tbl) / tbl.num_rows
target_rows_per_file = max(1, int(target_file_size / avg_row_size_bytes))
batches = tbl.to_batches(max_chunksize=target_rows_per_file)
bin_packed_record_batches = PackingIterator(
items=batches,
target_weight=target_file_size,
lookback=len(batches), # ignore lookback
weight_func=lambda x: x.nbytes,
weight_func=_arrow_data_size,
largest_bin_first=False,
)
return bin_packed_record_batches


def _arrow_data_size(data: pa.Table | pa.RecordBatch) -> int:
"""Estimate Arrow data size for writer bin-packing.

``nbytes`` is the better logical-size estimate, but PyArrow can raise for
view types such as ``string_view`` exported by libraries like Polars. Fall
back to total referenced buffer size so those streams can still be written.
"""
try:
return data.nbytes
except pyarrow.lib.ArrowTypeError:
return data.get_total_buffer_size()


def bin_pack_record_batches(batches: Iterable[pa.RecordBatch], target_file_size: int) -> Iterator[list[pa.RecordBatch]]:
"""Microbatch a single-pass stream of RecordBatches into target-sized groups.

Expand All @@ -2719,9 +2734,11 @@ def bin_pack_record_batches(batches: Iterable[pa.RecordBatch], target_file_size:

Note:
``target_file_size`` is measured in **uncompressed in-memory** Arrow
bytes (``RecordBatch.nbytes``), not compressed on-disk Parquet bytes.
The resulting Parquet file after compression is typically 3-10×
smaller than ``target_file_size``. Matches the existing
bytes, not compressed on-disk Parquet bytes. The size estimate uses
``nbytes`` when available and falls back to referenced buffer size for
Arrow view types that do not support ``nbytes``. The resulting Parquet
file after compression is typically 3-10× smaller than
``target_file_size``. Matches the existing
:func:`bin_pack_arrow_table` semantics; both will be tightened to true
on-disk bytes once the writer is switched to a rolling-
``ParquetWriter`` with ``OutputStream.tell()`` (#2998).
Expand All @@ -2730,7 +2747,7 @@ def bin_pack_record_batches(batches: Iterable[pa.RecordBatch], target_file_size:
buffer_bytes = 0
for batch in batches:
buffer.append(batch)
buffer_bytes += batch.nbytes
buffer_bytes += _arrow_data_size(batch)
if buffer_bytes >= target_file_size:
yield buffer
buffer = []
Expand Down Expand Up @@ -3033,3 +3050,23 @@ def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Ar
field_array = arrow_table[path_parts[0]]
# Navigate into the struct using the remaining path parts
return pc.struct_field(field_array, path_parts[1:])


def _coerce_arrow_input(df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable) -> pa.Table | pa.RecordBatchReader:
"""Normalize Arrow write input to a pa.Table or pa.RecordBatchReader.

Native pyarrow inputs pass through unchanged; any object implementing the
Arrow PyCapsule stream interface (``__arrow_c_stream__``) is imported as a
streaming RecordBatchReader.
"""
if isinstance(df, (pa.Table, pa.RecordBatchReader)):
return df

# Any object implementing the Arrow PyCapsule stream interface.
if hasattr(df, "__arrow_c_stream__"):
return pa.RecordBatchReader.from_stream(df)

raise ValueError(
f"Expected pa.Table, pa.RecordBatchReader, or an object implementing the "
f"Arrow PyCapsule interface (__arrow_c_stream__), got: {df!r}"
)
27 changes: 17 additions & 10 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import (
EMPTY_DICT,
ArrowStreamExportable,
IcebergBaseModel,
IcebergRootModel,
Identifier,
Expand Down Expand Up @@ -452,7 +453,7 @@ def update_statistics(self) -> UpdateStatistics:

def append(
self,
df: pa.Table | pa.RecordBatchReader,
df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
snapshot_properties: dict[str, str] = EMPTY_DICT,
branch: str | None = MAIN_BRANCH,
) -> None:
Expand Down Expand Up @@ -505,10 +506,9 @@ def append(
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files
from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _coerce_arrow_input, _dataframe_to_data_files

if not isinstance(df, (pa.Table, pa.RecordBatchReader)):
raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}")
df = _coerce_arrow_input(df)

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
Expand Down Expand Up @@ -598,7 +598,7 @@ def dynamic_partition_overwrite(

def overwrite(
self,
df: pa.Table | pa.RecordBatchReader,
df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
overwrite_filter: BooleanExpression | str = ALWAYS_TRUE,
snapshot_properties: dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
Expand Down Expand Up @@ -662,10 +662,9 @@ def overwrite(
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files
from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _coerce_arrow_input, _dataframe_to_data_files

if not isinstance(df, (pa.Table, pa.RecordBatchReader)):
raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}")
df = _coerce_arrow_input(df)

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
Expand Down Expand Up @@ -1472,7 +1471,7 @@ def upsert(

def append(
self,
df: pa.Table | pa.RecordBatchReader,
df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
snapshot_properties: dict[str, str] = EMPTY_DICT,
branch: str | None = MAIN_BRANCH,
) -> None:
Expand Down Expand Up @@ -1507,7 +1506,7 @@ def dynamic_partition_overwrite(

def overwrite(
self,
df: pa.Table | pa.RecordBatchReader,
df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
overwrite_filter: BooleanExpression | str = ALWAYS_TRUE,
snapshot_properties: dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
Expand Down Expand Up @@ -1716,6 +1715,10 @@ def __datafusion_table_provider__(self, session: Any | None = None) -> IcebergDa
).__datafusion_table_provider__
return provider(session)

def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""Export this Table as an Arrow C stream (PyCapsule interface)."""
return self.scan().to_arrow_batch_reader().__arrow_c_stream__(requested_schema)


class StaticTable(Table):
"""Load a table directly from a metadata file (i.e., without using a catalog)."""
Expand Down Expand Up @@ -2252,6 +2255,10 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
batches,
).cast(target_schema)

def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""Export this scan's result as an Arrow C stream (PyCapsule interface)."""
return self.to_arrow_batch_reader().__arrow_c_stream__(requested_schema)

def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
"""Read a Pandas DataFrame eagerly from this Iceberg table.

Expand Down
13 changes: 13 additions & 0 deletions pyiceberg/typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ def __setitem__(self, pos: int, value: Any) -> None:
"""Assign a value to a StructProtocol."""


@runtime_checkable
class ArrowStreamExportable(Protocol): # pragma: no cover
"""Any object implementing the Arrow PyCapsule stream interface.

Covers pa.Table, pa.RecordBatchReader, and third-party producers
(polars, arro3, nanoarrow, ...) without depending on any of them.
"""

@abstractmethod
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""Export the object as an Arrow C stream PyCapsule."""


class IcebergBaseModel(BaseModel):
"""
This class extends the Pydantic BaseModel to set default values by overriding them.
Expand Down
2 changes: 1 addition & 1 deletion tests/catalog/test_catalog_behaviors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ def test_append_invalid_input_type_raises(catalog: Catalog) -> None:
identifier = f"default.append_invalid_input_{catalog.name}"
pa_table = _simple_arrow_table()
tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema)
with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader"):
with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"):
tbl.append("not an arrow object")


Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
properties={"format-version": "1"},
)

with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"):
with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"):
tbl.append("not a df")


Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,10 +791,10 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_
identifier = "default.arrow_data_files"
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])

with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"):
with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"):
tbl.overwrite("not a df")

with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"):
with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"):
tbl.append("not a df")


Expand Down
11 changes: 11 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,6 +2435,17 @@ def test_bin_pack_arrow_table_target_size_smaller_than_row(arrow_table_with_null
assert sum(batch.num_rows for bin_ in bin_packed for batch in bin_) == arrow_table_with_null.num_rows


def test_bin_pack_arrow_table_with_string_view() -> None:
if not hasattr(pa, "string_view"):
pytest.skip("pyarrow does not support string_view")

table = pa.table({"region": pa.array(["ca", "mx"], type=pa.string_view())})

bins = list(bin_pack_arrow_table(table, target_file_size=1))

assert sum(batch.num_rows for bin_ in bins for batch in bin_) == table.num_rows


def test_bin_pack_record_batches_single_bin(arrow_table_with_null: pa.Table) -> None:
batches = arrow_table_with_null.to_batches()
bins = list(bin_pack_record_batches(iter(batches), target_file_size=arrow_table_with_null.nbytes * 10))
Expand Down
Loading