diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index f4bc40cf35d5a..1120120fd038a 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -265,6 +265,7 @@ impl ExecutionPlan for WindowAggExec { partition: usize, context: Arc, ) -> Result { + let batch_size = context.session_config().batch_size(); let input = self.input.execute(partition, context)?; let stream = Box::pin(WindowAggStream::new( Arc::clone(&self.schema), @@ -273,6 +274,7 @@ impl ExecutionPlan for WindowAggExec { BaselineMetrics::new(&self.metrics, partition), self.partition_by_sort_keys()?, self.ordered_partition_by_indices.clone(), + batch_size, )?); Ok(stream) } @@ -327,6 +329,15 @@ pub struct WindowAggStream { partition_by_sort_keys: Vec, baseline_metrics: BaselineMetrics, ordered_partition_by_indices: Vec, + /// Target output batch size. The fully-computed result is emitted in + /// slices of at most this many rows so downstream operators are not forced + /// to hold one batch sized to the entire input. + batch_size: usize, + /// The fully-computed result, set once the input is exhausted. Emitted + /// incrementally as zero-copy slices via `emit_offset`. + computed: Option, + /// Number of result rows already emitted from `computed`. + emit_offset: usize, } impl WindowAggStream { @@ -338,6 +349,7 @@ impl WindowAggStream { baseline_metrics: BaselineMetrics, partition_by_sort_keys: Vec, ordered_partition_by_indices: Vec, + batch_size: usize, ) -> Result { // In WindowAggExec all partition by columns should be ordered. assert_eq_or_internal_err!( @@ -354,6 +366,10 @@ impl WindowAggStream { baseline_metrics, partition_by_sort_keys, ordered_partition_by_indices, + // Guard against a zero batch size, which would never make progress. + batch_size: batch_size.max(1), + computed: None, + emit_offset: 0, }) } @@ -425,29 +441,60 @@ impl WindowAggStream { return Poll::Ready(None); } - loop { - return Poll::Ready(Some(match ready!(self.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - self.batches.push(batch); - continue; + // Phase 1: drain the input and compute the full result exactly once. + if self.computed.is_none() { + loop { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + self.batches.push(batch); + } + Some(Err(e)) => { + self.finished = true; + return Poll::Ready(Some(Err(e))); + } + None => { + // Release the input pipeline's resources before + // computing the final aggregates. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + let Some(result) = self.compute_aggregates()? else { + self.finished = true; + return Poll::Ready(None); + }; + self.computed = Some(result); + break; + } } - Some(Err(e)) => Err(e), - None => { - // Release the input pipeline's resources before computing - // the final aggregates. - let input_schema = self.input.schema(); - self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); - let Some(result) = self.compute_aggregates()? else { - return Poll::Ready(None); - }; - self.finished = true; - // Empty record batches should not be emitted. - // They need to be treated as [`Option`]es and handled separately - debug_assert!(result.num_rows() > 0); - Ok(result) - } - })); + } + } + + // Phase 2: emit the computed result in `batch_size` chunks. Slicing is + // zero-copy, so this only bounds the batch each downstream operator + // must hold at once; it does not re-copy the data. + Poll::Ready(self.next_output_batch().transpose()) + } + + /// Returns the next `batch_size`-row slice of the computed result, or + /// `None` once the whole result has been emitted. + fn next_output_batch(&mut self) -> Result> { + let Some(computed) = self.computed.as_ref() else { + return Ok(None); + }; + let total_rows = computed.num_rows(); + if self.emit_offset >= total_rows { + self.finished = true; + return Ok(None); } + let length = self.batch_size.min(total_rows - self.emit_offset); + let batch = computed.slice(self.emit_offset, length); + self.emit_offset += length; + if self.emit_offset >= total_rows { + self.finished = true; + } + // Empty record batches should not be emitted; `compute_aggregates` + // already returns `None` for an empty result, so each slice is non-empty. + debug_assert!(batch.num_rows() > 0); + Ok(Some(batch)) } } @@ -500,4 +547,67 @@ mod tests { )); Ok(()) } + + #[tokio::test] + async fn window_agg_exec_emits_batch_size_chunks() -> Result<()> { + use crate::common::collect; + use arrow::array::{ArrayRef, Int64Array}; + use datafusion_execution::config::SessionConfig; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + // 10 rows in a single partition (no PARTITION BY). + let a: ArrayRef = Arc::new(Int64Array::from((0..10).collect::>())); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a])?; + let input = + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; + + let args = vec![crate::expressions::col("a", &schema)?]; + // Running COUNT over UNBOUNDED PRECEDING .. CURRENT ROW -> 1, 2, ..., 10. + let window_expr = create_window_expr( + &WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count(a)".to_string(), + &args, + &[], + &[], + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + Arc::clone(&schema), + false, + false, + None, + )?; + let window = Arc::new(WindowAggExec::try_new(vec![window_expr], input, true)?); + + // A small batch size forces the single computed result to be emitted in + // multiple chunks instead of one batch sized to the whole input. + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(4)), + ); + + let stream = window.execute(0, task_ctx)?; + let batches = collect(stream).await?; + + // 10 rows with batch_size 4 -> chunks of 4, 4, 2. + assert_eq!( + batches.iter().map(|b| b.num_rows()).collect::>(), + vec![4, 4, 2] + ); + + // The running-count column is unaffected by chunking: it must read + // 1..=10 across the concatenation of the emitted slices. + let combined = concat_batches(&window.schema(), &batches)?; + let count_col = combined + .column(1) + .as_any() + .downcast_ref::() + .expect("count column is Int64"); + let expected = Int64Array::from((1..=10).collect::>()); + assert_eq!(count_col, &expected); + + Ok(()) + } }