diff --git a/datafusion/spark/src/function/math/ceil.rs b/datafusion/spark/src/function/math/ceil.rs new file mode 100644 index 0000000000000..5826d39ae30c7 --- /dev/null +++ b/datafusion/spark/src/function/math/ceil.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{AsArray, Decimal128Array}; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// Spark-compatible `ceil` expression +/// +/// +/// Differences with DataFusion ceil: +/// - Spark's ceil returns Int64 for float/integer types +/// - Spark's ceil adjusts precision for Decimal128 types +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCeil { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkCeil { + fn default() -> Self { + Self::new() + } +} + +impl SparkCeil { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + aliases: vec!["ceiling".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkCeil { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ceil" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Decimal128(p, s) if *s > 0 => { + let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; + Ok(DataType::Decimal128(new_p, 0)) + } + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + _ => Ok(DataType::Int64), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let return_type = args.return_type().clone(); + spark_ceil(&args.args, &return_type) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn spark_ceil(args: &[ColumnarValue], return_type: &DataType) -> Result { + let input = match take_function_args("ceil", args)? { + [ColumnarValue::Scalar(value)] => value.to_array()?, + [ColumnarValue::Array(arr)] => Arc::clone(arr), + }; + + let result = match input.data_type() { + DataType::Float32 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.ceil() as i64), + ) as _, + DataType::Float64 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.ceil() as i64), + ) as _, + dt if dt.is_integer() => cast(&input, &DataType::Int64)?, + DataType::Decimal128(_, s) if *s > 0 => { + let div = 10_i128.pow(*s as u32); + let result: Decimal128Array = + input.as_primitive::().unary(|x| { + let d = x / div; + let r = x % div; + if r > 0 { d + 1 } else { d } + }); + Arc::new(result.with_data_type(return_type.clone())) + } + DataType::Decimal128(_, _) => input, + other => return exec_err!("Unsupported data type {other:?} for function ceil"), + }; + + Ok(ColumnarValue::Array(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array}; + use datafusion_common::ScalarValue; + + #[test] + fn test_ceil_float64() { + let input = Float64Array::from(vec![Some(1.1), Some(-1.1), Some(0.0), None]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Int64Array::from(vec![Some(2), Some(-1), Some(0), None]) + ); + } + + #[test] + fn test_ceil_float32() { + let input = Float32Array::from(vec![Some(1.5f32), Some(-1.5f32)]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(2), Some(-1)])); + } + + #[test] + fn test_ceil_int64() { + let input = Int64Array::from(vec![Some(1), Some(-1), None]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None])); + } + + #[test] + fn test_ceil_decimal128() { + // Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00 + let return_type = DataType::Decimal128(9, 0); + let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None]) + .with_data_type(DataType::Decimal128(10, 2)); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args, &return_type).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1), None]) + .with_data_type(return_type); + assert_eq!(result, &expected); + } + + #[test] + fn test_ceil_scalar() { + let input = ScalarValue::Float64(Some(1.1)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = spark_ceil(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(2)])); + } +} diff --git a/datafusion/spark/src/function/math/floor.rs b/datafusion/spark/src/function/math/floor.rs new file mode 100644 index 0000000000000..13b0e5bbfb236 --- /dev/null +++ b/datafusion/spark/src/function/math/floor.rs @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{AsArray, Decimal128Array}; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// Spark-compatible `floor` expression +/// +/// +/// Differences with DataFusion floor: +/// - Spark's floor returns Int64 for float/integer types +/// - Spark's floor adjusts precision for Decimal128 types +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkFloor { + signature: Signature, +} + +impl Default for SparkFloor { + fn default() -> Self { + Self::new() + } +} + +impl SparkFloor { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkFloor { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "floor" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Decimal128(p, s) if *s > 0 => { + let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; + Ok(DataType::Decimal128(new_p, 0)) + } + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + _ => Ok(DataType::Int64), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let return_type = args.return_type().clone(); + spark_floor(&args.args, &return_type) + } +} + +fn spark_floor(args: &[ColumnarValue], return_type: &DataType) -> Result { + let input = match take_function_args("floor", args)? { + [ColumnarValue::Scalar(value)] => value.to_array()?, + [ColumnarValue::Array(arr)] => Arc::clone(arr), + }; + + let result = match input.data_type() { + DataType::Float32 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.floor() as i64), + ) as _, + DataType::Float64 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.floor() as i64), + ) as _, + dt if dt.is_integer() => cast(&input, &DataType::Int64)?, + DataType::Decimal128(_, s) if *s > 0 => { + let div = 10_i128.pow(*s as u32); + let result: Decimal128Array = + input.as_primitive::().unary(|x| { + let d = x / div; + let r = x % div; + if r < 0 { d - 1 } else { d } + }); + Arc::new(result.with_data_type(return_type.clone())) + } + DataType::Decimal128(_, _) => input, + other => return exec_err!("Unsupported data type {other:?} for function floor"), + }; + + Ok(ColumnarValue::Array(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array}; + use datafusion_common::ScalarValue; + + #[test] + fn test_floor_float64() { + let input = Float64Array::from(vec![Some(1.9), Some(-1.1), Some(0.0), None]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_floor(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Int64Array::from(vec![Some(1), Some(-2), Some(0), None]) + ); + } + + #[test] + fn test_floor_float32() { + let input = Float32Array::from(vec![Some(1.5f32), Some(-1.5f32)]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_floor(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-2)])); + } + + #[test] + fn test_floor_int64() { + let input = Int64Array::from(vec![Some(1), Some(-1), None]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_floor(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None])); + } + + #[test] + fn test_floor_decimal128() { + // Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00 + let return_type = DataType::Decimal128(9, 0); + let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None]) + .with_data_type(DataType::Decimal128(10, 2)); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_floor(&args, &return_type).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + let expected = Decimal128Array::from(vec![Some(1), Some(-2), Some(1), None]) + .with_data_type(return_type); + assert_eq!(result, &expected); + } + + #[test] + fn test_floor_scalar() { + let input = ScalarValue::Float64(Some(1.9)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = spark_floor(&args, &DataType::Int64).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(1)])); + } +} diff --git a/datafusion/spark/src/function/math/is_nan.rs b/datafusion/spark/src/function/math/is_nan.rs new file mode 100644 index 0000000000000..31571f2b7dd12 --- /dev/null +++ b/datafusion/spark/src/function/math/is_nan.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, BooleanArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; + +/// Spark-compatible `isnan` expression +/// +/// +/// Differences with standard SQL: +/// - Returns `false` for NULL inputs (not NULL) +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkIsNaN { + signature: Signature, +} + +impl Default for SparkIsNaN { + fn default() -> Self { + Self::new() + } +} + +impl SparkIsNaN { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Float32]), + TypeSignature::Exact(vec![DataType::Float64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkIsNaN { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "isnan" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_isnan(&args.args) + } +} + +fn spark_isnan(args: &[ColumnarValue]) -> Result { + let [value] = take_function_args("isnan", args)?; + + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(nulls_to_false( + BooleanArray::from_unary(array, |x| x.is_nan()), + ))) + } + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(nulls_to_false( + BooleanArray::from_unary(array, |x| x.is_nan()), + ))) + } + other => exec_err!("Unsupported data type {other:?} for function isnan"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Float32(v) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + Some(v.is_some_and(|x| x.is_nan())), + ))), + ScalarValue::Float64(v) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + Some(v.is_some_and(|x| x.is_nan())), + ))), + _ => exec_err!( + "Unsupported data type {:?} for function isnan", + sv.data_type() + ), + }, + } +} + +/// Replaces null values with false in a BooleanArray. +/// +/// Spark's `isnan` returns `false` for NULL inputs rather than propagating NULL. +fn nulls_to_false(is_nan: BooleanArray) -> ArrayRef { + match is_nan.nulls() { + Some(nulls) => { + let is_not_null = nulls.inner(); + Arc::new(BooleanArray::new( + is_nan.values() & is_not_null, + None, + )) + } + None => Arc::new(is_nan), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_isnan_float64() { + let input = Float64Array::from(vec![ + Some(1.0), + Some(f64::NAN), + None, + Some(f64::INFINITY), + Some(0.0), + ]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_isnan(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_any().downcast_ref::().unwrap(); + + // NULL input produces false, not NULL + assert!(!result.is_null(2)); + + let expected = BooleanArray::from(vec![false, true, false, false, false]); + assert_eq!(result, &expected); + } + + #[test] + fn test_isnan_float32() { + let input = Float32Array::from(vec![Some(f32::NAN), Some(1.0f32), None]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_isnan(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_any().downcast_ref::().unwrap(); + let expected = BooleanArray::from(vec![true, false, false]); + assert_eq!(result, &expected); + } + + #[test] + fn test_isnan_scalar_nan() { + let result = + spark_isnan(&[ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN)))]) + .unwrap(); + assert_eq!( + result, + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) + ); + } + + #[test] + fn test_isnan_scalar_null() { + let result = + spark_isnan(&[ColumnarValue::Scalar(ScalarValue::Float64(None))]).unwrap(); + assert_eq!( + result, + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) + ); + } + + #[test] + fn test_isnan_scalar_normal() { + let result = + spark_isnan(&[ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))]) + .unwrap(); + assert_eq!( + result, + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) + ); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 7f7d04e06b0be..3c76ef9a2a166 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -17,12 +17,16 @@ pub mod abs; pub mod bin; +pub mod ceil; pub mod expm1; pub mod factorial; +pub mod floor; pub mod hex; +pub mod is_nan; pub mod modulus; pub mod negative; pub mod rint; +pub mod round; pub mod trigonometry; pub mod unhex; pub mod width_bucket; @@ -32,30 +36,41 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(abs::SparkAbs, abs); +make_udf_function!(ceil::SparkCeil, ceil); make_udf_function!(expm1::SparkExpm1, expm1); make_udf_function!(factorial::SparkFactorial, factorial); +make_udf_function!(floor::SparkFloor, floor); make_udf_function!(hex::SparkHex, hex); +make_udf_function!(is_nan::SparkIsNaN, isnan); make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); +make_udf_function!(negative::SparkNegative, negative); make_udf_function!(rint::SparkRint, rint); +make_udf_function!(round::SparkRound, round); make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); -make_udf_function!(negative::SparkNegative, negative); make_udf_function!(bin::SparkBin, bin); pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((abs, "Returns abs(expr)", arg1)); + export_functions!((ceil, "Returns the smallest integer not less than expr.", arg1)); export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); export_functions!(( factorial, "Returns the factorial of expr. expr is [0..20]. Otherwise, null.", arg1 )); + export_functions!(( + floor, + "Returns the largest integer not greater than expr.", + arg1 + )); export_functions!((hex, "Computes hex value of the given column.", arg1)); + export_functions!((isnan, "Returns true if expr is NaN, false otherwise (including for NULL).", arg1)); export_functions!((modulus, "Returns the remainder of division of the first argument by the second argument.", arg1 arg2)); export_functions!((pmod, "Returns the positive remainder of division of the first argument by the second argument.", arg1 arg2)); export_functions!(( @@ -63,6 +78,7 @@ pub mod expr_fn { "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1 )); + export_functions!((round, "Rounds expr to d decimal places using HALF_UP rounding mode.", arg1 arg2)); export_functions!((unhex, "Converts hexadecimal string to binary.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); @@ -82,17 +98,21 @@ pub mod expr_fn { pub fn functions() -> Vec> { vec![ abs(), + ceil(), expm1(), factorial(), + floor(), hex(), + isnan(), modulus(), pmod(), + negative(), rint(), + round(), unhex(), width_bucket(), csc(), sec(), - negative(), bin(), ] } diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs new file mode 100644 index 0000000000000..9aaa12d9691f5 --- /dev/null +++ b/datafusion/spark/src/function/math/round.rs @@ -0,0 +1,577 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrowNativeTypeOp, AsArray, Decimal128Array, Int16Array, Int32Array, + Int64Array, Int8Array, +}; +use arrow::compute::kernels::arity::try_unary; +use arrow::datatypes::{ + DataType, Decimal128Type, Field, FieldRef, Float32Type, Float64Type, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; + +/// Spark-compatible `round` expression +/// +/// +/// Rounds the value to `d` decimal places using HALF_UP rounding mode. +/// When `d` is negative, rounds integer types with overflow checking in ANSI mode. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRound { + signature: Signature, +} + +impl Default for SparkRound { + fn default() -> Self { + Self::new() + } +} + +impl SparkRound { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::Numeric(1), TypeSignature::Any(2)], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkRound { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "round" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Default return type when decimal places is unknown (assumed 0) + Ok(arg_types[0].clone()) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let input_type = args.arg_fields[0].data_type(); + let point = if args.scalar_arguments.len() >= 2 { + match args.scalar_arguments[1] { + Some(ScalarValue::Int64(Some(p))) => *p, + Some(ScalarValue::Int32(Some(p))) => *p as i64, + Some(ScalarValue::Int16(Some(p))) => *p as i64, + Some(ScalarValue::Int8(Some(p))) => *p as i64, + _ => 0, + } + } else { + 0 + }; + + let return_type = compute_round_return_type(input_type, point); + Ok(Arc::new(Field::new(self.name(), return_type, true))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let point = extract_point(&args.args)?; + let return_type = args.return_type().clone(); + let enable_ansi_mode = args.config_options.execution.enable_ansi_mode; + spark_round(&args.args[0], point, &return_type, enable_ansi_mode) + } +} + +/// Computes the return type for round based on input type and decimal places. +/// +/// Follows Spark's `RoundBase.dataType` logic: +/// - Float/Integer types: return same type +/// - Decimal128(p, s) with d < 0: Decimal128(p - s + |d| + 1, 0) +/// - Decimal128(p, s) with d >= 0: Decimal128(p - s + min(s, d) + 1, min(s, d)) +fn compute_round_return_type(input_type: &DataType, point: i64) -> DataType { + match input_type { + DataType::Decimal128(p, s) => { + let p = *p as i64; + let s = *s as i64; + if point < 0 { + let abs_point = point.unsigned_abs() as i64; + let new_p = (p - s + abs_point + 1).clamp(1, 38) as u8; + DataType::Decimal128(new_p, 0) + } else { + let new_s = s.min(point); + let new_p = (p - s + new_s + 1).clamp(1, 38) as u8; + DataType::Decimal128(new_p, new_s as i8) + } + } + other => other.clone(), + } +} + +/// Extracts the decimal places parameter from function arguments. +fn extract_point(args: &[ColumnarValue]) -> Result { + if args.len() < 2 { + return Ok(0); + } + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(p))) => Ok(*p), + ColumnarValue::Scalar(ScalarValue::Int32(Some(p))) => Ok(*p as i64), + ColumnarValue::Scalar(ScalarValue::Int16(Some(p))) => Ok(*p as i64), + ColumnarValue::Scalar(ScalarValue::Int8(Some(p))) => Ok(*p as i64), + other => internal_err!( + "round requires a constant integer for decimal places, got: {:?}", + other + ), + } +} + +/// Rounds an integer value to the nearest multiple of `div` using HALF_UP rounding. +/// +/// When `enable_ansi_mode` is true, returns an error on arithmetic overflow. +/// When false, uses wrapping arithmetic. +macro_rules! integer_round { + ($x:expr, $div:expr, $half:expr, $enable_ansi_mode:expr) => {{ + let rem = $x % $div; + if rem <= -$half { + if $enable_ansi_mode { + ($x - rem).sub_checked($div).map_err(|_| { + ArrowError::ComputeError(format!( + "[ARITHMETIC_OVERFLOW] integer overflow in round" + )) + }) + } else { + Ok(($x - rem).sub_wrapping($div)) + } + } else if rem >= $half { + if $enable_ansi_mode { + ($x - rem).add_checked($div).map_err(|_| { + ArrowError::ComputeError(format!( + "[ARITHMETIC_OVERFLOW] integer overflow in round" + )) + }) + } else { + Ok(($x - rem).add_wrapping($div)) + } + } else if $enable_ansi_mode { + $x.sub_checked(rem).map_err(|_| { + ArrowError::ComputeError(format!( + "[ARITHMETIC_OVERFLOW] integer overflow in round" + )) + }) + } else { + Ok($x.sub_wrapping(rem)) + } + }}; +} + +macro_rules! round_integer_array { + ($array:expr, $point:expr, $array_type:ty, $native_type:ty, $enable_ansi_mode:expr) => {{ + let array = $array.as_any().downcast_ref::<$array_type>().unwrap(); + let ten: $native_type = 10; + let result: $array_type = if let Some(div) = ten.checked_pow((-$point) as u32) { + let half = div / 2; + try_unary(array, |x| { + integer_round!(x, div, half, $enable_ansi_mode) + })? + } else { + try_unary(array, |_| Ok(0 as $native_type))? + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! round_integer_scalar { + ($scalar_opt:expr, $point:expr, $scalar_variant:ident, $native_type:ty, $enable_ansi_mode:expr) => {{ + let ten: $native_type = 10; + if let Some(div) = ten.checked_pow((-$point) as u32) { + let half = div / 2; + let result = $scalar_opt + .map(|x| integer_round!(x, div, half, $enable_ansi_mode)) + .transpose() + .map_err(|e| datafusion_common::DataFusionError::ArrowError(Box::new(e), None))?; + Ok(ColumnarValue::Scalar(ScalarValue::$scalar_variant(result))) + } else { + Ok(ColumnarValue::Scalar(ScalarValue::$scalar_variant(Some( + 0, + )))) + } + }}; +} + +/// Rounds a float value to the specified number of decimal places. +#[inline] +fn round_float(value: f64, point: i64) -> f64 { + if value.is_nan() || value.is_infinite() { + return value; + } + // Clamp to avoid overflow in powi (f64 can represent up to ~10^308) + let point = point.clamp(-308, 308) as i32; + if point >= 0 { + let factor = 10f64.powi(point); + (value * factor).round() / factor + } else { + let factor = 10f64.powi(-point); + (value / factor).round() * factor + } +} + +/// Rounds a Decimal128 value to the specified number of decimal places. +/// +/// Uses Spark's BigDecimal-style rounding: +/// 1. Add half of the divisor (adjusted for sign) +/// 2. Truncate by division +/// 3. Adjust precision by multiplication (for negative point) +#[inline] +fn round_decimal(x: i128, scale: i64, point: i64) -> i128 { + if point < 0 { + if let Some(div) = + 10_i128.checked_pow(((-point) as u32) + (scale as u32)) + { + let half = div / 2; + let mul = 10_i128.pow((-point) as u32); + (x + x.signum() * half) / div * mul + } else { + 0 + } + } else { + let diff = (scale as u32).saturating_sub(point.min(scale) as u32); + let div = 10_i128.pow(diff); + let half = div / 2; + (x + x.signum() * half) / div + } +} + +fn spark_round( + value: &ColumnarValue, + point: i64, + return_type: &DataType, + enable_ansi_mode: bool, +) -> Result { + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = array + .as_primitive::() + .unary::<_, Float32Type>(|x| { + round_float(x as f64, point) as f32 + }); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Float64 => { + let result = array + .as_primitive::() + .unary::<_, Float64Type>(|x| round_float(x, point)); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int64 if point < 0 => { + round_integer_array!( + array, + point, + Int64Array, + i64, + enable_ansi_mode + ) + } + DataType::Int32 if point < 0 => { + round_integer_array!( + array, + point, + Int32Array, + i32, + enable_ansi_mode + ) + } + DataType::Int16 if point < 0 => { + round_integer_array!( + array, + point, + Int16Array, + i16, + enable_ansi_mode + ) + } + DataType::Int8 if point < 0 => { + round_integer_array!( + array, + point, + Int8Array, + i8, + enable_ansi_mode + ) + } + dt if dt.is_integer() => { + // Rounding to >= 0 decimal places on integers is a no-op + Ok(ColumnarValue::Array(Arc::clone(array))) + } + DataType::Decimal128(_, s) if *s >= 0 => { + let scale = *s as i64; + let result: Decimal128Array = array + .as_primitive::() + .unary(|x| round_decimal(x, scale, point)); + let result = result.with_data_type(return_type.clone()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + dt => exec_err!("Unsupported data type for round: {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Float32(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Float32(v.map(|x| round_float(x as f64, point) as f32)), + )), + ScalarValue::Float64(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(v.map(|x| round_float(x, point))), + )), + ScalarValue::Int64(v) if point < 0 => { + round_integer_scalar!(v, point, Int64, i64, enable_ansi_mode) + } + ScalarValue::Int32(v) if point < 0 => { + round_integer_scalar!(v, point, Int32, i32, enable_ansi_mode) + } + ScalarValue::Int16(v) if point < 0 => { + round_integer_scalar!(v, point, Int16, i16, enable_ansi_mode) + } + ScalarValue::Int8(v) if point < 0 => { + round_integer_scalar!(v, point, Int8, i8, enable_ansi_mode) + } + sv if sv.data_type().is_integer() => { + // Rounding to >= 0 decimal places on integers is a no-op + Ok(ColumnarValue::Scalar(sv.clone())) + } + ScalarValue::Decimal128(v, _, s) if *s >= 0 => { + let scale = *s as i64; + let result = v.map(|x| round_decimal(x, scale, point)); + let DataType::Decimal128(p, s) = return_type else { + return internal_err!( + "Expected Decimal128 return type for round" + ); + }; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + result, *p, *s, + ))) + } + dt => exec_err!("Unsupported data type for round: {dt}"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, Int64Array}; + + #[test] + fn test_round_float64() { + let input = Float64Array::from(vec![Some(2.5), Some(3.5), Some(-2.5), None]); + let result = spark_round( + &ColumnarValue::Array(Arc::new(input)), + 0, + &DataType::Float64, + false, + ) + .unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Float64Array::from(vec![Some(3.0), Some(4.0), Some(-3.0), None]) + ); + } + + #[test] + fn test_round_float64_with_decimal_places() { + let input = Float64Array::from(vec![Some(1.2345), Some(-1.2345)]); + let result = spark_round( + &ColumnarValue::Array(Arc::new(input)), + 2, + &DataType::Float64, + false, + ) + .unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Float64Array::from(vec![Some(1.23), Some(-1.23)]) + ); + } + + #[test] + fn test_round_int64_negative_point() { + let input = Int64Array::from(vec![Some(1234), Some(-1234), Some(1250), None]); + let result = spark_round( + &ColumnarValue::Array(Arc::new(input)), + -2, + &DataType::Int64, + false, + ) + .unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Int64Array::from(vec![Some(1200), Some(-1200), Some(1300), None]) + ); + } + + #[test] + fn test_round_int64_noop() { + // Rounding with non-negative decimal places is a no-op for integers + let input = Int64Array::from(vec![Some(42), Some(-42), None]); + let result = spark_round( + &ColumnarValue::Array(Arc::new(input.clone())), + 0, + &DataType::Int64, + false, + ) + .unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &input); + } + + #[test] + fn test_round_decimal128() { + // Decimal128(10, 3): values 1235 = 1.235, -1235 = -1.235 + let return_type = DataType::Decimal128(9, 2); + let input = Decimal128Array::from(vec![Some(1235), Some(-1235), None]) + .with_data_type(DataType::Decimal128(10, 3)); + let result = spark_round( + &ColumnarValue::Array(Arc::new(input)), + 2, + &return_type, + false, + ) + .unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + // 1.235 rounded to 2 decimal places = 1.24, -1.235 rounded = -1.24 + let expected = Decimal128Array::from(vec![Some(124), Some(-124), None]) + .with_data_type(return_type); + assert_eq!(result, &expected); + } + + #[test] + fn test_round_scalar_float64() { + let result = spark_round( + &ColumnarValue::Scalar(ScalarValue::Float64(Some(2.5))), + 0, + &DataType::Float64, + false, + ) + .unwrap(); + assert_eq!( + result, + ColumnarValue::Scalar(ScalarValue::Float64(Some(3.0))) + ); + } + + #[test] + fn test_round_scalar_int64_negative_point() { + let result = spark_round( + &ColumnarValue::Scalar(ScalarValue::Int64(Some(1234))), + -2, + &DataType::Int64, + false, + ) + .unwrap(); + assert_eq!( + result, + ColumnarValue::Scalar(ScalarValue::Int64(Some(1200))) + ); + } + + #[test] + fn test_round_return_type() { + // Float/integer types return same type + assert_eq!( + compute_round_return_type(&DataType::Float64, 2), + DataType::Float64 + ); + assert_eq!( + compute_round_return_type(&DataType::Int64, -2), + DataType::Int64 + ); + + // Decimal128(10, 3) rounded to 2 places: Decimal128(10-3+2+1=10, 2) -> Decimal128(10, 2) + assert_eq!( + compute_round_return_type(&DataType::Decimal128(10, 3), 2), + DataType::Decimal128(10, 2) + ); + + // Decimal128(10, 3) rounded to -1 places: Decimal128(10-3+1+1=9, 0) -> Decimal128(9, 0) + assert_eq!( + compute_round_return_type(&DataType::Decimal128(10, 3), -1), + DataType::Decimal128(9, 0) + ); + + // Decimal128(10, 3) rounded to 5 places (more than scale): min(3,5)=3 + // Decimal128(10-3+3+1=11, 3) -> Decimal128(11, 3) + assert_eq!( + compute_round_return_type(&DataType::Decimal128(10, 3), 5), + DataType::Decimal128(11, 3) + ); + } + + #[test] + fn test_round_integer_overflow_ansi_mode() { + // In ANSI mode, rounding i64::MAX with negative point should overflow + let result = spark_round( + &ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX))), + -1, + &DataType::Int64, + true, + ); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("ARITHMETIC_OVERFLOW")); + } + + #[test] + fn test_round_integer_overflow_legacy_mode() { + // In legacy mode, wrapping arithmetic is used (no error) + let result = spark_round( + &ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX))), + -1, + &DataType::Int64, + false, + ); + assert!(result.is_ok()); + } +}