Skip to content

Commit 10a1d4e

Browse files
authored
Remove UDAF manual Debug impls and simplify signatures (#19727)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #18092 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Main value add here is ensure UDAFs encode their actual accepted types in their signature instead of internally casting to the actual types they support from a wider signature. Also doing some driveby refactoring of removing manual Debug impls. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> See rationale. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Existing tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent e82dc21 commit 10a1d4e

11 files changed

Lines changed: 101 additions & 322 deletions

File tree

datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use std::any::Any;
19-
use std::fmt::{Debug, Formatter};
19+
use std::fmt::Debug;
2020
use std::hash::Hash;
2121
use std::mem::size_of_val;
2222
use std::sync::Arc;
@@ -111,20 +111,12 @@ An alternative syntax is also supported:
111111
description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
112112
)
113113
)]
114-
#[derive(PartialEq, Eq, Hash)]
114+
#[derive(PartialEq, Eq, Hash, Debug)]
115115
pub struct ApproxPercentileContWithWeight {
116116
signature: Signature,
117117
approx_percentile_cont: ApproxPercentileCont,
118118
}
119119

120-
impl Debug for ApproxPercentileContWithWeight {
121-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122-
f.debug_struct("ApproxPercentileContWithWeight")
123-
.field("signature", &self.signature)
124-
.finish()
125-
}
126-
}
127-
128120
impl Default for ApproxPercentileContWithWeight {
129121
fn default() -> Self {
130122
Self::new()

datafusion/functions-aggregate/src/bool_and_or.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,7 @@ pub struct BoolAnd {
114114
impl BoolAnd {
115115
fn new() -> Self {
116116
Self {
117-
signature: Signature::uniform(
118-
1,
119-
vec![DataType::Boolean],
120-
Volatility::Immutable,
121-
),
117+
signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable),
122118
}
123119
}
124120
}
@@ -251,11 +247,7 @@ pub struct BoolOr {
251247
impl BoolOr {
252248
fn new() -> Self {
253249
Self {
254-
signature: Signature::uniform(
255-
1,
256-
vec![DataType::Boolean],
257-
Volatility::Immutable,
258-
),
250+
signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable),
259251
}
260252
}
261253
}

datafusion/functions-aggregate/src/correlation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ fn accumulate_correlation_states(
367367
/// where:
368368
/// n = number of observations
369369
/// sum_x = sum of x values
370-
/// sum_y = sum of y values
370+
/// sum_y = sum of y values
371371
/// sum_xy = sum of (x * y)
372372
/// sum_xx = sum of x^2 values
373373
/// sum_yy = sum of y^2 values

datafusion/functions-aggregate/src/count.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,20 +147,11 @@ pub fn count_all_window() -> Expr {
147147
```"#,
148148
standard_argument(name = "expression",)
149149
)]
150-
#[derive(PartialEq, Eq, Hash)]
150+
#[derive(PartialEq, Eq, Hash, Debug)]
151151
pub struct Count {
152152
signature: Signature,
153153
}
154154

155-
impl Debug for Count {
156-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157-
f.debug_struct("Count")
158-
.field("name", &self.name())
159-
.field("signature", &self.signature)
160-
.finish()
161-
}
162-
}
163-
164155
impl Default for Count {
165156
fn default() -> Self {
166157
Self::new()

datafusion/functions-aggregate/src/covariance.rs

Lines changed: 33 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,13 @@
1717

1818
//! [`CovarianceSample`]: covariance sample aggregations.
1919
20-
use arrow::datatypes::FieldRef;
21-
use arrow::{
22-
array::{ArrayRef, Float64Array, UInt64Array},
23-
compute::kernels::cast,
24-
datatypes::{DataType, Field},
25-
};
26-
use datafusion_common::{
27-
Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err,
28-
};
20+
use arrow::array::ArrayRef;
21+
use arrow::datatypes::{DataType, Field, FieldRef};
22+
use datafusion_common::cast::{as_float64_array, as_uint64_array};
23+
use datafusion_common::{Result, ScalarValue};
2924
use datafusion_expr::{
3025
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
3126
function::{AccumulatorArgs, StateFieldsArgs},
32-
type_coercion::aggregates::NUMERICS,
3327
utils::format_state_name,
3428
};
3529
use datafusion_functions_aggregate_common::stats::StatsType;
@@ -69,21 +63,12 @@ make_udaf_expr_and_func!(
6963
standard_argument(name = "expression1", prefix = "First"),
7064
standard_argument(name = "expression2", prefix = "Second")
7165
)]
72-
#[derive(PartialEq, Eq, Hash)]
66+
#[derive(PartialEq, Eq, Hash, Debug)]
7367
pub struct CovarianceSample {
7468
signature: Signature,
7569
aliases: Vec<String>,
7670
}
7771

78-
impl Debug for CovarianceSample {
79-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80-
f.debug_struct("CovarianceSample")
81-
.field("name", &self.name())
82-
.field("signature", &self.signature)
83-
.finish()
84-
}
85-
}
86-
8772
impl Default for CovarianceSample {
8873
fn default() -> Self {
8974
Self::new()
@@ -94,7 +79,10 @@ impl CovarianceSample {
9479
pub fn new() -> Self {
9580
Self {
9681
aliases: vec![String::from("covar")],
97-
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
82+
signature: Signature::exact(
83+
vec![DataType::Float64, DataType::Float64],
84+
Volatility::Immutable,
85+
),
9886
}
9987
}
10088
}
@@ -112,11 +100,7 @@ impl AggregateUDFImpl for CovarianceSample {
112100
&self.signature
113101
}
114102

115-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116-
if !arg_types[0].is_numeric() {
117-
return plan_err!("Covariance requires numeric input types");
118-
}
119-
103+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
120104
Ok(DataType::Float64)
121105
}
122106

@@ -165,20 +149,11 @@ impl AggregateUDFImpl for CovarianceSample {
165149
standard_argument(name = "expression1", prefix = "First"),
166150
standard_argument(name = "expression2", prefix = "Second")
167151
)]
168-
#[derive(PartialEq, Eq, Hash)]
152+
#[derive(PartialEq, Eq, Hash, Debug)]
169153
pub struct CovariancePopulation {
170154
signature: Signature,
171155
}
172156

173-
impl Debug for CovariancePopulation {
174-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
175-
f.debug_struct("CovariancePopulation")
176-
.field("name", &self.name())
177-
.field("signature", &self.signature)
178-
.finish()
179-
}
180-
}
181-
182157
impl Default for CovariancePopulation {
183158
fn default() -> Self {
184159
Self::new()
@@ -188,7 +163,10 @@ impl Default for CovariancePopulation {
188163
impl CovariancePopulation {
189164
pub fn new() -> Self {
190165
Self {
191-
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
166+
signature: Signature::exact(
167+
vec![DataType::Float64, DataType::Float64],
168+
Volatility::Immutable,
169+
),
192170
}
193171
}
194172
}
@@ -206,11 +184,7 @@ impl AggregateUDFImpl for CovariancePopulation {
206184
&self.signature
207185
}
208186

209-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
210-
if !arg_types[0].is_numeric() {
211-
return plan_err!("Covariance requires numeric input types");
212-
}
213-
187+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
214188
Ok(DataType::Float64)
215189
}
216190

@@ -304,30 +278,15 @@ impl Accumulator for CovarianceAccumulator {
304278
}
305279

306280
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307-
let values1 = &cast(&values[0], &DataType::Float64)?;
308-
let values2 = &cast(&values[1], &DataType::Float64)?;
309-
310-
let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
311-
let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
281+
let values1 = as_float64_array(&values[0])?;
282+
let values2 = as_float64_array(&values[1])?;
312283

313-
for i in 0..values1.len() {
314-
let value1 = if values1.is_valid(i) {
315-
arr1.next()
316-
} else {
317-
None
318-
};
319-
let value2 = if values2.is_valid(i) {
320-
arr2.next()
321-
} else {
322-
None
284+
for (value1, value2) in values1.iter().zip(values2) {
285+
let (value1, value2) = match (value1, value2) {
286+
(Some(a), Some(b)) => (a, b),
287+
_ => continue,
323288
};
324289

325-
if value1.is_none() || value2.is_none() {
326-
continue;
327-
}
328-
329-
let value1 = unwrap_or_internal_err!(value1);
330-
let value2 = unwrap_or_internal_err!(value2);
331290
let new_count = self.count + 1;
332291
let delta1 = value1 - self.mean1;
333292
let new_mean1 = delta1 / new_count as f64 + self.mean1;
@@ -345,29 +304,14 @@ impl Accumulator for CovarianceAccumulator {
345304
}
346305

347306
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
348-
let values1 = &cast(&values[0], &DataType::Float64)?;
349-
let values2 = &cast(&values[1], &DataType::Float64)?;
350-
let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
351-
let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
352-
353-
for i in 0..values1.len() {
354-
let value1 = if values1.is_valid(i) {
355-
arr1.next()
356-
} else {
357-
None
358-
};
359-
let value2 = if values2.is_valid(i) {
360-
arr2.next()
361-
} else {
362-
None
363-
};
364-
365-
if value1.is_none() || value2.is_none() {
366-
continue;
367-
}
307+
let values1 = as_float64_array(&values[0])?;
308+
let values2 = as_float64_array(&values[1])?;
368309

369-
let value1 = unwrap_or_internal_err!(value1);
370-
let value2 = unwrap_or_internal_err!(value2);
310+
for (value1, value2) in values1.iter().zip(values2) {
311+
let (value1, value2) = match (value1, value2) {
312+
(Some(a), Some(b)) => (a, b),
313+
_ => continue,
314+
};
371315

372316
let new_count = self.count - 1;
373317
let delta1 = self.mean1 - value1;
@@ -386,10 +330,10 @@ impl Accumulator for CovarianceAccumulator {
386330
}
387331

388332
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
389-
let counts = downcast_value!(states[0], UInt64Array);
390-
let means1 = downcast_value!(states[1], Float64Array);
391-
let means2 = downcast_value!(states[2], Float64Array);
392-
let cs = downcast_value!(states[3], Float64Array);
333+
let counts = as_uint64_array(&states[0])?;
334+
let means1 = as_float64_array(&states[1])?;
335+
let means2 = as_float64_array(&states[2])?;
336+
let cs = as_float64_array(&states[3])?;
393337

394338
for i in 0..counts.len() {
395339
let c = counts.value(i);

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,12 @@ pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
9090
```"#,
9191
standard_argument(name = "expression",)
9292
)]
93-
#[derive(PartialEq, Eq, Hash)]
93+
#[derive(PartialEq, Eq, Hash, Debug)]
9494
pub struct FirstValue {
9595
signature: Signature,
9696
is_input_pre_ordered: bool,
9797
}
9898

99-
impl Debug for FirstValue {
100-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
101-
f.debug_struct("FirstValue")
102-
.field("name", &self.name())
103-
.field("signature", &self.signature)
104-
.field("accumulator", &"<FUNC>")
105-
.finish()
106-
}
107-
}
108-
10999
impl Default for FirstValue {
110100
fn default() -> Self {
111101
Self::new()
@@ -1040,22 +1030,12 @@ impl Accumulator for FirstValueAccumulator {
10401030
```"#,
10411031
standard_argument(name = "expression",)
10421032
)]
1043-
#[derive(PartialEq, Eq, Hash)]
1033+
#[derive(PartialEq, Eq, Hash, Debug)]
10441034
pub struct LastValue {
10451035
signature: Signature,
10461036
is_input_pre_ordered: bool,
10471037
}
10481038

1049-
impl Debug for LastValue {
1050-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1051-
f.debug_struct("LastValue")
1052-
.field("name", &self.name())
1053-
.field("signature", &self.signature)
1054-
.field("accumulator", &"<FUNC>")
1055-
.finish()
1056-
}
1057-
}
1058-
10591039
impl Default for LastValue {
10601040
fn default() -> Self {
10611041
Self::new()

datafusion/functions-aggregate/src/grouping.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
//! Defines physical expressions that can evaluated at runtime during query execution
1919
2020
use std::any::Any;
21-
use std::fmt;
2221

2322
use arrow::datatypes::Field;
2423
use arrow::datatypes::{DataType, FieldRef};
@@ -60,20 +59,11 @@ make_udaf_expr_and_func!(
6059
description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function."
6160
)
6261
)]
63-
#[derive(PartialEq, Eq, Hash)]
62+
#[derive(PartialEq, Eq, Hash, Debug)]
6463
pub struct Grouping {
6564
signature: Signature,
6665
}
6766

68-
impl fmt::Debug for Grouping {
69-
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70-
f.debug_struct("Grouping")
71-
.field("name", &self.name())
72-
.field("signature", &self.signature)
73-
.finish()
74-
}
75-
}
76-
7767
impl Default for Grouping {
7868
fn default() -> Self {
7969
Self::new()

datafusion/functions-aggregate/src/median.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,11 @@ make_udaf_expr_and_func!(
8585
/// If using the distinct variation, the memory usage will be similarly high if the
8686
/// cardinality is high as it stores all distinct values in memory before computing the
8787
/// result, but if cardinality is low then memory usage will also be lower.
88-
#[derive(PartialEq, Eq, Hash)]
88+
#[derive(PartialEq, Eq, Hash, Debug)]
8989
pub struct Median {
9090
signature: Signature,
9191
}
9292

93-
impl Debug for Median {
94-
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
95-
f.debug_struct("Median")
96-
.field("name", &self.name())
97-
.field("signature", &self.signature)
98-
.finish()
99-
}
100-
}
101-
10293
impl Default for Median {
10394
fn default() -> Self {
10495
Self::new()

0 commit comments

Comments
 (0)