Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 255 additions & 3 deletions datafusion/expr/src/window_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

//! Structures used to hold window function state (for implementing WindowUDFs)

use std::{collections::VecDeque, ops::Range, sync::Arc};
use std::{cmp::Ordering, collections::VecDeque, ops::Range, sync::Arc};

use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};

use arrow::{
array::ArrayRef,
array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray},
compute::{SortOptions, concat, concat_batches},
datatypes::{DataType, SchemaRef},
datatypes::{
DataType, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
Int64Type, SchemaRef, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
},
record_batch::RecordBatch,
};
use datafusion_common::{
Expand Down Expand Up @@ -446,6 +449,25 @@ impl WindowFrameStateRange {
} else {
current_row_values
};
// Fast path: a single primitive (integer/float) ORDER BY column can be
// scanned directly over its native values, avoiding the per-probe
// `Vec<ScalarValue>` allocation and dynamic `ScalarValue` comparison that
// `search_in_slice` performs. The target arithmetic above is unchanged,
// so decimal/temporal/overflow/underflow semantics are identical; only
// the comparison scan is specialized.
if range_columns.len() == 1
&& let Some(sort_options) = self.sort_options.first()
&& let Some(found) = search_single_primitive_range::<SIDE>(
&range_columns[0],
&end_range[0],
sort_options,
search_start,
length,
)
{
return Ok(found);
}

let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare_rows(current, target, &self.sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
Expand All @@ -454,6 +476,116 @@ impl WindowFrameStateRange {
}
}

/// Total-ordering comparison of the value at `idx` in a single primitive column
/// against a (possibly null) `target`, reproducing [`compare_rows`] for one
/// column (including all NULLS FIRST/LAST × ASC/DESC combinations and the
/// float total-order used by `ScalarValue::try_cmp`).
#[inline]
fn primitive_row_cmp<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
idx: usize,
target: Option<T::Native>,
sort_options: &SortOptions,
) -> Ordering
where
T::Native: ArrowNativeTypeOp,
{
match (
array.is_null(idx),
target.is_none(),
sort_options.nulls_first,
) {
(true, false, false) | (false, true, true) => Ordering::Greater,
(true, false, true) | (false, true, false) => Ordering::Less,
(false, false, _) => {
let current = array.value(idx);
// Safe: `target.is_none()` is handled by the arms above.
let target = target.unwrap();
if sort_options.descending {
target.compare(current)
} else {
current.compare(target)
}
}
(true, true, _) => Ordering::Equal,
}
}

/// Native equivalent of [`search_in_slice`] for a single primitive column: scans
/// forward from `low` while the boundary predicate holds, matching the
/// `compare_fn` used by the generic path (`is_lt` for the frame start, `is_le`
/// for the frame end).
#[inline]
fn search_in_primitive_slice<const SIDE: bool, T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
target: Option<T::Native>,
sort_options: &SortOptions,
mut low: usize,
high: usize,
) -> usize
where
T::Native: ArrowNativeTypeOp,
{
while low < high {
let cmp = primitive_row_cmp(array, low, target, sort_options);
let keep = if SIDE { cmp.is_lt() } else { cmp.is_le() };
if !keep {
break;
}
low += 1;
}
low
}

/// Fast path for RANGE-frame boundary search when there is a single primitive
/// integer/float ORDER BY column. Returns `None` so the caller falls back to the
/// generic `ScalarValue` path for multi-column frames, unsupported scalar types
/// (e.g. decimal/temporal, whose scale/units would not match a raw native
/// comparison), or a column/target type mismatch.
fn search_single_primitive_range<const SIDE: bool>(
column: &ArrayRef,
target: &ScalarValue,
sort_options: &SortOptions,
low: usize,
high: usize,
) -> Option<usize> {
macro_rules! search {
($variant:path, $arrow_ty:ty) => {{
// `target` always shares the column's type here (it is either the
// current row value or the result of same-type arithmetic), but
// downcast defensively and fall back on any mismatch.
let $variant(target) = target else {
return None;
};
let array = column
.as_any()
.downcast_ref::<PrimitiveArray<$arrow_ty>>()?;
Some(search_in_primitive_slice::<SIDE, $arrow_ty>(
array,
*target,
sort_options,
low,
high,
))
}};
}

match target {
ScalarValue::Int8(_) => search!(ScalarValue::Int8, Int8Type),
ScalarValue::Int16(_) => search!(ScalarValue::Int16, Int16Type),
ScalarValue::Int32(_) => search!(ScalarValue::Int32, Int32Type),
ScalarValue::Int64(_) => search!(ScalarValue::Int64, Int64Type),
ScalarValue::UInt8(_) => search!(ScalarValue::UInt8, UInt8Type),
ScalarValue::UInt16(_) => search!(ScalarValue::UInt16, UInt16Type),
ScalarValue::UInt32(_) => search!(ScalarValue::UInt32, UInt32Type),
ScalarValue::UInt64(_) => search!(ScalarValue::UInt64, UInt64Type),
ScalarValue::Float16(_) => search!(ScalarValue::Float16, Float16Type),
ScalarValue::Float32(_) => search!(ScalarValue::Float32, Float32Type),
ScalarValue::Float64(_) => search!(ScalarValue::Float64, Float64Type),
_ => None,
}
}

