From c775004a9807b5bb0870296150a8c5936e955536 Mon Sep 17 00:00:00 2001 From: Alexander Bianchi Date: Thu, 26 Feb 2026 03:12:17 +0000 Subject: [PATCH] dictionary encoded group by's --- datafusion/physical-plan/Cargo.toml | 5 + .../benches/dict_group_values.rs | 146 ++++++ .../src/aggregates/group_values/mod.rs | 2 +- .../group_values/multi_group_by/dictionary.rs | 179 +++++++ .../group_values/multi_group_by/mod.rs | 469 ++++++++++-------- 5 files changed, 603 insertions(+), 198 deletions(-) create mode 100644 datafusion/physical-plan/benches/dict_group_values.rs create mode 100644 datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dictionary.rs diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 6a28486cca5dc..d77f6cfc0b439 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -107,3 +107,8 @@ required-features = ["test_utils"] harness = false name = "aggregate_vectorized" required-features = ["test_utils"] + +[[bench]] +harness = false +name = "dict_group_values" +required-features = ["test_utils"] diff --git a/datafusion/physical-plan/benches/dict_group_values.rs b/datafusion/physical-plan/benches/dict_group_values.rs new file mode 100644 index 0000000000000..cad9ddbf67cf1 --- /dev/null +++ b/datafusion/physical-plan/benches/dict_group_values.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for GROUP BY on dictionary-encoded columns. +//! +//! Compares three paths: +//! - `column_utf8`: GroupValuesColumn with plain Utf8 (fast-path baseline) +//! - `column_dict`: GroupValuesColumn with Dictionary(Int32, Utf8) (new path) +//! - `rows_dict`: GroupValuesRows with Dictionary(Int32, Utf8) (old fallback) + +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{ + BenchmarkId, Criterion, criterion_group, criterion_main, +}; +use datafusion_physical_plan::aggregates::group_values::multi_group_by::GroupValuesColumn; +use datafusion_physical_plan::aggregates::group_values::row::GroupValuesRows; +use datafusion_physical_plan::aggregates::group_values::GroupValues; +use rand::Rng; +use rand::rngs::StdRng; +use rand::SeedableRng; + +const CARDINALITIES: [usize; 3] = [50, 1_000, 10_000]; +const BATCH_SIZES: [usize; 2] = [8_192, 65_536]; +const NUM_BATCHES: usize = 10; + +/// Generate `num_rows` random string values chosen from `cardinality` distinct strings, +/// returned as both plain Utf8 and Dictionary(Int32, Utf8) arrays. +fn generate_string_batches( + num_rows: usize, + cardinality: usize, + seed: u64, +) -> (ArrayRef, ArrayRef) { + let mut rng = StdRng::seed_from_u64(seed); + + // Build a pool of distinct strings + let pool: Vec = (0..cardinality) + .map(|i| format!("group_value_{i:06}")) + .collect(); + + let values: Vec<&str> = (0..num_rows) + .map(|_| pool[rng.random_range(0..cardinality)].as_str()) + .collect(); + + let utf8_array: ArrayRef = Arc::new(StringArray::from(values)); + let dict_array = cast(&utf8_array, &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))) + .expect("cast to dictionary"); + + (utf8_array, dict_array) +} + +fn bench_dict_group_values(c: &mut Criterion) { + let mut group = c.benchmark_group("dict_group_values"); + + for &cardinality in &CARDINALITIES { + for &batch_size in &BATCH_SIZES { + // Pre-generate batches (both utf8 and dict variants) + let batches: Vec<(ArrayRef, ArrayRef)> = (0..NUM_BATCHES as u64) + .map(|seed| generate_string_batches(batch_size, cardinality, seed)) + .collect(); + + let param = format!("card_{cardinality}/batch_{batch_size}"); + + // ---- column_utf8: GroupValuesColumn with plain Utf8 (baseline fast path) ---- + { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Utf8, false), + ])); + let id = BenchmarkId::new("column_utf8", ¶m); + group.bench_function(id, |b| { + b.iter(|| { + let mut gv = GroupValuesColumn::::try_new(Arc::clone(&schema)).unwrap(); + let mut groups = Vec::new(); + for (utf8, _dict) in &batches { + gv.intern(&[Arc::clone(utf8)], &mut groups).unwrap(); + } + }); + }); + } + + // ---- column_dict: GroupValuesColumn with Dictionary(Int32, Utf8) (new path) ---- + { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "key", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + ])); + let id = BenchmarkId::new("column_dict", ¶m); + group.bench_function(id, |b| { + b.iter(|| { + let mut gv = GroupValuesColumn::::try_new(Arc::clone(&schema)).unwrap(); + let mut groups = Vec::new(); + for (_utf8, dict) in &batches { + gv.intern(&[Arc::clone(dict)], &mut groups).unwrap(); + } + }); + }); + } + + // ---- rows_dict: GroupValuesRows with Dictionary(Int32, Utf8) (old fallback) ---- + { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "key", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + ])); + let id = BenchmarkId::new("rows_dict", ¶m); + group.bench_function(id, |b| { + b.iter(|| { + let mut gv = GroupValuesRows::try_new(Arc::clone(&schema)).unwrap(); + let mut groups = Vec::new(); + for (_utf8, dict) in &batches { + gv.intern(&[Arc::clone(dict)], &mut groups).unwrap(); + } + }); + }); + } + } + } + + group.finish(); +} + +criterion_group!(benches, bench_dict_group_values); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2f3b1a19e7d73..be289f08c085b 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -30,7 +30,7 @@ use datafusion_expr::EmitTo; pub mod multi_group_by; -mod row; +pub mod row; mod single_group_by; use datafusion_physical_expr::binary_map::OutputType; use multi_group_by::GroupValuesColumn; diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dictionary.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dictionary.rs new file mode 100644 index 0000000000000..a419cbb4bc520 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dictionary.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DictionaryGroupValueBuilder`] for dictionary-encoded GROUP BY columns. + +use std::marker::PhantomData; + +use arrow::array::{Array, ArrayRef, DictionaryArray, new_null_array}; +use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType}; +use datafusion_common::Result; + +use super::GroupColumn; + +/// A [`GroupColumn`] wrapper that transparently handles dictionary-encoded +/// input arrays by resolving dictionary keys on-demand. +/// +/// Instead of materializing the full decoded array via `cast()` (which copies +/// O(batch_size) strings per batch), this builder looks up values through +/// dictionary keys, only copying data for rows that are actually appended as +/// new groups. Comparisons index directly into the dictionary's values array. +/// +/// The inner builder stores decoded values. On emit, the existing code in +/// [`GroupValuesColumn::emit`] re-encodes back to dictionary via `cast()`. +/// +/// [`GroupValuesColumn::emit`]: super::GroupValuesColumn +pub struct DictionaryGroupValueBuilder { + /// Inner builder that operates on the dictionary's value type + inner: Box, + /// A single-element null array of the value type, used to represent null + /// dictionary keys to the inner builder + null_array: ArrayRef, + _phantom: PhantomData, +} + +impl DictionaryGroupValueBuilder { + pub fn new(inner: Box, value_type: &DataType) -> Self { + let null_array = new_null_array(value_type, 1); + Self { + inner, + null_array, + _phantom: PhantomData, + } + } +} + +impl GroupColumn for DictionaryGroupValueBuilder { + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + let dict = array + .as_any() + .downcast_ref::>() + .unwrap(); + if dict.is_null(rhs_row) { + return self.inner.equal_to(lhs_row, &self.null_array, 0); + } + let key = dict.keys().value(rhs_row).as_usize(); + self.inner.equal_to(lhs_row, dict.values(), key) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { + let dict = array + .as_any() + .downcast_ref::>() + .unwrap(); + if dict.is_null(row) { + return self.inner.append_val(&self.null_array, 0); + } + let key = dict.keys().value(row).as_usize(); + self.inner.append_val(dict.values(), key) + } + + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + let dict = array + .as_any() + .downcast_ref::>() + .unwrap(); + let keys = dict.keys(); + let values = dict.values(); + + if dict.null_count() == 0 { + // Fast path: no null keys, remap indices and delegate to inner + let mapped_rhs: Vec = rhs_rows + .iter() + .map(|&row| keys.value(row).as_usize()) + .collect(); + self.inner + .vectorized_equal_to(lhs_rows, values, &mapped_rhs, equal_to_results); + } else { + // Null keys present: fall back to scalar comparison + for (i, (&lhs_row, &rhs_row)) in + lhs_rows.iter().zip(rhs_rows.iter()).enumerate() + { + if !equal_to_results[i] { + continue; + } + if dict.is_null(rhs_row) { + equal_to_results[i] = + self.inner.equal_to(lhs_row, &self.null_array, 0); + } else { + let key = keys.value(rhs_row).as_usize(); + equal_to_results[i] = + self.inner.equal_to(lhs_row, values, key); + } + } + } + } + + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { + let dict = array + .as_any() + .downcast_ref::>() + .unwrap(); + let keys = dict.keys(); + let values = dict.values(); + + if dict.null_count() == 0 { + // Fast path: no null keys, remap indices and delegate to inner + let mapped_rows: Vec = rows + .iter() + .map(|&row| keys.value(row).as_usize()) + .collect(); + self.inner.vectorized_append(values, &mapped_rows) + } else { + // Null keys present: process in order, chunking consecutive + // non-null rows for vectorized processing + let mut i = 0; + while i < rows.len() { + if dict.is_null(rows[i]) { + self.inner.append_val(&self.null_array, 0)?; + i += 1; + } else { + // Collect consecutive non-null rows + let mut chunk = Vec::new(); + while i < rows.len() && !dict.is_null(rows[i]) { + chunk.push(keys.value(rows[i]).as_usize()); + i += 1; + } + self.inner.vectorized_append(values, &chunk)?; + } + } + Ok(()) + } + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn size(&self) -> usize { + self.inner.size() + self.null_array.get_array_memory_size() + } + + fn build(self: Box) -> ArrayRef { + self.inner.build() + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + self.inner.take_n(n) + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 479bff001e3c8..d6450244ab260 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -20,6 +20,7 @@ mod boolean; mod bytes; pub mod bytes_view; +mod dictionary; pub mod primitive; use std::mem::{self, size_of}; @@ -27,7 +28,8 @@ use std::mem::{self, size_of}; use crate::aggregates::group_values::GroupValues; use crate::aggregates::group_values::multi_group_by::{ boolean::BooleanGroupValueBuilder, bytes::ByteGroupValueBuilder, - bytes_view::ByteViewGroupValueBuilder, primitive::PrimitiveGroupValueBuilder, + bytes_view::ByteViewGroupValueBuilder, dictionary::DictionaryGroupValueBuilder, + primitive::PrimitiveGroupValueBuilder, }; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; @@ -870,186 +872,127 @@ impl GroupValuesColumn { } } -/// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v +/// Creates the appropriate [`GroupColumn`] builder for the given data type. /// -/// Arguments: -/// `$v`: the vector to push the new builder into -/// `$nullable`: whether the input can contains nulls -/// `$t`: the primitive type of the builder -macro_rules! instantiate_primitive { - ($v:expr, $nullable:expr, $t:ty, $data_type:ident) => { - if $nullable { - let b = PrimitiveGroupValueBuilder::<$t, true>::new($data_type.to_owned()); - $v.push(Box::new(b) as _) - } else { - let b = PrimitiveGroupValueBuilder::<$t, false>::new($data_type.to_owned()); - $v.push(Box::new(b) as _) +/// For dictionary types, creates a [`DictionaryGroupValueBuilder`] wrapping +/// an inner value-type builder that resolves keys on-demand. +fn make_group_column( + data_type: &DataType, + nullable: bool, +) -> Result> { + /// Helper: returns a boxed [`PrimitiveGroupValueBuilder`] with the correct + /// nullability const-generic. + macro_rules! primitive { + ($t:ty, $data_type:expr) => { + if nullable { + Ok(Box::new(PrimitiveGroupValueBuilder::<$t, true>::new( + $data_type.to_owned(), + )) as _) + } else { + Ok(Box::new(PrimitiveGroupValueBuilder::<$t, false>::new( + $data_type.to_owned(), + )) as _) + } + }; + } + + match data_type { + DataType::Dictionary(key_type, value_type) => { + let inner = make_group_column(value_type.as_ref(), nullable)?; + macro_rules! wrap { + ($kt:ty) => { + Ok(Box::new(DictionaryGroupValueBuilder::<$kt>::new( + inner, + value_type.as_ref(), + )) as _) + }; + } + match key_type.as_ref() { + DataType::Int8 => wrap!(Int8Type), + DataType::Int16 => wrap!(Int16Type), + DataType::Int32 => wrap!(Int32Type), + DataType::Int64 => wrap!(Int64Type), + DataType::UInt8 => wrap!(UInt8Type), + DataType::UInt16 => wrap!(UInt16Type), + DataType::UInt32 => wrap!(UInt32Type), + DataType::UInt64 => wrap!(UInt64Type), + _ => not_impl_err!( + "dictionary key type {key_type} not supported in GroupValuesColumn" + ), + } } - }; + DataType::Int8 => primitive!(Int8Type, data_type), + DataType::Int16 => primitive!(Int16Type, data_type), + DataType::Int32 => primitive!(Int32Type, data_type), + DataType::Int64 => primitive!(Int64Type, data_type), + DataType::UInt8 => primitive!(UInt8Type, data_type), + DataType::UInt16 => primitive!(UInt16Type, data_type), + DataType::UInt32 => primitive!(UInt32Type, data_type), + DataType::UInt64 => primitive!(UInt64Type, data_type), + DataType::Float32 => primitive!(Float32Type, data_type), + DataType::Float64 => primitive!(Float64Type, data_type), + DataType::Date32 => primitive!(Date32Type, data_type), + DataType::Date64 => primitive!(Date64Type, data_type), + DataType::Time32(TimeUnit::Second) => { + primitive!(Time32SecondType, data_type) + } + DataType::Time32(TimeUnit::Millisecond) => { + primitive!(Time32MillisecondType, data_type) + } + DataType::Time64(TimeUnit::Microsecond) => { + primitive!(Time64MicrosecondType, data_type) + } + DataType::Time64(TimeUnit::Nanosecond) => { + primitive!(Time64NanosecondType, data_type) + } + DataType::Timestamp(TimeUnit::Second, _) => { + primitive!(TimestampSecondType, data_type) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + primitive!(TimestampMillisecondType, data_type) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + primitive!(TimestampMicrosecondType, data_type) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + primitive!(TimestampNanosecondType, data_type) + } + DataType::Decimal128(_, _) => primitive!(Decimal128Type, data_type), + DataType::Utf8 => { + Ok(Box::new(ByteGroupValueBuilder::::new(OutputType::Utf8))) + } + DataType::LargeUtf8 => { + Ok(Box::new(ByteGroupValueBuilder::::new(OutputType::Utf8))) + } + DataType::Binary => { + Ok(Box::new(ByteGroupValueBuilder::::new(OutputType::Binary))) + } + DataType::LargeBinary => { + Ok(Box::new(ByteGroupValueBuilder::::new(OutputType::Binary))) + } + DataType::Utf8View => { + Ok(Box::new(ByteViewGroupValueBuilder::::new())) + } + DataType::BinaryView => { + Ok(Box::new(ByteViewGroupValueBuilder::::new())) + } + DataType::Boolean => { + if nullable { + Ok(Box::new(BooleanGroupValueBuilder::::new())) + } else { + Ok(Box::new(BooleanGroupValueBuilder::::new())) + } + } + dt => not_impl_err!("{dt} not supported in GroupValuesColumn"), + } } impl GroupValues for GroupValuesColumn { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { if self.group_values.is_empty() { let mut v = Vec::with_capacity(cols.len()); - for f in self.schema.fields().iter() { - let nullable = f.is_nullable(); - let data_type = f.data_type(); - match data_type { - &DataType::Int8 => { - instantiate_primitive!(v, nullable, Int8Type, data_type) - } - &DataType::Int16 => { - instantiate_primitive!(v, nullable, Int16Type, data_type) - } - &DataType::Int32 => { - instantiate_primitive!(v, nullable, Int32Type, data_type) - } - &DataType::Int64 => { - instantiate_primitive!(v, nullable, Int64Type, data_type) - } - &DataType::UInt8 => { - instantiate_primitive!(v, nullable, UInt8Type, data_type) - } - &DataType::UInt16 => { - instantiate_primitive!(v, nullable, UInt16Type, data_type) - } - &DataType::UInt32 => { - instantiate_primitive!(v, nullable, UInt32Type, data_type) - } - &DataType::UInt64 => { - instantiate_primitive!(v, nullable, UInt64Type, data_type) - } - &DataType::Float32 => { - instantiate_primitive!(v, nullable, Float32Type, data_type) - } - &DataType::Float64 => { - instantiate_primitive!(v, nullable, Float64Type, data_type) - } - &DataType::Date32 => { - instantiate_primitive!(v, nullable, Date32Type, data_type) - } - &DataType::Date64 => { - instantiate_primitive!(v, nullable, Date64Type, data_type) - } - &DataType::Time32(t) => match t { - TimeUnit::Second => { - instantiate_primitive!( - v, - nullable, - Time32SecondType, - data_type - ) - } - TimeUnit::Millisecond => { - instantiate_primitive!( - v, - nullable, - Time32MillisecondType, - data_type - ) - } - _ => {} - }, - &DataType::Time64(t) => match t { - TimeUnit::Microsecond => { - instantiate_primitive!( - v, - nullable, - Time64MicrosecondType, - data_type - ) - } - TimeUnit::Nanosecond => { - instantiate_primitive!( - v, - nullable, - Time64NanosecondType, - data_type - ) - } - _ => {} - }, - &DataType::Timestamp(t, _) => match t { - TimeUnit::Second => { - instantiate_primitive!( - v, - nullable, - TimestampSecondType, - data_type - ) - } - TimeUnit::Millisecond => { - instantiate_primitive!( - v, - nullable, - TimestampMillisecondType, - data_type - ) - } - TimeUnit::Microsecond => { - instantiate_primitive!( - v, - nullable, - TimestampMicrosecondType, - data_type - ) - } - TimeUnit::Nanosecond => { - instantiate_primitive!( - v, - nullable, - TimestampNanosecondType, - data_type - ) - } - }, - &DataType::Decimal128(_, _) => { - instantiate_primitive! { - v, - nullable, - Decimal128Type, - data_type - } - } - &DataType::Utf8 => { - let b = ByteGroupValueBuilder::::new(OutputType::Utf8); - v.push(Box::new(b) as _) - } - &DataType::LargeUtf8 => { - let b = ByteGroupValueBuilder::::new(OutputType::Utf8); - v.push(Box::new(b) as _) - } - &DataType::Binary => { - let b = ByteGroupValueBuilder::::new(OutputType::Binary); - v.push(Box::new(b) as _) - } - &DataType::LargeBinary => { - let b = ByteGroupValueBuilder::::new(OutputType::Binary); - v.push(Box::new(b) as _) - } - &DataType::Utf8View => { - let b = ByteViewGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - &DataType::BinaryView => { - let b = ByteViewGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - &DataType::Boolean => { - if nullable { - let b = BooleanGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } else { - let b = BooleanGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - } - dt => { - return not_impl_err!("{dt} not supported in GroupValuesColumn"); - } - } + v.push(make_group_column(f.data_type(), f.is_nullable())?); } self.group_values = v; } @@ -1211,31 +1154,34 @@ pub fn supported_schema(schema: &Schema) -> bool { /// In order to be supported, there must be a specialized implementation of /// [`GroupColumn`] for the data type, instantiated in [`GroupValuesColumn::intern`] fn supported_type(data_type: &DataType) -> bool { - matches!( - *data_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Timestamp(_, _) - | DataType::Utf8View - | DataType::BinaryView - | DataType::Boolean - ) + match data_type { + DataType::Dictionary(_, value_type) => supported_type(value_type.as_ref()), + _ => matches!( + *data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Timestamp(_, _) + | DataType::Utf8View + | DataType::BinaryView + | DataType::Boolean + ), + } } ///Shows how many `null`s there are in an array @@ -1809,4 +1755,133 @@ mod tests { &mut group_values.map_size, ); } + + #[test] + fn test_supported_type_dictionary() { + use super::supported_type; + + // Dictionary with supported value types should be accepted + assert!(supported_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ))); + assert!(supported_type(&DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Int64), + ))); + assert!(supported_type(&DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Boolean), + ))); + + // Dictionary with unsupported value types should be rejected + assert!(!supported_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Struct(Default::default())), + ))); + } + + #[test] + fn test_intern_with_dictionary_columns() { + use arrow::array::DictionaryArray; + use arrow::compute::cast; + + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_key", dict_type.clone(), true), + Field::new("int_key", DataType::Int64, true), + ])); + + let mut gv = GroupValuesColumn::::try_new(Arc::clone(&schema)).unwrap(); + let mut groups = Vec::new(); + + // Batch 1: some overlapping groups + let utf8_col1 = StringArray::from(vec![ + Some("alpha"), + Some("beta"), + Some("alpha"), + None, + Some("beta"), + ]); + let dict_col1 = cast(&utf8_col1, &dict_type).unwrap(); + let int_col1: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(1), + Some(3), + Some(2), + ])); + + gv.intern(&[dict_col1, int_col1], &mut groups).unwrap(); + // ("alpha", 1), ("beta", 2), (null, 3) => 3 distinct groups + assert_eq!(groups, vec![0, 1, 0, 2, 1]); + assert_eq!(gv.len(), 3); + + // Batch 2: mix of existing and new groups + let utf8_col2 = StringArray::from(vec![ + Some("alpha"), + Some("gamma"), + None, + ]); + let dict_col2 = cast(&utf8_col2, &dict_type).unwrap(); + let int_col2: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(1), + Some(4), + Some(3), + ])); + + gv.intern(&[dict_col2, int_col2], &mut groups).unwrap(); + // ("alpha", 1) exists as group 0, ("gamma", 4) new, (null, 3) exists as group 2 + assert_eq!(groups, vec![0, 3, 2]); + assert_eq!(gv.len(), 4); + + // Emit and verify output types match schema (dictionary re-encoding) + let output = gv.emit(EmitTo::All).unwrap(); + assert_eq!(output.len(), 2); + assert_eq!(output[0].data_type(), &dict_type); + assert_eq!(output[1].data_type(), &DataType::Int64); + + // Verify the dictionary column values + let dict_out = output[0] + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict_out.len(), 4); + } + + #[test] + fn test_dictionary_null_handling() { + use arrow::compute::cast; + + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![ + Field::new("key", dict_type.clone(), true), + ])); + + let mut gv = GroupValuesColumn::::try_new(Arc::clone(&schema)).unwrap(); + let mut groups = Vec::new(); + + // Null keys and null dictionary values should both produce null groups + let utf8_col = StringArray::from(vec![ + None, + Some("a"), + None, + Some("b"), + None, + ]); + let dict_col = cast(&utf8_col, &dict_type).unwrap(); + + gv.intern(&[dict_col], &mut groups).unwrap(); + // null, "a", null, "b", null => 3 distinct: null(0), "a"(1), "b"(2) + assert_eq!(groups, vec![0, 1, 0, 2, 0]); + assert_eq!(gv.len(), 3); + + // Emit and verify dictionary output + let output = gv.emit(EmitTo::All).unwrap(); + assert_eq!(output[0].data_type(), &dict_type); + assert_eq!(output[0].null_count(), 1); + assert_eq!(output[0].len(), 3); + } }