Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions datafusion/functions/src/unicode/reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use std::sync::Arc;
use crate::utils::{make_scalar_function, utf8_to_str_type};
use DataType::{LargeUtf8, Utf8, Utf8View};
use arrow::array::{
Array, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait, StringArrayType,
Array, ArrayRef, AsArray, GenericStringArray, LargeStringBuilder, OffsetSizeTrait,
StringArrayType, StringBuilder, StringLikeArrayBuilder, StringViewArray,
StringViewBuilder,
};
use arrow::datatypes::DataType;
use datafusion_common::{Result, exec_err};
Expand Down Expand Up @@ -82,7 +84,11 @@ impl ScalarUDFImpl for ReverseFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "reverse")
if arg_types[0] == Utf8View {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once all the functions that use the utf8_to_str_type function are updated to emit utf8view appropriately a followup pr to update that function to include utf8view -> utf8view would clean this up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively just

Ok(arg_types[0].clone())

Ok(Utf8View)
} else {
utf8_to_str_type(&arg_types[0], "btrim")
}
}

fn invoke_with_args(
Expand All @@ -107,20 +113,38 @@ impl ScalarUDFImpl for ReverseFunc {
/// Reverses the order of the characters in the string `reverse('abcde') = 'edcba'`.
/// The implementation uses UTF-8 code points as characters
fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably should take this opportunity to refactor the T generic out of reverse here, since it reads a bit confusing to match on datatype in invoke_with_args to determine the T for reverse but then we again match inside reverse anyway.

if args[0].data_type() == &Utf8View {
reverse_impl::<T, _>(&args[0].as_string_view())
} else {
reverse_impl::<T, _>(&args[0].as_string::<T>())
let len = args[0].len();

match args[0].data_type() {
Utf8View => reverse_impl::<&StringViewArray, StringViewBuilder>(
&args[0].as_string_view(),
StringViewBuilder::with_capacity(len),
),
LargeUtf8 => reverse_impl::<&GenericStringArray<T>, LargeStringBuilder>(
&args[0].as_string::<T>(),
LargeStringBuilder::with_capacity(len, 1024),
),
Utf8 => reverse_impl::<&GenericStringArray<T>, StringBuilder>(
&args[0].as_string::<T>(),
StringBuilder::with_capacity(len, 1024),
),
_ => unreachable!(
"Reverse can only be applied to Utf8View, Utf8 and LargeUtf8 types"
),
}
}

fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>(
string_array: &V,
) -> Result<ArrayRef> {
let mut builder = GenericStringBuilder::<T>::with_capacity(string_array.len(), 1024);

fn reverse_impl<'a, StringArrType, StringBuilderType>(
string_array: &StringArrType,
mut array_builder: StringBuilderType,
) -> Result<ArrayRef>
where
StringArrType: StringArrayType<'a>,
StringBuilderType: StringLikeArrayBuilder,
{
let mut string_buf = String::new();
let mut byte_buf = Vec::<u8>::new();

for string in string_array.iter() {
if let Some(s) = string {
if s.is_ascii() {
Expand All @@ -129,25 +153,25 @@ fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>(
byte_buf.reverse();
// SAFETY: Since the original string was ASCII, reversing the bytes still results in valid UTF-8.
let reversed = unsafe { std::str::from_utf8_unchecked(&byte_buf) };
builder.append_value(reversed);
array_builder.append_value(reversed);
byte_buf.clear();
} else {
string_buf.extend(s.chars().rev());
builder.append_value(&string_buf);
array_builder.append_value(&string_buf);
string_buf.clear();
}
} else {
builder.append_null();
array_builder.append_null();
}
}

Ok(Arc::new(builder.finish()) as ArrayRef)
Ok(Arc::new(array_builder.finish()) as ArrayRef)
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, LargeStringArray, StringArray};
use arrow::datatypes::DataType::{LargeUtf8, Utf8};
use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray};
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
Expand Down Expand Up @@ -180,8 +204,8 @@ mod tests {
vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
$EXPECTED,
&str,
Utf8,
StringArray
Utf8View,
StringViewArray
);
};
}
Expand Down
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_literal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,21 @@ SELECT reverse(arrow_cast('abcde', 'Utf8View'))
----
edcba

query T
SELECT arrow_typeof(reverse('abcde'))
----
Utf8

query T
SELECT arrow_typeof(reverse(arrow_cast('abcde', 'LargeUtf8')))
----
LargeUtf8

query T
SELECT arrow_typeof(reverse(arrow_cast('abcde', 'Utf8View')))
----
Utf8View

query T
SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)'))
----
Expand Down