// In GROUPS mode, rows with duplicate sorting values are grouped together.
// Therefore, there must be an ORDER BY clause in the window definition to use GROUPS mode.
// The syntax is as follows:
Expand Down Expand Up @@ -699,6 +831,126 @@ mod tests {
(range_columns, sort_options)
}

/// The single-primitive-column RANGE fast path must return exactly the same
/// boundary index as the generic `ScalarValue` scan for every position,
/// target, and sort-option combination — including nulls, NaN, duplicates,
/// and signed zero. Arrays need not be sorted: both implementations apply
/// the same per-row predicate, so they must agree element-for-element.
#[test]
fn range_fast_path_matches_generic_scan() {
use arrow::array::{Int32Array, UInt32Array};

let i32_col: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(3),
None,
Some(3),
Some(5),
None,
Some(8),
]));
let i32_targets = vec![
ScalarValue::Int32(Some(0)),
ScalarValue::Int32(Some(3)),
ScalarValue::Int32(Some(4)),
ScalarValue::Int32(Some(8)),
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(None),
];

let u32_col: ArrayRef = Arc::new(UInt32Array::from(vec![
Some(0u32),
Some(2),
Some(2),
None,
Some(7),
]));
let u32_targets = vec![
ScalarValue::UInt32(Some(0)),
ScalarValue::UInt32(Some(2)),
ScalarValue::UInt32(Some(3)),
ScalarValue::UInt32(None),
];

let f64_col: ArrayRef = Arc::new(Float64Array::from(vec![
Some(1.0),
Some(f64::NAN),
None,
Some(2.5),
Some(f64::NAN),
Some(-0.0),
]));
let f64_targets = vec![
ScalarValue::Float64(Some(1.0)),
ScalarValue::Float64(Some(f64::NAN)),
ScalarValue::Float64(Some(2.5)),
ScalarValue::Float64(Some(0.0)),
ScalarValue::Float64(None),
];

for (col, targets) in [
(i32_col, i32_targets),
(u32_col, u32_targets),
(f64_col, f64_targets),
] {
let len = col.len();
for descending in [false, true] {
for nulls_first in [false, true] {
let sort_options = SortOptions {
descending,
nulls_first,
};
let sort_options_vec = vec![sort_options];
for target in &targets {
let generic = |side_lt: bool, low: usize| -> usize {
search_in_slice(
std::slice::from_ref(&col),
std::slice::from_ref(target),
|current, tgt| {
let cmp =
compare_rows(current, tgt, &sort_options_vec)?;
Ok(if side_lt { cmp.is_lt() } else { cmp.is_le() })
},
low,
len,
)
.unwrap()
};
for low in 0..=len {
let native_start = search_single_primitive_range::<true>(
&col,
target,
&sort_options,
low,
len,
)
.expect("primitive column should hit the fast path");
assert_eq!(
native_start,
generic(true, low),
"start: target={target:?} desc={descending} nulls_first={nulls_first} low={low}"
);

let native_end = search_single_primitive_range::<false>(
&col,
target,
&sort_options,
low,
len,
)
.expect("primitive column should hit the fast path");
assert_eq!(
native_end,
generic(false, low),
"end: target={target:?} desc={descending} nulls_first={nulls_first} low={low}"
);
}
}
}
}
}
}

fn assert_group_ranges(
window_frame: &Arc<WindowFrame>,
expected_results: Vec<(Range<usize>, usize)>,
Expand Down
Loading