Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow::array::{
};
use arrow::datatypes::{DataType, i256};
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
use half::f16;
use hashbrown::hash_table::HashTable;
Expand Down Expand Up @@ -80,17 +79,11 @@ hash_float!(f16, f32, f64);
pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
/// The data type of the output array
data_type: DataType,
/// Stores the `(group_index, hash)` based on the hash of its value
///
/// We also store `hash` is for reducing cost of rehashing. Such cost
/// is obvious in high cardinality group by situation.
/// More details can see:
/// <https://github.com/apache/datafusion/issues/15961>
map: HashTable<(usize, u64)>,
/// Stores (value, group_index) pairs directly.
/// Values are compared directly during probing without indirection.
map: HashTable<(T::Native, usize)>,
/// The group index of the null value if any
null_group: Option<usize>,
/// The values for each group index
values: Vec<T::Native>,
/// The random state used to generate hashes
random_state: RandomState,
}
Expand All @@ -101,7 +94,6 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
Self {
data_type,
map: HashTable::with_capacity(128),
values: Vec::with_capacity(128),
null_group: None,
random_state: crate::aggregates::AGGREGATION_HASH_SEED,
}
Expand All @@ -118,28 +110,25 @@ where

for v in cols[0].as_primitive::<T>() {
let group_id = match v {
None => *self.null_group.get_or_insert_with(|| {
let group_id = self.values.len();
self.values.push(Default::default());
group_id
}),
None => match self.null_group {
Some(idx) => idx,
None => {
let g = self.len();
self.null_group = Some(g);
g
}
},
Some(key) => {
let state = &self.random_state;
let hash = key.hash(state);
let insert = self.map.entry(
hash,
|&(g, h)| unsafe {
hash == h && self.values.get_unchecked(g).is_eq(key)
},
|&(_, h)| h,
);

match insert {
hashbrown::hash_table::Entry::Occupied(o) => o.get().0,
hashbrown::hash_table::Entry::Vacant(v) => {
let g = self.values.len();
v.insert((g, hash));
self.values.push(key);
match self.map.find_entry(hash, |entry| entry.0.is_eq(key)) {
Ok(occupied) => occupied.get().1,
Err(absent) => {
let table = absent.into_table();
let g = table.len() + self.null_group.is_some() as usize;
table.insert_unique(hash, (key, g), |entry| {
entry.0.hash(state)
});
g
}
}
Expand All @@ -151,15 +140,15 @@ where
}

fn size(&self) -> usize {
self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size()
self.map.capacity() * size_of::<(T::Native, usize)>()
}

fn is_empty(&self) -> bool {
self.values.is_empty()
self.map.is_empty() && self.null_group.is_none()
}

fn len(&self) -> usize {
self.values.len()
self.map.len() + self.null_group.is_some() as usize
}

fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
Expand All @@ -180,21 +169,23 @@ where

let array: PrimitiveArray<T> = match emit_to {
EmitTo::All => {
let mut values = vec![T::Native::default(); self.len()];
for &(value, group_idx) in self.map.iter() {
values[group_idx] = value;
}
self.map.clear();
build_primitive(std::mem::take(&mut self.values), self.null_group.take())
let null_group = self.null_group.take();
build_primitive(values, null_group)
}
EmitTo::First(n) => {
let mut values = vec![T::Native::default(); n];
self.map.retain(|entry| {
// Decrement group index by n
let group_idx = entry.0;
match group_idx.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => {
entry.0 = sub;
true
}
// Group index was < n, so remove from table
None => false,
if entry.1 < n {
values[entry.1] = entry.0;
false
} else {
entry.1 -= n;
true
}
});
let null_group = match &mut self.null_group {
Expand All @@ -205,19 +196,16 @@ where
Some(_) => self.null_group.take(),
None => None,
};
let mut split = self.values.split_off(n);
std::mem::swap(&mut self.values, &mut split);
build_primitive(split, null_group)
build_primitive(values, null_group)
}
};

Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
}

fn clear_shrink(&mut self, num_rows: usize) {
self.values.clear();
self.values.shrink_to(num_rows);
fn clear_shrink(&mut self, count: usize) {
self.null_group = None;
self.map.clear();
self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared
self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
}
}
2 changes: 1 addition & 1 deletion datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3951,7 +3951,7 @@ mod tests {

// Pool must be large enough for accumulation to start but too small for
// sort_memory after clearing.
let task_ctx = new_spill_ctx(1, 500);
let task_ctx = new_spill_ctx(1, 300);
let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;

match &result {
Expand Down
Loading