diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 71a1e6842f8f3..7b501c0a6542d 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -18,10 +18,11 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use DataType::{LargeUtf8, Utf8, Utf8View}; use arrow::array::{ - Array, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait, StringArrayType, + Array, ArrayRef, AsArray, LargeStringBuilder, StringArrayType, StringBuilder, + StringLikeArrayBuilder, StringViewBuilder, }; use arrow::datatypes::DataType; use datafusion_common::{Result, exec_err}; @@ -82,7 +83,7 @@ impl ScalarUDFImpl for ReverseFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "reverse") + Ok(arg_types[0].clone()) } fn invoke_with_args( @@ -91,8 +92,7 @@ impl ScalarUDFImpl for ReverseFunc { ) -> Result { let args = &args.args; match args[0].data_type() { - Utf8 | Utf8View => make_scalar_function(reverse::, vec![])(args), - LargeUtf8 => make_scalar_function(reverse::, vec![])(args), + Utf8 | Utf8View | LargeUtf8 => make_scalar_function(reverse, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function reverse") } @@ -106,21 +106,39 @@ impl ScalarUDFImpl for ReverseFunc { /// Reverses the order of the characters in the string `reverse('abcde') = 'edcba'`. /// The implementation uses UTF-8 code points as characters -fn reverse(args: &[ArrayRef]) -> Result { - if args[0].data_type() == &Utf8View { - reverse_impl::(&args[0].as_string_view()) - } else { - reverse_impl::(&args[0].as_string::()) +fn reverse(args: &[ArrayRef]) -> Result { + let len = args[0].len(); + + match args[0].data_type() { + Utf8 => reverse_impl( + &args[0].as_string::(), + StringBuilder::with_capacity(len, 1024), + ), + Utf8View => reverse_impl( + &args[0].as_string_view(), + StringViewBuilder::with_capacity(len), + ), + LargeUtf8 => reverse_impl( + &args[0].as_string::(), + LargeStringBuilder::with_capacity(len, 1024), + ), + _ => unreachable!( + "Reverse can only be applied to Utf8View, Utf8 and LargeUtf8 types" + ), } } -fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>( - string_array: &V, -) -> Result { - let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), 1024); - +fn reverse_impl<'a, StringArrType, StringBuilderType>( + string_array: &StringArrType, + mut array_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + StringBuilderType: StringLikeArrayBuilder, +{ let mut string_buf = String::new(); let mut byte_buf = Vec::::new(); + for string in string_array.iter() { if let Some(s) = string { if s.is_ascii() { @@ -129,25 +147,25 @@ fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>( byte_buf.reverse(); // SAFETY: Since the original string was ASCII, reversing the bytes still results in valid UTF-8. let reversed = unsafe { std::str::from_utf8_unchecked(&byte_buf) }; - builder.append_value(reversed); + array_builder.append_value(reversed); byte_buf.clear(); } else { string_buf.extend(s.chars().rev()); - builder.append_value(&string_buf); + array_builder.append_value(&string_buf); string_buf.clear(); } } else { - builder.append_null(); + array_builder.append_null(); } } - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(array_builder.finish()) as ArrayRef) } #[cfg(test)] mod tests { - use arrow::array::{Array, LargeStringArray, StringArray}; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -180,8 +198,8 @@ mod tests { vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); }; } diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 6cf02218872df..2272de74cf7f2 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -389,6 +389,21 @@ SELECT reverse(arrow_cast('abcde', 'Utf8View')) ---- edcba +query T +SELECT arrow_typeof(reverse('abcde')) +---- +Utf8 + +query T +SELECT arrow_typeof(reverse(arrow_cast('abcde', 'LargeUtf8'))) +---- +LargeUtf8 + +query T +SELECT arrow_typeof(reverse(arrow_cast('abcde', 'Utf8View'))) +---- +Utf8View + query T SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) ----