diff --git a/datafusion/functions-nested/src/array_scale.rs b/datafusion/functions-nested/src/array_scale.rs new file mode 100644 index 0000000000000..24750ade8a775 --- /dev/null +++ b/datafusion/functions-nested/src/array_scale.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_scale function. + +use crate::utils::make_scalar_function; +use arrow::array::{ + Array, ArrayRef, Float64Array, GenericListArray, OffsetBufferBuilder, OffsetSizeTrait, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayScale, + array_scale, + array scalar, + "scales each element of a numeric array by a scalar.", + array_scale_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns a new array with each element of the input array multiplied by a scalar value, computed as `array[i] * scalar`. Returns NULL if the input row is NULL or the scalar is NULL. If a NULL element appears in the input array at position `i`, the result element at position `i` is NULL. Returns an empty array for an empty input array.", + syntax_example = "array_scale(array, scalar)", + sql_example = r#"```sql +> select array_scale([1.0, 2.0, 3.0], 2.0); ++----------------------------------+ +| array_scale(List([1.0,2.0,3.0]),Float64(2.0)) | ++----------------------------------+ +| [2.0, 4.0, 6.0] | ++----------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "scalar", + description = "Numeric scalar to multiply each element by. Can be a constant or column expression." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayScale { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayScale { + fn default() -> Self { + Self::new() + } +} + +impl ArrayScale { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_scale".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayScale { + fn name(&self) -> &str { + "array_scale" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // After `coerce_types`, `arg_types[0]` is one of List(Float64) or LargeList(Float64). + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [array_type, scalar_type] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + if !matches!( + array_type, + Null | List(_) | LargeList(_) | FixedSizeList(..) + ) { + return plan_err!( + "{} first argument must be a list type, got {array_type}", + self.name() + ); + } + + if !scalar_type.is_numeric() && !matches!(scalar_type, Null) { + return plan_err!( + "{} second argument must be numeric, got {scalar_type}", + self.name() + ); + } + + let coerced_array = if matches!(array_type, Null) { + List(Arc::new(Field::new_list_field(DataType::Float64, true))) + } else { + coerced_type_with_base_type_only(array_type, &DataType::Float64, coercion) + }; + + Ok(vec![coerced_array, DataType::Float64]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_scale_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_scale_inner(args: &[ArrayRef]) -> Result { + let [array, scalar] = take_function_args("array_scale", args)?; + match array.data_type() { + List(_) => general_array_scale::(array, scalar), + LargeList(_) => general_array_scale::(array, scalar), + arg_type => internal_err!( + "array_scale received unexpected type after coercion: {arg_type}" + ), + } +} + +fn general_array_scale( + array: &ArrayRef, + scalar: &ArrayRef, +) -> Result { + let list_array = as_generic_list_array::(array)?; + let scalar_array = as_float64_array(scalar)?; + + let values = as_float64_array(list_array.values())?; + let offsets = list_array.value_offsets(); + + // A row is null whenever either input row is null. The scalar applies + // uniformly across the array, so a null scalar makes the whole row + // undefined; union the two row-level null buffers in a single pass + // rather than tracking row nulls inside the value loop. + let row_nulls = NullBuffer::union(list_array.nulls(), scalar_array.nulls()); + + let mut value_builder = Float64Array::builder(values.len()); + let mut new_offsets = OffsetBufferBuilder::::new(list_array.len()); + + for row in 0..list_array.len() { + if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) { + new_offsets.push_length(0); + continue; + } + + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + let len = end - start; + let scalar_val = scalar_array.value(row); + + let slice = values.slice(start, len); + + // Per-element NULL propagation for NULL elements inside the array. + for i in 0..len { + if slice.is_null(i) { + value_builder.append_null(); + } else { + value_builder.append_value(slice.value(i) * scalar_val); + } + } + + new_offsets.push_length(len); + } + + let values_array = Arc::new(value_builder.finish()); + + // Preserve the inner field from the input array (including any user + // metadata). After `coerce_types` the inner type is Float64, but the + // input may still carry field-level annotations worth keeping. + let field = match list_array.data_type() { + List(f) | LargeList(f) => Arc::clone(f), + other => { + return internal_err!("array_scale unexpected list type: {other}"); + } + }; + + Ok(Arc::new(GenericListArray::::try_new( + field, + new_offsets.finish(), + values_array, + row_nulls, + )?)) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 1e6dc68cb23ae..aacc4dbd3d481 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -47,6 +47,7 @@ pub mod array_compact; pub mod array_filter; pub mod array_has; pub mod array_normalize; +pub mod array_scale; pub mod array_transform; pub mod arrays_zip; pub mod cardinality; @@ -96,6 +97,7 @@ pub mod expr_fn { pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; pub use super::array_normalize::array_normalize; + pub use super::array_scale::array_scale; pub use super::array_transform::array_transform; pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; @@ -171,6 +173,7 @@ pub fn all_default_nested_functions() -> Vec> { empty::array_empty_udf(), length::array_length_udf(), array_normalize::array_normalize_udf(), + array_scale::array_scale_udf(), cosine_distance::cosine_distance_udf(), inner_product::inner_product_udf(), distance::array_distance_udf(), diff --git a/datafusion/sqllogictest/test_files/array_scale.slt b/datafusion/sqllogictest/test_files/array_scale.slt new file mode 100644 index 0000000000000..15d6cd6d98f68 --- /dev/null +++ b/datafusion/sqllogictest/test_files/array_scale.slt @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## array_scale + +# General case: scale vector by positive scalar +query ? +select array_scale([1.0, 2.0, 3.0], 2.0); +---- +[2.0, 4.0, 6.0] + +# Scale by 1 returns the same array +query ? +select array_scale([1.0, 2.0, 3.0], 1.0); +---- +[1.0, 2.0, 3.0] + +# Scale by 0 returns zeros +query ? +select array_scale([1.0, 2.0, 3.0], 0.0); +---- +[0.0, 0.0, 0.0] + +# Scale by negative scalar +query ? +select array_scale([1.0, 2.0, 3.0], -1.0); +---- +[-1.0, -2.0, -3.0] + +# Scale by fractional scalar +query ? +select array_scale([2.0, 4.0, 6.0], 0.5); +---- +[1.0, 2.0, 3.0] + +# Single-element array +query ? +select array_scale([5.0], 3.0); +---- +[15.0] + +# Bare NULL array returns NULL +query ? +select array_scale(NULL, 2.0); +---- +NULL + +# NULL scalar returns NULL row (whole-row null because the scalar applies uniformly) +query ? +select array_scale([1.0, 2.0, 3.0], NULL); +---- +NULL + +# Both NULL returns NULL +query ? +select array_scale(NULL, NULL); +---- +NULL + +# NULL element in array propagates only to that position +query ? +select array_scale([1.0, NULL, 3.0], 2.0); +---- +[2.0, NULL, 6.0] + +# All-NULL elements with valid scalar: each position remains NULL +query ? +select array_scale([NULL, NULL], 5.0); +---- +[NULL, NULL] + +# LargeList support +query ? +select array_scale(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), 2.0); +---- +[2.0, 4.0, 6.0] + +# FixedSizeList input (coerced to List) +query ? +select array_scale(arrow_cast([1.0, 2.0, 3.0], 'FixedSizeList(3, Float64)'), 2.0); +---- +[2.0, 4.0, 6.0] + +# Float32 inner type (coerced to Float64) +query ? +select array_scale(arrow_cast([1.0, 2.0, 3.0], 'List(Float32)'), 2.0); +---- +[2.0, 4.0, 6.0] + +# Int64 inner type (coerced to Float64) +query ? +select array_scale(arrow_cast([1, 2, 3], 'List(Int64)'), 2); +---- +[2.0, 4.0, 6.0] + +# Integer literals on both sides (coerced to Float64) +query ? +select array_scale([1, 2, 3], 2); +---- +[2.0, 4.0, 6.0] + +# Integer scalar with Float64 list +query ? +select array_scale([1.0, 2.0, 3.0], 3); +---- +[3.0, 6.0, 9.0] + +# Unsupported non-numeric scalar (plan error) +query error array_scale second argument must be numeric +select array_scale([1.0, 2.0, 3.0], 'foo'); + +# Unsupported non-list first argument (plan error) +query error array_scale first argument must be a list type +select array_scale(1.0, 2.0); + +# Multi-row query: constant scalar broadcast across rows +query ? +select array_scale(column1, 2.0) from (values + (make_array(1.0, 2.0, 3.0)), + (make_array(0.0, 0.0)), + (make_array(1.0, NULL, 3.0)), + (NULL) +) as t(column1); +---- +[2.0, 4.0, 6.0] +[0.0, 0.0] +[2.0, NULL, 6.0] +NULL + +# Multi-row query: scalar from a column (varies per row) +query ? +select array_scale(column1, column2) from (values + (make_array(1.0, 2.0, 3.0), 2.0), + (make_array(1.0, 2.0), 0.5), + (make_array(1.0, 2.0), arrow_cast(NULL, 'Float64')), + (NULL, 3.0) +) as t(column1, column2); +---- +[2.0, 4.0, 6.0] +[0.5, 1.0] +NULL +NULL + +# Empty array: array_scale of an empty array yields an empty array +query ? +select array_scale(arrow_cast(make_array(), 'List(Float64)'), 2.0); +---- +[] + +# Wrong arity (zero args) +query error array_scale function requires 2 arguments, got 0 +select array_scale(); + +# Wrong arity (one arg) +query error array_scale function requires 2 arguments, got 1 +select array_scale([1.0, 2.0]); + +# Return type matches input list shape: List(Float64) input yields List(Float64) output +query ?T +select array_scale([1.0, 2.0], 3.0), arrow_typeof(array_scale([1.0, 2.0], 3.0)); +---- +[3.0, 6.0] List(Float64) + +# list_scale alias produces the same result +query ? +select list_scale([1.0, 2.0, 3.0], 2.0); +---- +[2.0, 4.0, 6.0] + +# list_scale alias with NULL scalar propagates correctly +query ? +select list_scale(column1, column2) from (values + (make_array(1.0, 2.0), 2.0), + (make_array(1.0, 2.0), arrow_cast(NULL, 'Float64')) +) as t(column1, column2); +---- +[2.0, 4.0] +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 6bf61391eb10e..955654d80e688 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3286,6 +3286,7 @@ _Alias of [current_date](#current_date)._ - [array_replace_n](#array_replace_n) - [array_resize](#array_resize) - [array_reverse](#array_reverse) +- [array_scale](#array_scale) - [array_slice](#array_slice) - [array_sort](#array_sort) - [array_to_string](#array_to_string) @@ -3341,6 +3342,7 @@ _Alias of [current_date](#current_date)._ - [list_replace_n](#list_replace_n) - [list_resize](#list_resize) - [list_reverse](#list_reverse) +- [list_scale](#list_scale) - [list_slice](#list_slice) - [list_sort](#list_sort) - [list_to_string](#list_to_string) @@ -4394,6 +4396,34 @@ array_reverse(array) - list_reverse +### `array_scale` + +Returns a new array with each element of the input array multiplied by a scalar value, computed as `array[i] * scalar`. Returns NULL if the input row is NULL or the scalar is NULL. If a NULL element appears in the input array at position `i`, the result element at position `i` is NULL. Returns an empty array for an empty input array. + +```sql +array_scale(array, scalar) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **scalar**: Numeric scalar to multiply each element by. Can be a constant or column expression. + +#### Example + +```sql +> select array_scale([1.0, 2.0, 3.0], 2.0); ++----------------------------------+ +| array_scale(List([1.0,2.0,3.0]),Float64(2.0)) | ++----------------------------------+ +| [2.0, 4.0, 6.0] | ++----------------------------------+ +``` + +#### Aliases + +- list_scale + ### `array_slice` Returns a slice of the array based on 1-indexed start and end positions. @@ -4909,6 +4939,10 @@ _Alias of [array_resize](#array_resize)._ _Alias of [array_reverse](#array_reverse)._ +### `list_scale` + +_Alias of [array_scale](#array_scale)._ + ### `list_slice` _Alias of [array_slice](#array_slice)._