diff --git a/Cargo.lock b/Cargo.lock index 5092a860e3c13..ea85d84f1319b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2592,8 +2592,10 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-nested", "log", + "num-traits", "percent-encoding", "rand 0.9.2", + "regex", "serde_json", "sha1", "sha2", diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 162b6d814e804..e3d4c573195ec 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -57,7 +57,9 @@ datafusion-functions = { workspace = true, features = ["crypto_expressions"] } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } +num-traits = { workspace = true } percent-encoding = "2.3.2" +regex = { workspace = true } rand = { workspace = true } serde_json = { workspace = true } sha1 = "0.10" diff --git a/datafusion/spark/src/function/conversion/cast.rs b/datafusion/spark/src/function/conversion/cast.rs new file mode 100644 index 0000000000000..2bd662912c4a5 --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast.rs @@ -0,0 +1,844 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::conversion::cast_boolean::{ + cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible, +}; +use crate::function::conversion::cast_complex::{ + cast_array_to_string, cast_binary_to_string, cast_int_to_binary, + cast_struct_to_struct, casts_struct_to_string, +}; +use crate::function::conversion::cast_datetime::{ + cast_date_to_timestamp, cast_int_to_timestamp, +}; +use crate::function::conversion::cast_numeric::{ + cast_float32_to_decimal128, cast_float64_to_decimal128, cast_int_to_decimal128, + spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, + spark_cast_float64_to_utf8, spark_cast_int_to_int, + spark_cast_nonintegral_numeric_to_integral, +}; +use crate::function::conversion::cast_string::{ + cast_string_to_date, cast_string_to_decimal, cast_string_to_float, + cast_string_to_int, cast_string_to_timestamp, + is_df_cast_from_string_spark_compatible, spark_cast_utf8_to_boolean, +}; +use crate::function::conversion::cast_utils::{ + EvalMode, SparkCastOptions, TIMESTAMP_FORMAT, array_with_timezone, + parse_spark_datatype, spark_cast_postprocess, +}; +use arrow::array::{Array, ArrayRef, AsArray, DictionaryArray, PrimitiveArray}; +use arrow::compute::{CastOptions, can_cast_types, cast_with_options, take}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Int32Type, +}; +use arrow::util::display::FormatOptions; +use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +// --- SparkCast UDF --- + +/// Spark-compatible CAST function. +/// Usage: `spark_cast(expr, 'TYPE_NAME')` +/// +/// Behavior depends on the `enable_ansi_mode` config option: +/// - `enable_ansi_mode = false` (default): Legacy mode, returns NULL on errors +/// - `enable_ansi_mode = true`: ANSI mode, raises errors on invalid input +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCast { + signature: Signature, +} + +impl Default for SparkCast { + fn default() -> Self { + Self::new() + } +} + +impl SparkCast { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::Any(2)], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkCast { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result> { + if args.scalar_arguments.len() != 2 { + return internal_err!("spark_cast requires exactly 2 arguments"); + } + let type_str = match args.scalar_arguments[1] { + Some(ScalarValue::Utf8(Some(s))) => s, + Some(ScalarValue::LargeUtf8(Some(s))) => s, + _ => { + return internal_err!( + "spark_cast second argument must be a string literal type name" + ); + } + }; + let dt = parse_spark_datatype(type_str)?; + Ok(Arc::new(Field::new(self.name(), dt, true))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let target_type = resolve_target_type(&args)?; + let timezone = args + .config_options + .execution + .time_zone + .clone() + .unwrap_or_else(|| "UTC".to_string()); + + let eval_mode = if args.config_options.execution.enable_ansi_mode { + EvalMode::Ansi + } else { + EvalMode::Legacy + }; + + let cast_options = SparkCastOptions::new(eval_mode, &timezone); + spark_cast_inner( + args.args.into_iter().next().unwrap(), + &target_type, + &cast_options, + ) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } +} + +// --- SparkTryCast UDF --- + +/// Spark-compatible TRY_CAST function. +/// Usage: `spark_try_cast(expr, 'TYPE_NAME')` +/// +/// Always uses Try mode: returns NULL instead of raising errors on invalid input. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkTryCast { + signature: Signature, +} + +impl Default for SparkTryCast { + fn default() -> Self { + Self::new() + } +} + +impl SparkTryCast { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::Any(2)], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkTryCast { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_try_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result> { + if args.scalar_arguments.len() != 2 { + return internal_err!("spark_try_cast requires exactly 2 arguments"); + } + let type_str = match args.scalar_arguments[1] { + Some(ScalarValue::Utf8(Some(s))) => s, + Some(ScalarValue::LargeUtf8(Some(s))) => s, + _ => { + return internal_err!( + "spark_try_cast second argument must be a string literal type name" + ); + } + }; + let dt = parse_spark_datatype(type_str)?; + Ok(Arc::new(Field::new(self.name(), dt, true))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let target_type = resolve_target_type(&args)?; + let timezone = args + .config_options + .execution + .time_zone + .clone() + .unwrap_or_else(|| "UTC".to_string()); + + let cast_options = SparkCastOptions::new(EvalMode::Try, &timezone); + spark_cast_inner( + args.args.into_iter().next().unwrap(), + &target_type, + &cast_options, + ) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } +} + +fn resolve_target_type(args: &ScalarFunctionArgs) -> Result { + match &args.args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(type_str))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(type_str))) => { + parse_spark_datatype(type_str) + } + _ => internal_err!( + "spark_cast/spark_try_cast second argument must be a string literal type name" + ), + } +} + +// --- Core cast logic --- + +/// Entry point: cast a ColumnarValue using Spark semantics. +pub fn spark_cast_inner( + arg: ColumnarValue, + data_type: &DataType, + cast_options: &SparkCastOptions, +) -> Result { + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( + array, + data_type, + cast_options, + )?)), + ColumnarValue::Scalar(scalar) => { + let array = scalar.to_array()?; + let scalar = ScalarValue::try_from_array( + &cast_array(array, data_type, cast_options)?, + 0, + )?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +/// Helper to create DictionaryArray from values +fn dict_from_values( + values_array: ArrayRef, +) -> Result { + let key_array: PrimitiveArray = (0..values_array.len()) + .map(|index| { + if values_array.is_valid(index) { + let native_index = K::Native::from_usize(index).ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not create index of type {} from value {}", + K::DATA_TYPE, + index + )) + })?; + Ok(Some(native_index)) + } else { + Ok(None) + } + }) + .collect::>>()? + .into_iter() + .collect(); + + let dict_array = DictionaryArray::::try_new(key_array, values_array)?; + Ok(Arc::new(dict_array)) +} + +/// Main router: cast an Arrow array using Spark semantics. +pub(crate) fn cast_array( + array: ArrayRef, + to_type: &DataType, + cast_options: &SparkCastOptions, +) -> Result { + use DataType::*; + let array = array_with_timezone(array, &cast_options.timezone, Some(to_type))?; + let from_type = array.data_type().clone(); + + let native_cast_options: CastOptions = CastOptions { + safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), + format_options: FormatOptions::new() + .with_timestamp_tz_format(TIMESTAMP_FORMAT) + .with_timestamp_format(TIMESTAMP_FORMAT), + }; + + // Handle dictionary input + let array = match &from_type { + Dictionary(key_type, value_type) + if key_type.as_ref() == &Int32 + && (value_type.as_ref() == &Utf8 + || value_type.as_ref() == &LargeUtf8 + || value_type.as_ref() == &Binary + || value_type.as_ref() == &LargeBinary) => + { + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a dictionary array"); + + let casted_result = match to_type { + Dictionary(_, to_value_type) => { + let casted_dictionary = DictionaryArray::::new( + dict_array.keys().clone(), + cast_array( + Arc::clone(dict_array.values()), + to_value_type, + cast_options, + )?, + ); + Arc::new(casted_dictionary.clone()) + } + _ => { + let casted_dictionary = DictionaryArray::::new( + dict_array.keys().clone(), + cast_array( + Arc::clone(dict_array.values()), + to_type, + cast_options, + )?, + ); + take(casted_dictionary.values().as_ref(), dict_array.keys(), None)? + } + }; + return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); + } + _ => { + if let Dictionary(_, _) = to_type { + let dict_array = dict_from_values::(array)?; + let casted_result = cast_array(dict_array, to_type, cast_options)?; + return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); + } else { + array + } + } + }; + let from_type = array.data_type(); + let eval_mode = cast_options.eval_mode; + + let cast_result = match (from_type, to_type) { + // String -> Boolean + (Utf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + + // String -> Timestamp + (Utf8, Timestamp(_, _)) => { + let tz_str = if cast_options.timezone.is_empty() { + "UTC" + } else { + cast_options.timezone.as_str() + }; + let tz: arrow::array::timezone::Tz = tz_str.parse().map_err(|e| { + DataFusionError::Internal(format!("Failed to parse timezone: {e}")) + })?; + cast_string_to_timestamp(&array, to_type, eval_mode, &tz) + } + + // String -> Date + (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + + // Date -> Int (reinterpret days as i32) + (Date32, Int32) => { + let cast_opts = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + Ok(cast_with_options(&array, to_type, &cast_opts)?) + } + + // String -> Float + (Utf8, Float32 | Float64) => cast_string_to_float(&array, to_type, eval_mode), + + // String -> Decimal + (Utf8 | LargeUtf8, Decimal128(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } + (Utf8 | LargeUtf8, Decimal256(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } + + // Int -> Int narrowing (not Try mode) + (Int64, Int32) + | (Int64, Int16) + | (Int64, Int8) + | (Int32, Int16) + | (Int32, Int8) + | (Int16, Int8) + if eval_mode != EvalMode::Try => + { + spark_cast_int_to_int(&array, eval_mode, from_type, to_type) + } + + // Int -> Decimal + (Int8 | Int16 | Int32 | Int64, Decimal128(precision, scale)) => { + cast_int_to_decimal128( + &array, eval_mode, from_type, to_type, *precision, *scale, + ) + } + + // String -> Int + (Utf8, Int8 | Int16 | Int32 | Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + (LargeUtf8, Int8 | Int16 | Int32 | Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + + // Float -> String + (Float64, Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float64, LargeUtf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float32, Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (Float32, LargeUtf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + + // Float -> Decimal + (Float32, Decimal128(precision, scale)) => { + cast_float32_to_decimal128(&array, *precision, *scale, eval_mode) + } + (Float64, Decimal128(precision, scale)) => { + cast_float64_to_decimal128(&array, *precision, *scale, eval_mode) + } + + // Float/Decimal -> Int (not Try mode) + (Float32, Int8) + | (Float32, Int16) + | (Float32, Int32) + | (Float32, Int64) + | (Float64, Int8) + | (Float64, Int16) + | (Float64, Int32) + | (Float64, Int64) + | (Decimal128(_, _), Int8) + | (Decimal128(_, _), Int16) + | (Decimal128(_, _), Int32) + | (Decimal128(_, _), Int64) + if eval_mode != EvalMode::Try => + { + spark_cast_nonintegral_numeric_to_integral( + &array, eval_mode, from_type, to_type, + ) + } + + // Decimal -> Boolean + (Decimal128(_p, _s), Boolean) => spark_cast_decimal_to_boolean(&array), + + // Utf8View -> Utf8 + (Utf8View, Utf8) => { + let cast_opts = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + Ok(cast_with_options(&array, to_type, &cast_opts)?) + } + + // Struct -> String + (Struct(_), Utf8) => casts_struct_to_string(array.as_struct(), cast_options), + + // Struct -> Struct + (Struct(_), Struct(_)) => { + cast_struct_to_struct(array.as_struct(), from_type, to_type, cast_options) + } + + // List -> String + (List(_), Utf8) => cast_array_to_string(array.as_list(), cast_options), + + // List -> List (delegate to Arrow if supported) + (List(_), List(_)) if can_cast_types(from_type, to_type) => { + let cast_opts = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + Ok(cast_with_options(&array, to_type, &cast_opts)?) + } + + // Binary -> String + (Binary, Utf8) => Ok(cast_binary_to_string::(&array, cast_options)?), + + // Date -> Timestamp + (Date32, Timestamp(_, tz)) => cast_date_to_timestamp(&array, cast_options, tz), + + // Int -> Binary (Legacy mode only) + (Int8 | Int16 | Int32 | Int64, Binary) if eval_mode == EvalMode::Legacy => { + cast_int_to_binary(&array, from_type) + } + + // Boolean -> Decimal + (Boolean, Decimal128(precision, scale)) => { + cast_boolean_to_decimal(&array, *precision, *scale) + } + + // Int -> Timestamp + (Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => { + cast_int_to_timestamp(&array, tz) + } + + // Fallback: use DataFusion cast when known to be compatible + _ if is_datafusion_spark_compatible(from_type, to_type) => { + Ok(cast_with_options(&array, to_type, &native_cast_options)?) + } + + // Unsupported cast + _ => { + internal_err!( + "Spark cast invoked for unsupported cast from {from_type:?} to {to_type:?}" + ) + } + }; + Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) +} + +/// Determines if DataFusion supports the given cast in a way that is +/// compatible with Spark. +fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { + if from_type == to_type { + return true; + } + match from_type { + DataType::Null => { + matches!(to_type, DataType::List(_)) + } + DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Utf8 + ), + DataType::Utf8 => is_df_cast_from_string_spark_compatible(to_type), + DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 + | DataType::Date32 + | DataType::Utf8 + | DataType::Timestamp(_, _) + ) + } + DataType::Binary => { + matches!(to_type, DataType::Utf8) + } + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + BooleanArray, Decimal128Array, Float32Array, Float64Array, Int8Array, + Int16Array, Int32Array, Int64Array, PrimitiveArray, StringArray, StructArray, + }; + use arrow::datatypes::{Fields, TimestampMicrosecondType, TimeUnit}; + + fn test_input_bool_array() -> ArrayRef { + Arc::new(BooleanArray::from(vec![Some(true), Some(false), None])) + } + + #[test] + fn test_spark_cast_scalar_string_to_int() { + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let result = spark_cast_inner( + ColumnarValue::Scalar(ScalarValue::Utf8(Some("42".to_string()))), + &DataType::Int32, + &opts, + ) + .unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Int32(Some(42))) => {} + other => panic!("Expected ScalarValue::Int32(42), got {other:?}"), + } + } + + #[test] + fn test_spark_cast_array_string_to_int() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("1"), + Some("2"), + Some("invalid"), + None, + ])); + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let result = cast_array(array, &DataType::Int32, &opts).unwrap(); + let int_arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 1); + assert_eq!(int_arr.value(1), 2); + assert!(int_arr.is_null(2)); // "invalid" -> null in Legacy + assert!(int_arr.is_null(3)); + } + + #[test] + fn test_spark_cast_struct_to_struct() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2)])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); + let c: ArrayRef = Arc::new(StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Int32, true)), a), + (Arc::new(Field::new("b", DataType::Utf8, true)), b), + ])); + + let fields = Fields::from(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, true), + ]); + + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let result = + spark_cast_inner(ColumnarValue::Array(c), &DataType::Struct(fields), &opts) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + assert_eq!(2, arr.len()); + let a = arr.as_struct().column(0).as_string::(); + assert_eq!("1", a.value(0)); + } else { + unreachable!() + } + } + + #[test] + fn test_is_datafusion_spark_compatible() { + assert!(is_datafusion_spark_compatible( + &DataType::Int32, + &DataType::Int32 + )); + assert!(is_datafusion_spark_compatible( + &DataType::Int32, + &DataType::Float64 + )); + assert!(is_datafusion_spark_compatible( + &DataType::Boolean, + &DataType::Int32 + )); + assert!(!is_datafusion_spark_compatible( + &DataType::Boolean, + &DataType::Decimal128(10, 2) + )); + } + + #[test] + fn test_cast_bool_to_int_types() { + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let input = test_input_bool_array(); + + // Bool -> Int8 + let result = cast_array(Arc::clone(&input), &DataType::Int8, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + + // Bool -> Int16 + let result = + cast_array(Arc::clone(&input), &DataType::Int16, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + + // Bool -> Int32 + let result = + cast_array(Arc::clone(&input), &DataType::Int32, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + + // Bool -> Int64 + let result = + cast_array(Arc::clone(&input), &DataType::Int64, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_cast_bool_to_float_types() { + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let input = test_input_bool_array(); + + // Bool -> Float32 + let result = + cast_array(Arc::clone(&input), &DataType::Float32, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1.0); + assert_eq!(arr.value(1), 0.0); + assert!(arr.is_null(2)); + + // Bool -> Float64 + let result = + cast_array(Arc::clone(&input), &DataType::Float64, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1.0); + assert_eq!(arr.value(1), 0.0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_cast_bool_to_string() { + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let input = test_input_bool_array(); + + let result = cast_array(input, &DataType::Utf8, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "true"); + assert_eq!(arr.value(1), "false"); + assert!(arr.is_null(2)); + } + + #[test] + fn test_cast_string_to_int_ansi_error() { + let array: ArrayRef = + Arc::new(StringArray::from(vec![Some("not_a_number")])); + let opts = SparkCastOptions::new(EvalMode::Ansi, "UTC"); + let result = cast_array(array, &DataType::Int32, &opts); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("[CAST_INVALID_INPUT]")); + } + + #[test] + fn test_cast_string_to_int_try_mode() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("42"), + Some("invalid"), + Some(""), + None, + ])); + let opts = SparkCastOptions::new(EvalMode::Try, "UTC"); + let result = cast_array(array, &DataType::Int32, &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 42); + assert!(arr.is_null(1)); + assert!(arr.is_null(2)); + assert!(arr.is_null(3)); + } + + #[test] + fn test_cast_float_to_decimal_through_router() { + let array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(42.0), + Some(-1.5), + Some(f64::NAN), + None, + ])); + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let result = + cast_array(array, &DataType::Decimal128(10, 2), &opts).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 4200); + assert_eq!(arr.value(1), -150); + assert!(arr.is_null(2)); // NaN -> null + assert!(arr.is_null(3)); + } + + #[test] + fn test_cast_string_to_timestamp_through_router() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01"), + Some("2020-01-01T12:34:56"), + None, + ])); + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let to_type = + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())); + let result = cast_array(array, &to_type, &opts).unwrap(); + let ts = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(ts.len(), 3); + assert_eq!(ts.value(0), 1577836800000000_i64); // 2020-01-01 00:00:00 + assert_eq!(ts.value(1), 1577882096000000_i64); // 2020-01-01T12:34:56 + assert!(ts.is_null(2)); + } + + #[test] + fn test_cast_invalid_timezone() { + let dates: ArrayRef = + Arc::new(arrow::array::Date32Array::from(vec![Some(0)])); + let opts = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone"); + let result = cast_array( + dates, + &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + &opts, + ); + assert!(result.is_err()); + } +} diff --git a/datafusion/spark/src/function/conversion/cast_boolean.rs b/datafusion/spark/src/function/conversion/cast_boolean.rs new file mode 100644 index 0000000000000..facf20f429eb9 --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast_boolean.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, Decimal128Array}; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::sync::Arc; + +/// Check if DataFusion's built-in cast from Boolean to the target type is +/// compatible with Spark behavior. +pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool { + use DataType::*; + matches!( + to_type, + Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 + ) +} + +/// Cast a Boolean array to Decimal128 with the given precision and scale. +/// true -> 1 * 10^scale, false -> 0, null -> null +pub fn cast_boolean_to_decimal( + array: &ArrayRef, + precision: u8, + scale: i8, +) -> Result { + let bool_array = array.as_boolean(); + let scaled_val = 10_i128.pow(scale as u32); + let result: Decimal128Array = bool_array + .iter() + .map(|v| v.map(|b| if b { scaled_val } else { 0 })) + .collect(); + Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, BooleanArray}; + + #[test] + fn test_is_df_cast_from_bool_spark_compatible() { + assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Boolean)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int8)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int16)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int32)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int64)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float32)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float64)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Utf8)); + assert!(!is_df_cast_from_bool_spark_compatible( + &DataType::Decimal128(10, 4) + )); + assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Null)); + } + + #[test] + fn test_cast_boolean_to_decimal() { + let array: ArrayRef = + Arc::new(BooleanArray::from(vec![Some(true), Some(false), None])); + let result = cast_boolean_to_decimal(&array, 10, 4).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 10000); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } +} diff --git a/datafusion/spark/src/function/conversion/cast_complex.rs b/datafusion/spark/src/function/conversion/cast_complex.rs new file mode 100644 index 0000000000000..2d62ae67a4fca --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast_complex.rs @@ -0,0 +1,354 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::conversion::cast::cast_array; +use crate::function::conversion::cast_utils::SparkCastOptions; +use arrow::array::builder::StringBuilder; +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, GenericByteArray, GenericStringArray, + Int8Array, Int16Array, Int32Array, Int64Array, ListArray, OffsetSizeTrait, + StringArray, StructArray, +}; +use arrow::datatypes::{DataType, GenericBinaryType}; +use arrow::error::ArrowError; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +// --- Int → Binary --- + +macro_rules! cast_whole_num_to_binary { + ($array:expr, $primitive_type:ty, $byte_size:expr) => {{ + let input_arr = $array + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Expected numeric array".to_string(), + ) + })?; + + let len = input_arr.len(); + let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size); + + for i in 0..input_arr.len() { + if input_arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(input_arr.value(i).to_be_bytes()); + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) + }}; +} + +pub(crate) fn cast_int_to_binary( + array: &ArrayRef, + from_type: &DataType, +) -> Result { + match from_type { + DataType::Int8 => cast_whole_num_to_binary!(array, Int8Array, 1), + DataType::Int16 => cast_whole_num_to_binary!(array, Int16Array, 2), + DataType::Int32 => cast_whole_num_to_binary!(array, Int32Array, 4), + DataType::Int64 => cast_whole_num_to_binary!(array, Int64Array, 8), + _ => datafusion_common::internal_err!( + "Unsupported type for cast_int_to_binary: {:?}", + from_type + ), + } +} + +// --- Struct → Struct --- + +/// Cast between struct types based on logic in +/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`. +pub(crate) fn cast_struct_to_struct( + array: &StructArray, + from_type: &DataType, + to_type: &DataType, + cast_options: &SparkCastOptions, +) -> Result { + match (from_type, to_type) { + (DataType::Struct(from_fields), DataType::Struct(to_fields)) => { + let cast_fields: Vec = from_fields + .iter() + .enumerate() + .zip(to_fields.iter()) + .map(|((idx, _from), to)| { + let from_field = Arc::clone(array.column(idx)); + let array_length = from_field.len(); + let cast_result = spark_cast_columnar( + ColumnarValue::from(from_field), + to.data_type(), + cast_options, + ) + .unwrap(); + cast_result.to_array(array_length).unwrap() + }) + .collect(); + + Ok(Arc::new(StructArray::new( + to_fields.clone(), + cast_fields, + array.nulls().cloned(), + ))) + } + _ => unreachable!(), + } +} + +// --- Struct → String --- + +pub(crate) fn casts_struct_to_string( + array: &StructArray, + spark_cast_options: &SparkCastOptions, +) -> Result { + let string_arrays: Vec = array + .columns() + .iter() + .map(|arr| { + spark_cast_columnar( + ColumnarValue::Array(Arc::clone(arr)), + &DataType::Utf8, + spark_cast_options, + ) + .and_then(|cv| cv.into_array(arr.len())) + }) + .collect::>>()?; + let string_arrays: Vec<&StringArray> = + string_arrays.iter().map(|arr| arr.as_string()).collect(); + + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut str = String::with_capacity(array.len() * 16); + for row_index in 0..array.len() { + if array.is_null(row_index) { + builder.append_null(); + } else { + str.clear(); + let mut any_fields_written = false; + str.push('{'); + for field in &string_arrays { + if any_fields_written { + str.push_str(", "); + } + if field.is_null(row_index) { + str.push_str("null"); + } else { + str.push_str(field.value(row_index)); + } + any_fields_written = true; + } + str.push('}'); + builder.append_value(&str); + } + } + Ok(Arc::new(builder.finish())) +} + +// --- List → String --- + +pub(crate) fn cast_array_to_string( + array: &ListArray, + spark_cast_options: &SparkCastOptions, +) -> Result { + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut str = String::with_capacity(array.len() * 16); + + let casted_values = cast_array( + Arc::clone(array.values()), + &DataType::Utf8, + spark_cast_options, + )?; + let string_values = casted_values + .as_any() + .downcast_ref::() + .expect("Casted values should be StringArray"); + + let offsets = array.offsets(); + for row_index in 0..array.len() { + if array.is_null(row_index) { + builder.append_null(); + } else { + str.clear(); + let start = offsets[row_index] as usize; + let end = offsets[row_index + 1] as usize; + + str.push('['); + let mut first = true; + for idx in start..end { + if !first { + str.push_str(", "); + } + if string_values.is_null(idx) { + str.push_str("null"); + } else { + str.push_str(string_values.value(idx)); + } + first = false; + } + str.push(']'); + builder.append_value(&str); + } + } + Ok(Arc::new(builder.finish())) +} + +// --- Binary → String --- + +pub(crate) fn cast_binary_to_string( + array: &dyn Array, + _spark_cast_options: &SparkCastOptions, +) -> std::result::Result { + let input = array + .as_any() + .downcast_ref::>>() + .unwrap(); + + let output_array = input + .iter() + .map(|value| match value { + Some(value) => Ok(Some(cast_binary_formatter(value))), + _ => Ok(None), + }) + .collect::, ArrowError>>()?; + Ok(Arc::new(output_array)) +} + +fn cast_binary_formatter(value: &[u8]) -> String { + match String::from_utf8(value.to_vec()) { + Ok(value) => value, + Err(_) => unsafe { String::from_utf8_unchecked(value.to_vec()) }, + } +} + +/// Helper: apply spark_cast to a ColumnarValue (used by struct/list casting). +fn spark_cast_columnar( + arg: ColumnarValue, + data_type: &DataType, + cast_options: &SparkCastOptions, +) -> Result { + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( + array, + data_type, + cast_options, + )?)), + ColumnarValue::Scalar(scalar) => { + let array = scalar.to_array()?; + let scalar = datafusion_common::ScalarValue::try_from_array( + &cast_array(array, data_type, cast_options)?, + 0, + )?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::Field; + + #[test] + fn test_cast_int_to_binary() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(256), None])); + let result = cast_int_to_binary(&array, &DataType::Int32).unwrap(); + let binary = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(binary.value(0), &[0, 0, 0, 1]); + assert_eq!(binary.value(1), &[0, 0, 1, 0]); + assert!(binary.is_null(2)); + } + + #[test] + fn test_cast_struct_to_utf8() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let c: ArrayRef = Arc::new(StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Int32, true)), a), + (Arc::new(Field::new("b", DataType::Utf8, true)), b), + ])); + let cast_opts = SparkCastOptions::new( + crate::function::conversion::cast_utils::EvalMode::Legacy, + "UTC", + ); + let string_array = casts_struct_to_string(c.as_struct(), &cast_opts).unwrap(); + let string_array = string_array.as_string::(); + assert_eq!(string_array.value(0), "{1, a}"); + assert_eq!(string_array.value(1), "{2, b}"); + assert_eq!(string_array.value(2), "{null, c}"); + } + + #[test] + fn test_cast_list_to_string() { + let values_array = StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("a"), + None, + None, + ]); + let offsets_buffer = OffsetBuffer::::new(vec![0, 3, 5, 6, 6].into()); + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let list_array = Arc::new(ListArray::new( + item_field, + offsets_buffer, + Arc::new(values_array), + None, + )); + let cast_opts = SparkCastOptions::new( + crate::function::conversion::cast_utils::EvalMode::Legacy, + "UTC", + ); + let string_array = cast_array_to_string(&list_array, &cast_opts).unwrap(); + let string_array = string_array.as_string::(); + assert_eq!(string_array.value(0), "[a, b, c]"); + assert_eq!(string_array.value(1), "[a, null]"); + assert_eq!(string_array.value(2), "[null]"); + assert_eq!(string_array.value(3), "[]"); + } + + #[test] + fn test_cast_i32_list_to_string() { + let values_array = + Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]); + let offsets_buffer = OffsetBuffer::::new(vec![0, 3, 5, 6, 6].into()); + let item_field = Arc::new(Field::new("item", DataType::Int32, true)); + let list_array = Arc::new(ListArray::new( + item_field, + offsets_buffer, + Arc::new(values_array), + None, + )); + let cast_opts = SparkCastOptions::new( + crate::function::conversion::cast_utils::EvalMode::Legacy, + "UTC", + ); + let string_array = cast_array_to_string(&list_array, &cast_opts).unwrap(); + let string_array = string_array.as_string::(); + assert_eq!(string_array.value(0), "[1, 2, 3]"); + assert_eq!(string_array.value(1), "[1, null]"); + assert_eq!(string_array.value(2), "[null]"); + assert_eq!(string_array.value(3), "[]"); + } +} diff --git a/datafusion/spark/src/function/conversion/cast_datetime.rs b/datafusion/spark/src/function/conversion/cast_datetime.rs new file mode 100644 index 0000000000000..0121ece0ab14b --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast_datetime.rs @@ -0,0 +1,292 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::conversion::cast_utils::{MICROS_PER_SECOND, SparkCastOptions}; +use arrow::array::{Array, ArrayRef, AsArray, TimestampMicrosecondBuilder}; +use arrow::datatypes::{DataType, Date32Type, Int8Type, Int16Type, Int32Type, Int64Type}; +use chrono::{NaiveDate, TimeZone}; +use datafusion_common::Result; +use std::sync::Arc; + +macro_rules! cast_int_to_timestamp_impl { + ($array:expr, $builder:expr, $primitive_type:ty) => {{ + let arr = $array.as_primitive::<$primitive_type>(); + for i in 0..arr.len() { + if arr.is_null(i) { + $builder.append_null(); + } else { + // saturating_mul limits to i64::MIN/MAX on overflow instead of panicking, + // matching Spark behavior (irrespective of EvalMode) + let micros = (arr.value(i) as i64).saturating_mul(MICROS_PER_SECOND); + $builder.append_value(micros); + } + } + }}; +} + +/// Cast integer types (Int8/16/32/64) to Timestamp(Microsecond) by treating +/// the integer as seconds since epoch. +pub(crate) fn cast_int_to_timestamp( + array_ref: &ArrayRef, + target_tz: &Option>, +) -> Result { + let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len()); + + match array_ref.data_type() { + DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder, Int8Type), + DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder, Int16Type), + DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder, Int32Type), + DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder, Int64Type), + dt => { + return datafusion_common::internal_err!( + "Unsupported type for cast_int_to_timestamp: {:?}", + dt + ); + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) +} + +/// Cast Date32 to Timestamp(Microsecond) with timezone awareness. +/// Converts a date (days since epoch) to a timestamp at midnight in the +/// session timezone, then stores as UTC microseconds. +pub(crate) fn cast_date_to_timestamp( + array_ref: &ArrayRef, + cast_options: &SparkCastOptions, + target_tz: &Option>, +) -> Result { + let tz_str = if cast_options.timezone.is_empty() { + "UTC" + } else { + cast_options.timezone.as_str() + }; + let tz: arrow::array::timezone::Tz = tz_str.parse().map_err(|e| { + datafusion_common::DataFusionError::Internal(format!( + "Failed to parse timezone: {e}" + )) + })?; + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let date_array = array_ref.as_primitive::(); + + let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len()); + + for date in date_array.iter() { + match date { + Some(date) => { + let naive_date = epoch + chrono::Duration::days(date as i64); + let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap(); + let local_midnight_in_microsec = tz + .from_local_datetime(&local_midnight) + .earliest() + .map(|dt| dt.timestamp_micros()) + // fallback to UTC if DST ambiguity + .unwrap_or((date as i64) * 86_400 * 1_000_000); + builder.append_value(local_midnight_in_microsec); + } + None => { + builder.append_null(); + } + } + } + Ok(Arc::new( + builder.finish().with_timezone_opt(target_tz.clone()), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Date32Array, Int8Array, Int16Array, Int32Array, Int64Array, + }; + use arrow::datatypes::TimestampMicrosecondType; + + use crate::function::conversion::cast_utils::EvalMode; + + #[test] + fn test_cast_int_to_timestamp() { + let timezones: [Option>; 3] = [ + Some(Arc::from("UTC")), + Some(Arc::from("America/New_York")), + None, + ]; + + for tz in &timezones { + let int8_array: ArrayRef = + Arc::new(Int8Array::from(vec![Some(0), Some(1), Some(-1), None])); + + let result = cast_int_to_timestamp(&int8_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert!(ts_array.is_null(3)); + } + } + + #[test] + fn test_cast_int64_overflow() { + let int64_array: ArrayRef = + Arc::new(Int64Array::from(vec![Some(i64::MAX), Some(i64::MIN)])); + let result = + cast_int_to_timestamp(&int64_array, &Some(Arc::from("UTC"))).unwrap(); + let ts_array = result.as_primitive::(); + // saturating_mul should cap at i64::MAX/MIN + assert_eq!(ts_array.value(0), i64::MAX); + assert_eq!(ts_array.value(1), i64::MIN); + } + + #[test] + fn test_cast_date_to_timestamp_utc() { + let dates: ArrayRef = Arc::new(Date32Array::from(vec![ + Some(0), // epoch + Some(19723), // 2024-01-01 + None, + ])); + let opts = SparkCastOptions::new(EvalMode::Legacy, "UTC"); + let result = + cast_date_to_timestamp(&dates, &opts, &Some(Arc::from("UTC"))).unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), 0); + assert_eq!(ts.value(1), 1704067200000000i64); + assert!(ts.is_null(2)); + } + + #[test] + fn test_cast_int_to_timestamp_all_types() { + let timezones: [Option>; 6] = [ + Some(Arc::from("UTC")), + Some(Arc::from("America/New_York")), + Some(Arc::from("America/Los_Angeles")), + Some(Arc::from("Europe/London")), + Some(Arc::from("Asia/Tokyo")), + Some(Arc::from("Australia/Sydney")), + ]; + + for tz in &timezones { + // Int8 with boundary values + let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(127), + Some(-128), + None, + ])); + let result = cast_int_to_timestamp(&int8_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 127_000_000); + assert_eq!(ts_array.value(4), -128_000_000); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + // Int16 with boundary values + let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(32767), + Some(-32768), + None, + ])); + let result = cast_int_to_timestamp(&int16_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 32_767_000_000_i64); + assert_eq!(ts_array.value(4), -32_768_000_000_i64); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + // Int32 with a realistic epoch seconds value + let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(1704067200), // 2024-01-01 + None, + ])); + let result = cast_int_to_timestamp(&int32_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64); + assert!(ts_array.is_null(4)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + // Int64 with MAX/MIN (overflow-saturating) + let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(i64::MAX), + Some(i64::MIN), + ])); + let result = cast_int_to_timestamp(&int64_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000_i64); + assert_eq!(ts_array.value(2), -1_000_000_i64); + assert_eq!(ts_array.value(3), i64::MAX); + assert_eq!(ts_array.value(4), i64::MIN); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + } + } + + #[test] + fn test_cast_date_to_timestamp_dst() { + // epoch, 2024-01-01, 2024-03-11 (DST spring-forward in US) + let dates: ArrayRef = Arc::new(Date32Array::from(vec![ + Some(0), + Some(19723), + Some(19793), + None, + ])); + + let non_dst_date = 1704067200000000_i64; + let dst_date = 1710115200000000_i64; + let seven_hours_ts = 25200000000_i64; + let eight_hours_ts = 28800000000_i64; + + // America/Los_Angeles: DST-aware + let opts = SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles"); + let result = + cast_date_to_timestamp(&dates, &opts, &Some(Arc::from("UTC"))).unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), eight_hours_ts); + assert_eq!(ts.value(1), non_dst_date + eight_hours_ts); + // DST: spring forward -> only 7 hours offset + assert_eq!(ts.value(2), dst_date + seven_hours_ts); + assert!(ts.is_null(3)); + + // America/Phoenix: no DST, always 7 hours offset + let opts = SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix"); + let result = + cast_date_to_timestamp(&dates, &opts, &Some(Arc::from("UTC"))).unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), seven_hours_ts); + assert_eq!(ts.value(1), non_dst_date + seven_hours_ts); + assert_eq!(ts.value(2), dst_date + seven_hours_ts); + assert!(ts.is_null(3)); + } +} diff --git a/datafusion/spark/src/function/conversion/cast_numeric.rs b/datafusion/spark/src/function/conversion/cast_numeric.rs new file mode 100644 index 0000000000000..0ad50337b522d --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast_numeric.rs @@ -0,0 +1,1090 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::conversion::cast_utils::{ + EvalMode, cast_overflow, format_decimal_str, numeric_value_out_of_range, +}; +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanBuilder, Decimal128Array, Decimal128Builder, + Float32Array, Float64Array, GenericStringArray, Int8Array, Int16Array, Int32Array, + Int64Array, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Decimal128Type, Float32Type, Float64Type, Int8Type, + Int16Type, Int32Type, Int64Type, is_validate_decimal_precision, +}; +use datafusion_common::Result; +use num_traits::{AsPrimitive, ToPrimitive, Zero}; +use std::sync::Arc; + +// --- Float → String macros --- + +macro_rules! cast_float_to_string { + ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ + fn cast(from: &dyn Array, _eval_mode: EvalMode) -> Result + where + OffsetSize: OffsetSizeTrait, + { + let array = from.as_any().downcast_ref::<$output_type>().unwrap(); + + // If the absolute number is less than 10,000,000 and greater or equal than 0.001, + // the result is expressed without scientific notation with at least one digit on + // either side of the decimal point. Otherwise, Spark uses scientific notation. + const LOWER_SCIENTIFIC_BOUND: $type = 0.001; + const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0; + + let output_array = array + .iter() + .map(|value| match value { + Some(value) if value == <$type>::INFINITY => { + Ok(Some("Infinity".to_string())) + } + Some(value) if value == <$type>::NEG_INFINITY => { + Ok(Some("-Infinity".to_string())) + } + Some(value) + if (value.abs() < UPPER_SCIENTIFIC_BOUND + && value.abs() >= LOWER_SCIENTIFIC_BOUND) + || value.abs() == 0.0 => + { + let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" }; + Ok(Some(format!("{value}{trailing_zero}"))) + } + Some(value) + if value.abs() >= UPPER_SCIENTIFIC_BOUND + || value.abs() < LOWER_SCIENTIFIC_BOUND => + { + let formatted = format!("{value:E}"); + if formatted.contains('.') { + Ok(Some(formatted)) + } else { + let prepare_number: Vec<&str> = formatted.split('E').collect(); + let coefficient = prepare_number[0]; + let exponent = prepare_number[1]; + Ok(Some(format!("{coefficient}.0E{exponent}"))) + } + } + Some(value) => Ok(Some(value.to_string())), + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(output_array)) + } + + cast::<$offset_type>($from, $eval_mode) + }}; +} + +// --- Int → Int narrowing macros --- + +macro_rules! cast_int_to_int_macro { + ( + $array:expr, + $eval_mode:expr, + $from_arrow_primitive_type:ty, + $to_arrow_primitive_type:ty, + $from_data_type:expr, + $to_native_type:ty, + $spark_from_data_type_name:expr, + $spark_to_data_type_name:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let spark_int_literal_suffix = match $from_data_type { + &DataType::Int64 => "L", + &DataType::Int16 => "S", + &DataType::Int8 => "T", + _ => "", + }; + + let output_array = match $eval_mode { + EvalMode::Legacy => cast_array + .iter() + .map(|value| match value { + Some(value) => Ok::< + Option<$to_native_type>, + datafusion_common::DataFusionError, + >(Some(value as $to_native_type)), + _ => Ok(None), + }) + .collect::>>(), + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let res = <$to_native_type>::try_from(value); + if res.is_err() { + Err(cast_overflow( + &(value.to_string() + spark_int_literal_suffix), + $spark_from_data_type_name, + $spark_to_data_type_name, + )) + } else { + Ok::< + Option<$to_native_type>, + datafusion_common::DataFusionError, + >(Some(res.unwrap())) + } + } + _ => Ok(None), + }) + .collect::>>(), + }?; + let result: Result = Ok(Arc::new(output_array) as ArrayRef); + result + }}; +} + +// --- Float/Decimal → Int macros --- + +// When Spark casts to Byte/Short types, it casts to Int first then to Byte/Short. +macro_rules! cast_float_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = + value.is_nan() || value.abs() as i32 == i32::MAX; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + let i32_value = value as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let i32_value = value as i32; + Ok::, datafusion_common::DataFusionError>( + Some(i32_value as $rust_dest_type), + ) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_float_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $max_dest_val:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = value.is_nan() + || value.abs() as $rust_dest_type == $max_dest_val; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + Ok(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => Ok::< + Option<$rust_dest_type>, + datafusion_common::DataFusionError, + >(Some(value as $rust_dest_type)), + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_decimal_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect("Expected a Decimal128ArrayType"); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + let is_overflow = truncated.abs() > i32::MAX.into(); + if is_overflow { + return Err(cast_overflow( + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + let i32_value = truncated as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let i32_value = (value / divisor) as i32; + Ok::, datafusion_common::DataFusionError>( + Some(i32_value as $rust_dest_type), + ) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_decimal_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $max_dest_val:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect("Expected a Decimal128ArrayType"); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + let is_overflow = truncated.abs() > $max_dest_val.into(); + if is_overflow { + return Err(cast_overflow( + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + Ok(Some(truncated as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + Ok::, datafusion_common::DataFusionError>( + Some(truncated as $rust_dest_type), + ) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +// --- Public functions --- + +pub(crate) fn spark_cast_float64_to_utf8( + from: &dyn Array, + eval_mode: EvalMode, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, eval_mode, f64, Float64Array, OffsetSize) +} + +pub(crate) fn spark_cast_float32_to_utf8( + from: &dyn Array, + eval_mode: EvalMode, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, eval_mode, f32, Float32Array, OffsetSize) +} + +pub(crate) fn spark_cast_int_to_int( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> Result { + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" + ), + (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" + ), + (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" + ), + (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" + ), + (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" + ), + (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" + ), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), + } +} + +pub(crate) fn spark_cast_nonintegral_numeric_to_integral( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> Result { + match (from_type, to_type) { + (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int8Array, + f32, + i8, + "FLOAT", + "TINYINT", + "{:e}" + ), + (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int16Array, + f32, + i16, + "FLOAT", + "SMALLINT", + "{:e}" + ), + (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int32Array, + f32, + i32, + "FLOAT", + "INT", + i32::MAX, + "{:e}" + ), + (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int64Array, + f32, + i64, + "FLOAT", + "BIGINT", + i64::MAX, + "{:e}" + ), + (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int8Array, + f64, + i8, + "DOUBLE", + "TINYINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int16Array, + f64, + i16, + "DOUBLE", + "SMALLINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int32Array, + f64, + i32, + "DOUBLE", + "INT", + i32::MAX, + "{:e}D" + ), + (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int64Array, + f64, + i64, + "DOUBLE", + "BIGINT", + i64::MAX, + "{:e}D" + ), + (DataType::Decimal128(precision, scale), DataType::Int8) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int16) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int32) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int32Array, + i32, + "INT", + i32::MAX, + *precision, + *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int64) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int64Array, + i64, + "BIGINT", + i64::MAX, + *precision, + *scale + ) + } + _ => unreachable!( + "{}", + format!( + "invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}" + ) + ), + } +} + +pub(crate) fn spark_cast_decimal_to_boolean(array: &dyn Array) -> Result { + let decimal_array = array.as_primitive::(); + let mut result = BooleanBuilder::with_capacity(decimal_array.len()); + for i in 0..decimal_array.len() { + if decimal_array.is_null(i) { + result.append_null() + } else { + result.append_value(!decimal_array.value(i).is_zero()); + } + } + Ok(Arc::new(result.finish())) +} + +pub(crate) fn cast_int_to_decimal128( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, + precision: u8, + scale: i8, +) -> Result { + match (from_type, to_type) { + (DataType::Int8, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int16, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int32, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int64, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + _ => datafusion_common::internal_err!( + "Unsupported cast from datatype: {}", + from_type + ), + } +} + +fn cast_int_to_decimal128_internal( + array: &PrimitiveArray, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> Result +where + T: ArrowPrimitiveType, + T::Native: Into, +{ + let mut builder = Decimal128Builder::with_capacity(array.len()); + let multiplier = 10_i128.pow(scale as u32); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let v = array.value(i).into(); + let scaled = v.checked_mul(multiplier); + match scaled { + Some(scaled) => { + if !is_validate_decimal_precision(scaled, precision) { + match eval_mode { + EvalMode::Ansi => { + return Err(numeric_value_out_of_range( + &v.to_string(), + precision, + scale, + )); + } + EvalMode::Try | EvalMode::Legacy => builder.append_null(), + } + } else { + builder.append_value(scaled); + } + } + _ => match eval_mode { + EvalMode::Ansi => { + return Err(numeric_value_out_of_range( + &v.to_string(), + precision, + scale, + )); + } + EvalMode::Legacy | EvalMode::Try => builder.append_null(), + }, + } + } + } + Ok(Arc::new( + builder.with_precision_and_scale(precision, scale)?.finish(), + )) +} + +pub(crate) fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> Result { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} + +pub(crate) fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> Result { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} + +fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> Result +where + ::Native: AsPrimitive, +{ + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); + + let mul = 10_f64.powi(scale as i32); + + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + continue; + } + + let input_value = input.value(i).as_(); + if let Some(v) = (input_value * mul).round().to_i128() + && is_validate_decimal_precision(v, precision) + { + cast_array.append_value(v); + continue; + } + + if eval_mode == EvalMode::Ansi { + return Err(numeric_value_out_of_range( + &input_value.to_string(), + precision, + scale, + )); + } + cast_array.append_null(); + } + + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cast_int64_to_int32_legacy() { + let array: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(1), + Some(i64::MAX), + Some(i64::MIN), + None, + ])); + let result = spark_cast_int_to_int( + &array, + EvalMode::Legacy, + &DataType::Int64, + &DataType::Int32, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), -1); // truncation + assert_eq!(arr.value(2), 0); // truncation + assert!(arr.is_null(3)); + } + + #[test] + fn test_cast_int64_to_int32_ansi_overflow() { + let array: ArrayRef = Arc::new(Int64Array::from(vec![Some(i64::MAX)])); + let result = spark_cast_int_to_int( + &array, + EvalMode::Ansi, + &DataType::Int64, + &DataType::Int32, + ); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("[CAST_OVERFLOW]")); + } + + #[test] + fn test_cast_decimal_to_boolean() { + let array: ArrayRef = Arc::new( + Decimal128Array::from(vec![Some(100), Some(0), None]) + .with_precision_and_scale(10, 2) + .unwrap(), + ); + let result = spark_cast_decimal_to_boolean(&array).unwrap(); + let bool_arr = result + .as_any() + .downcast_ref::() + .unwrap(); + assert!(bool_arr.value(0)); // 100 != 0 -> true + assert!(!bool_arr.value(1)); // 0 -> false + assert!(bool_arr.is_null(2)); + } + + #[test] + fn test_cast_float_to_string_f64() { + let array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(42.0), + Some(0.001), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + None, + ])); + let result = spark_cast_float64_to_utf8::(&array, EvalMode::Legacy).unwrap(); + let str_arr = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(str_arr.value(0), "42.0"); + assert_eq!(str_arr.value(1), "0.001"); + assert_eq!(str_arr.value(2), "Infinity"); + assert_eq!(str_arr.value(3), "-Infinity"); + assert!(str_arr.is_null(4)); + } + + #[test] + fn test_cast_float_to_decimal() { + let a: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(42.), + Some(-42.4242415), + Some(42e-314), + Some(0.), + Some(f64::INFINITY), + Some(f64::NAN), + None, + ])); + let b = + cast_floating_point_to_decimal128::(&a, 8, 6, EvalMode::Legacy) + .unwrap(); + assert_eq!(b.len(), a.len()); + let casted = b.as_primitive::(); + assert_eq!(casted.value(0), 42000000); + assert_eq!(casted.value(1), -42424242); + assert_eq!(casted.value(2), 0); + assert_eq!(casted.value(3), 0); + assert!(casted.is_null(4)); // Infinity + assert!(casted.is_null(5)); // NaN + assert!(casted.is_null(6)); // null + } + + #[test] + fn test_cast_float32_to_int_legacy() { + // Float32 -> Int8: truncation, NaN -> 0 + let array: ArrayRef = Arc::new(Float32Array::from(vec![ + Some(1.9_f32), + Some(-1.9_f32), + Some(f32::NAN), + Some(0.0_f32), + None, + ])); + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Legacy, + &DataType::Float32, + &DataType::Int8, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); // truncated + assert_eq!(arr.value(1), -1); // truncated + assert_eq!(arr.value(2), 0); // NaN -> 0 in legacy + assert_eq!(arr.value(3), 0); + assert!(arr.is_null(4)); + + // Float32 -> Int16 + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Legacy, + &DataType::Float32, + &DataType::Int16, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), -1); + + // Float32 -> Int32 + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Legacy, + &DataType::Float32, + &DataType::Int32, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), -1); + } + + #[test] + fn test_cast_float64_to_int_ansi_overflow() { + // NaN -> error + let array: ArrayRef = Arc::new(Float64Array::from(vec![Some(f64::NAN)])); + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Ansi, + &DataType::Float64, + &DataType::Int8, + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("[CAST_OVERFLOW]")); + + // Infinity -> error + let array: ArrayRef = Arc::new(Float64Array::from(vec![Some(f64::INFINITY)])); + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Ansi, + &DataType::Float64, + &DataType::Int16, + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("[CAST_OVERFLOW]")); + + // Value exceeding Int32 range + let array: ArrayRef = Arc::new(Float64Array::from(vec![Some(3e10)])); + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Ansi, + &DataType::Float64, + &DataType::Int32, + ); + assert!(result.is_err()); + + // Value exceeding Int64 range + let array: ArrayRef = Arc::new(Float64Array::from(vec![Some(1e19)])); + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Ansi, + &DataType::Float64, + &DataType::Int64, + ); + assert!(result.is_err()); + } + + #[test] + fn test_cast_decimal_to_int_legacy() { + // Decimal128(10, 2) -> Int8: truncation + let array: ArrayRef = Arc::new( + Decimal128Array::from(vec![Some(199), Some(-199), Some(0), None]) + .with_precision_and_scale(10, 2) + .unwrap(), + ); + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Legacy, + &DataType::Decimal128(10, 2), + &DataType::Int8, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); // 199/100=1 + assert_eq!(arr.value(1), -1); + assert_eq!(arr.value(2), 0); + assert!(arr.is_null(3)); + + // Decimal128(10, 2) -> Int32 + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Legacy, + &DataType::Decimal128(10, 2), + &DataType::Int32, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), -1); + + // Decimal128(10, 2) -> Int64 + let result = spark_cast_nonintegral_numeric_to_integral( + &array, + EvalMode::Legacy, + &DataType::Decimal128(10, 2), + &DataType::Int64, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), -1); + } + + #[test] + fn test_cast_int_to_decimal_overflow_ansi() { + // Int64 value that exceeds Decimal128(5,0) precision + let array: ArrayRef = Arc::new(Int64Array::from(vec![Some(999999)])); + let result = cast_int_to_decimal128( + &array, + EvalMode::Ansi, + &DataType::Int64, + &DataType::Decimal128(5, 0), + 5, + 0, + ); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("NUMERIC_VALUE_OUT_OF_RANGE") + ); + } + + #[test] + fn test_cast_int_to_decimal_overflow_legacy() { + // Same value in Legacy → null + let array: ArrayRef = Arc::new(Int64Array::from(vec![Some(999999), Some(42)])); + let result = cast_int_to_decimal128( + &array, + EvalMode::Legacy, + &DataType::Int64, + &DataType::Decimal128(5, 0), + 5, + 0, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); // overflow → null + assert_eq!(arr.value(1), 42); // fits + } + + #[test] + fn test_cast_float32_to_string() { + let array: ArrayRef = Arc::new(Float32Array::from(vec![ + Some(1e7_f32), + Some(0.001_f32), + Some(0.0_f32), + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + Some(f32::NAN), + None, + ])); + let result = spark_cast_float32_to_utf8::(&array, EvalMode::Legacy).unwrap(); + let str_arr = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(str_arr.value(0), "1.0E7"); + assert_eq!(str_arr.value(1), "0.001"); + assert_eq!(str_arr.value(2), "0.0"); + assert_eq!(str_arr.value(3), "Infinity"); + assert_eq!(str_arr.value(4), "-Infinity"); + assert_eq!(str_arr.value(5), "NaN"); + assert!(str_arr.is_null(6)); + } +} diff --git a/datafusion/spark/src/function/conversion/cast_string.rs b/datafusion/spark/src/function/conversion/cast_string.rs new file mode 100644 index 0000000000000..9c86836d4d630 --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast_string.rs @@ -0,0 +1,1818 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::conversion::cast_utils::{ + EvalMode, cast_invalid_input, numeric_out_of_range, numeric_value_out_of_range, +}; +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, BooleanArray, Decimal128Builder, + GenericStringArray, OffsetSizeTrait, PrimitiveArray, PrimitiveBuilder, StringArray, +}; +use arrow::datatypes::{ + DataType, Date32Type, Decimal256Type, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, TimestampMicrosecondType, i256, is_validate_decimal_precision, +}; +use chrono::{DateTime, NaiveDate, TimeZone, Timelike}; +use datafusion_common::Result; +use num_traits::Float; +use regex::Regex; +use std::num::Wrapping; +use std::str::FromStr; +use std::sync::{Arc, LazyLock}; + +// Pre-compiled regex patterns for timestamp parsing (improvement over Comet which recompiles) +static RE_YEAR: LazyLock = LazyLock::new(|| Regex::new(r"^\d{4,5}$").unwrap()); +static RE_MONTH: LazyLock = + LazyLock::new(|| Regex::new(r"^\d{4,5}-\d{2}$").unwrap()); +static RE_DAY: LazyLock = + LazyLock::new(|| Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap()); +static RE_HOUR: LazyLock = + LazyLock::new(|| Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap()); +static RE_MINUTE: LazyLock = + LazyLock::new(|| Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap()); +static RE_SECOND: LazyLock = + LazyLock::new(|| Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap()); +static RE_MICROSECOND: LazyLock = LazyLock::new(|| { + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap() +}); +static RE_TIME_ONLY: LazyLock = + LazyLock::new(|| Regex::new(r"^T\d{1,2}$").unwrap()); + +macro_rules! cast_utf8_to_timestamp { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{ + let len = $array.len(); + let mut cast_array = + PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC"); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Ok(Some(cast_value)) = + $cast_method($array.value(i).trim(), $eval_mode, $tz) + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef; + result + }}; +} + +macro_rules! cast_utf8_to_int { + ($array:expr, $array_type:ty, $parse_fn:expr) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + let parse_fn = $parse_fn; + if $array.null_count() == 0 { + for i in 0..len { + if let Some(cast_value) = parse_fn($array.value(i))? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } else { + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = parse_fn($array.value(i))? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } + let result: Result = Ok(Arc::new(cast_array.finish()) as ArrayRef); + result + }}; +} + +struct TimeStampInfo { + year: i32, + month: u32, + day: u32, + hour: u32, + minute: u32, + second: u32, + microsecond: u32, +} + +impl Default for TimeStampInfo { + fn default() -> Self { + TimeStampInfo { + year: 1, + month: 1, + day: 1, + hour: 0, + minute: 0, + second: 0, + microsecond: 0, + } + } +} + +impl TimeStampInfo { + fn with_year(&mut self, year: i32) -> &mut Self { + self.year = year; + self + } + + fn with_month(&mut self, month: u32) -> &mut Self { + self.month = month; + self + } + + fn with_day(&mut self, day: u32) -> &mut Self { + self.day = day; + self + } + + fn with_hour(&mut self, hour: u32) -> &mut Self { + self.hour = hour; + self + } + + fn with_minute(&mut self, minute: u32) -> &mut Self { + self.minute = minute; + self + } + + fn with_second(&mut self, second: u32) -> &mut Self { + self.second = second; + self + } + + fn with_microsecond(&mut self, microsecond: u32) -> &mut Self { + self.microsecond = microsecond; + self + } +} + +pub(crate) fn is_df_cast_from_string_spark_compatible(to_type: &DataType) -> bool { + matches!(to_type, DataType::Binary) +} + +// --- String → Float --- + +pub(crate) fn cast_string_to_float( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> Result { + match to_type { + DataType::Float32 => { + cast_string_to_float_impl::(array, eval_mode, "FLOAT") + } + DataType::Float64 => { + cast_string_to_float_impl::(array, eval_mode, "DOUBLE") + } + _ => datafusion_common::internal_err!( + "Unsupported cast to float type: {:?}", + to_type + ), + } +} + +fn cast_string_to_float_impl( + array: &ArrayRef, + eval_mode: EvalMode, + type_name: &str, +) -> Result +where + T::Native: FromStr + Float, +{ + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Expected string array".to_string(), + ) + })?; + + let mut builder = PrimitiveBuilder::::with_capacity(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + let str_value = arr.value(i).trim(); + match parse_string_to_float(str_value) { + Some(v) => builder.append_value(v), + None => { + if eval_mode == EvalMode::Ansi { + return Err(cast_invalid_input( + arr.value(i), + "STRING", + type_name, + )); + } + builder.append_null(); + } + } + } + } + + Ok(Arc::new(builder.finish())) +} + +fn parse_string_to_float(s: &str) -> Option +where + F: FromStr + Float, +{ + if s.eq_ignore_ascii_case("inf") + || s.eq_ignore_ascii_case("+inf") + || s.eq_ignore_ascii_case("infinity") + || s.eq_ignore_ascii_case("+infinity") + { + return Some(F::infinity()); + } + if s.eq_ignore_ascii_case("-inf") || s.eq_ignore_ascii_case("-infinity") { + return Some(F::neg_infinity()); + } + if s.eq_ignore_ascii_case("nan") { + return Some(F::nan()); + } + // Remove D/F suffix if present + let pruned_float_str = + if s.ends_with('d') || s.ends_with('D') || s.ends_with('f') || s.ends_with('F') { + &s[..s.len() - 1] + } else { + s + }; + pruned_float_str.parse::().ok() +} + +// --- String → Boolean --- + +pub(crate) fn spark_cast_utf8_to_boolean( + from: &dyn Array, + eval_mode: EvalMode, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), + "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), + _ if eval_mode == EvalMode::Ansi => { + Err(cast_invalid_input(value, "STRING", "BOOLEAN")) + } + _ => Ok(None), + }, + _ => Ok(None), + }) + .collect::>()?; + + Ok(Arc::new(output_array)) +} + +// --- String → Decimal --- + +pub(crate) fn cast_string_to_decimal( + array: &ArrayRef, + to_type: &DataType, + precision: &u8, + scale: &i8, + eval_mode: EvalMode, +) -> Result { + match to_type { + DataType::Decimal128(_, _) => { + cast_string_to_decimal128_impl(array, eval_mode, *precision, *scale) + } + DataType::Decimal256(_, _) => { + cast_string_to_decimal256_impl(array, eval_mode, *precision, *scale) + } + _ => datafusion_common::internal_err!( + "Unexpected type in cast_string_to_decimal: {:?}", + to_type + ), + } +} + +fn cast_string_to_decimal128_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Expected string array".to_string(), + ) + })?; + + let mut decimal_builder = Decimal128Builder::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + decimal_builder.append_value(decimal_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(cast_invalid_input( + string_array.value(i), + "STRING", + &format!("DECIMAL({precision},{scale})"), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +fn cast_string_to_decimal256_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Expected string array".to_string(), + ) + })?; + + let mut decimal_builder = + PrimitiveBuilder::::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + let i256_value = i256::from_i128(decimal_value); + decimal_builder.append_value(i256_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(cast_invalid_input( + str_value, + "STRING", + &format!("DECIMAL({precision},{scale})"), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +fn parse_string_to_decimal( + input_str: &str, + precision: u8, + scale: i8, +) -> Result> { + let string_bytes = input_str.as_bytes(); + let mut start = 0; + let mut end = string_bytes.len(); + + while start < end && string_bytes[start].is_ascii_whitespace() { + start += 1; + } + while end > start && string_bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + + let trimmed = &input_str[start..end]; + + if trimmed.is_empty() { + return Ok(None); + } + + if trimmed.eq_ignore_ascii_case("inf") + || trimmed.eq_ignore_ascii_case("+inf") + || trimmed.eq_ignore_ascii_case("infinity") + || trimmed.eq_ignore_ascii_case("+infinity") + || trimmed.eq_ignore_ascii_case("-inf") + || trimmed.eq_ignore_ascii_case("-infinity") + || trimmed.eq_ignore_ascii_case("nan") + { + return Ok(None); + } + + let (mantissa, exponent) = parse_decimal_str(trimmed, input_str, precision, scale)?; + + if mantissa == 0 { + if exponent < -37 { + return Err(numeric_out_of_range(input_str)); + } + return Ok(Some(0)); + } + + let target_scale = scale as i32; + let scale_adjustment = target_scale - exponent; + + let scaled_value = if scale_adjustment >= 0 { + if scale_adjustment > 38 { + return Ok(None); + } + mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) + } else { + let abs_scale_adjustment = (-scale_adjustment) as u32; + if abs_scale_adjustment > 38 { + return Ok(Some(0)); + } + + let divisor = 10_i128.pow(abs_scale_adjustment); + let quotient_opt = mantissa.checked_div(divisor); + if quotient_opt.is_none() { + return Ok(None); + } + let quotient = quotient_opt.unwrap(); + let remainder = mantissa % divisor; + + let half_divisor = divisor / 2; + let rounded = if remainder.abs() >= half_divisor { + if mantissa >= 0 { + quotient + 1 + } else { + quotient - 1 + } + } else { + quotient + }; + Some(rounded) + }; + + match scaled_value { + Some(value) => { + if is_validate_decimal_precision(value, precision) { + Ok(Some(value)) + } else { + Err(numeric_value_out_of_range(trimmed, precision, scale)) + } + } + None => Err(numeric_value_out_of_range(trimmed, precision, scale)), + } +} + +fn invalid_decimal_cast( + value: &str, + precision: u8, + scale: i8, +) -> datafusion_common::DataFusionError { + cast_invalid_input(value, "STRING", &format!("DECIMAL({precision},{scale})")) +} + +fn parse_decimal_str( + s: &str, + original_str: &str, + precision: u8, + scale: i8, +) -> Result<(i128, i32)> { + if s.is_empty() { + return Err(invalid_decimal_cast(original_str, precision, scale)); + } + + let (mantissa_str, exponent) = + if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) { + let mantissa_part = &s[..e_pos]; + let exponent_part = &s[e_pos + 1..]; + let exp: i32 = exponent_part + .parse() + .map_err(|_| invalid_decimal_cast(original_str, precision, scale))?; + (mantissa_part, exp) + } else { + (s, 0) + }; + + let negative = mantissa_str.starts_with('-'); + let mantissa_str = if negative || mantissa_str.starts_with('+') { + &mantissa_str[1..] + } else { + mantissa_str + }; + + if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') { + return Err(invalid_decimal_cast(original_str, precision, scale)); + } + + let (integral_part, fractional_part) = match mantissa_str.find('.') { + Some(dot_pos) => { + if mantissa_str[dot_pos + 1..].contains('.') { + return Err(invalid_decimal_cast(original_str, precision, scale)); + } + (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..]) + } + None => (mantissa_str, ""), + }; + + if integral_part.is_empty() && fractional_part.is_empty() { + return Err(invalid_decimal_cast(original_str, precision, scale)); + } + + if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) { + return Err(invalid_decimal_cast(original_str, precision, scale)); + } + + if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) + { + return Err(invalid_decimal_cast(original_str, precision, scale)); + } + + let integral_value: i128 = if integral_part.is_empty() { + 0 + } else { + integral_part + .parse() + .map_err(|_| invalid_decimal_cast(original_str, precision, scale))? + }; + + let fractional_scale = fractional_part.len() as i32; + let fractional_value: i128 = if fractional_part.is_empty() { + 0 + } else { + fractional_part + .parse() + .map_err(|_| invalid_decimal_cast(original_str, precision, scale))? + }; + + let mantissa = integral_value + .checked_mul(10_i128.pow(fractional_scale as u32)) + .and_then(|v| v.checked_add(fractional_value)) + .ok_or_else(|| invalid_decimal_cast(original_str, precision, scale))?; + + let final_mantissa = if negative { -mantissa } else { mantissa }; + let final_scale = fractional_scale - exponent; + Ok((final_mantissa, final_scale)) +} + +// --- String → Date --- + +pub(crate) fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + if to_type != &DataType::Date32 { + unreachable!("Invalid data type {:?} in cast from string", to_type); + } + + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); + + for i in 0..len { + let value = if string_array.is_null(i) { + None + } else { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => Some(cast_value), + Ok(None) => None, + Err(e) => return Err(e), + } + }; + + match value { + Some(cast_value) => cast_array.append_value(cast_value), + None => cast_array.append_null(), + } + } + + Ok(Arc::new(cast_array.finish()) as ArrayRef) +} + +fn date_parser(date_str: &str, eval_mode: EvalMode) -> Result> { + fn get_trimmed_start(bytes: &[u8]) -> usize { + let mut start = 0; + while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) { + start += 1; + } + start + } + + fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize { + let mut end = bytes.len() - 1; + while end > start && is_whitespace_or_iso_control(bytes[end]) { + end -= 1; + } + end + 1 + } + + fn is_whitespace_or_iso_control(byte: u8) -> bool { + byte.is_ascii_whitespace() || byte.is_ascii_control() + } + + fn is_valid_digits(segment: i32, digits: usize) -> bool { + let max_digits_year = 7; + (segment == 0 && digits >= 4 && digits <= max_digits_year) + || (segment != 0 && digits > 0 && digits <= 2) + } + + fn return_result(date_str: &str, eval_mode: EvalMode) -> Result> { + if eval_mode == EvalMode::Ansi { + Err(cast_invalid_input(date_str, "STRING", "DATE")) + } else { + Ok(None) + } + } + + if date_str.is_empty() { + return return_result(date_str, eval_mode); + } + + let mut date_segments = [1, 1, 1]; + let mut sign = 1; + let mut current_segment = 0; + let mut current_segment_value = Wrapping(0); + let mut current_segment_digits = 0; + let bytes = date_str.as_bytes(); + + let mut j = get_trimmed_start(bytes); + let str_end_trimmed = get_trimmed_end(j, bytes); + + if j == str_end_trimmed { + return return_result(date_str, eval_mode); + } + + if bytes[j] == b'-' || bytes[j] == b'+' { + sign = if bytes[j] == b'-' { -1 } else { 1 }; + j += 1; + } + + while j < str_end_trimmed + && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) + { + let b = bytes[j]; + if current_segment < 2 && b == b'-' { + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + date_segments[current_segment as usize] = current_segment_value.0; + current_segment_value = Wrapping(0); + current_segment_digits = 0; + current_segment += 1; + } else if !b.is_ascii_digit() { + return return_result(date_str, eval_mode); + } else { + let parsed_value = Wrapping((b - b'0') as i32); + current_segment_value = current_segment_value * Wrapping(10) + parsed_value; + current_segment_digits += 1; + } + j += 1; + } + + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + + if current_segment < 2 && j < str_end_trimmed { + return return_result(date_str, eval_mode); + } + + date_segments[current_segment as usize] = current_segment_value.0; + + match NaiveDate::from_ymd_opt( + sign * date_segments[0], + date_segments[1] as u32, + date_segments[2] as u32, + ) { + Some(date) => { + let duration_since_epoch = date + .signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date()) + .num_days(); + Ok(Some(duration_since_epoch.try_into().unwrap_or(i32::MAX))) + } + None => Ok(None), + } +} + +// --- String → Timestamp --- + +pub(crate) fn cast_string_to_timestamp( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + tz: &T, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Timestamp(_, _) => { + cast_utf8_to_timestamp!( + string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser, + tz + ) + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) +} + +fn get_timestamp_values( + value: &str, + timestamp_type: &str, + tz: &T, +) -> Result> { + let values: Vec<_> = value.split(['T', '-', ':', '.']).collect(); + let year = values[0].parse::().unwrap_or_default(); + let month = values.get(1).map_or(1, |m| m.parse::().unwrap_or(1)); + let day = values.get(2).map_or(1, |d| d.parse::().unwrap_or(1)); + let hour = values.get(3).map_or(0, |h| h.parse::().unwrap_or(0)); + let minute = values.get(4).map_or(0, |m| m.parse::().unwrap_or(0)); + let second = values.get(5).map_or(0, |s| s.parse::().unwrap_or(0)); + let microsecond = values.get(6).map_or(0, |ms| ms.parse::().unwrap_or(0)); + + let mut timestamp_info = TimeStampInfo::default(); + + let timestamp_info = match timestamp_type { + "year" => timestamp_info.with_year(year), + "month" => timestamp_info.with_year(year).with_month(month), + "day" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day), + "hour" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour), + "minute" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour) + .with_minute(minute), + "second" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour) + .with_minute(minute) + .with_second(second), + "microsecond" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour) + .with_minute(minute) + .with_second(second) + .with_microsecond(microsecond), + _ => { + return Err(cast_invalid_input(value, "STRING", "TIMESTAMP")); + } + }; + parse_timestamp_to_micros(timestamp_info, tz) +} + +fn parse_timestamp_to_micros( + timestamp_info: &TimeStampInfo, + tz: &T, +) -> Result> { + let datetime = tz.with_ymd_and_hms( + timestamp_info.year, + timestamp_info.month, + timestamp_info.day, + timestamp_info.hour, + timestamp_info.minute, + timestamp_info.second, + ); + + let tz_datetime = match datetime.single() { + Some(dt) => dt + .with_timezone(tz) + .with_nanosecond(timestamp_info.microsecond * 1000), + None => { + return Err(datafusion_common::DataFusionError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + let result = match tz_datetime { + Some(dt) => dt.timestamp_micros(), + None => { + return Err(datafusion_common::DataFusionError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(result)) +} + +fn parse_str_to_year_timestamp(value: &str, tz: &T) -> Result> { + get_timestamp_values(value, "year", tz) +} + +fn parse_str_to_month_timestamp(value: &str, tz: &T) -> Result> { + get_timestamp_values(value, "month", tz) +} + +fn parse_str_to_day_timestamp(value: &str, tz: &T) -> Result> { + get_timestamp_values(value, "day", tz) +} + +fn parse_str_to_hour_timestamp(value: &str, tz: &T) -> Result> { + get_timestamp_values(value, "hour", tz) +} + +fn parse_str_to_minute_timestamp( + value: &str, + tz: &T, +) -> Result> { + get_timestamp_values(value, "minute", tz) +} + +fn parse_str_to_second_timestamp( + value: &str, + tz: &T, +) -> Result> { + get_timestamp_values(value, "second", tz) +} + +fn parse_str_to_microsecond_timestamp( + value: &str, + tz: &T, +) -> Result> { + get_timestamp_values(value, "microsecond", tz) +} + +type TimestampPattern = ( + &'static LazyLock, + fn(&str, &T) -> Result>, +); + +fn timestamp_parser( + value: &str, + eval_mode: EvalMode, + tz: &T, +) -> Result> { + let value = value.trim(); + if value.is_empty() { + return Ok(None); + } + + let patterns: &[TimestampPattern] = &[ + (&RE_YEAR, parse_str_to_year_timestamp), + (&RE_MONTH, parse_str_to_month_timestamp), + (&RE_DAY, parse_str_to_day_timestamp), + (&RE_HOUR, parse_str_to_hour_timestamp), + (&RE_MINUTE, parse_str_to_minute_timestamp), + (&RE_SECOND, parse_str_to_second_timestamp), + (&RE_MICROSECOND, parse_str_to_microsecond_timestamp), + (&RE_TIME_ONLY, parse_str_to_time_only_timestamp), + ]; + + let mut timestamp = None; + + for (pattern, parse_func) in patterns { + if pattern.is_match(value) { + timestamp = parse_func(value, tz)?; + break; + } + } + + if timestamp.is_none() { + return if eval_mode == EvalMode::Ansi { + Err(cast_invalid_input(value, "STRING", "TIMESTAMP")) + } else { + Ok(None) + }; + } + + match timestamp { + Some(ts) => Ok(Some(ts)), + None => Err(datafusion_common::DataFusionError::Internal( + "Failed to parse timestamp".to_string(), + )), + } +} + +fn parse_str_to_time_only_timestamp( + value: &str, + tz: &T, +) -> Result> { + let values: Vec<&str> = value.split('T').collect(); + let time_values: Vec = values[1] + .split(':') + .map(|v| v.parse::().unwrap_or(0)) + .collect(); + + let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc()); + let timestamp = datetime + .with_timezone(tz) + .with_hour(time_values.first().copied().unwrap_or_default()) + .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0))) + .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0))) + .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000)) + .map(|dt| dt.timestamp_micros()) + .unwrap_or_default(); + + Ok(Some(timestamp)) +} + +// --- String → Int --- + +pub(crate) fn cast_string_to_int( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("cast_string_to_int expected a string array"); + + let cast_array: ArrayRef = match (to_type, eval_mode) { + (DataType::Int8, EvalMode::Legacy) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_legacy)? + } + (DataType::Int8, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_ansi)? + } + (DataType::Int8, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_try)? + } + (DataType::Int16, EvalMode::Legacy) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_legacy)? + } + (DataType::Int16, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_ansi)? + } + (DataType::Int16, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_try)? + } + (DataType::Int32, EvalMode::Legacy) => cast_utf8_to_int!( + string_array, + Int32Type, + |s| do_parse_string_to_int_legacy::(s, i32::MIN) + )?, + (DataType::Int32, EvalMode::Ansi) => cast_utf8_to_int!( + string_array, + Int32Type, + |s| do_parse_string_to_int_ansi::(s, "INT", i32::MIN) + )?, + (DataType::Int32, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int32Type, |s| do_parse_string_to_int_try::< + i32, + >(s, i32::MIN))? + } + (DataType::Int64, EvalMode::Legacy) => cast_utf8_to_int!( + string_array, + Int64Type, + |s| do_parse_string_to_int_legacy::(s, i64::MIN) + )?, + (DataType::Int64, EvalMode::Ansi) => cast_utf8_to_int!( + string_array, + Int64Type, + |s| do_parse_string_to_int_ansi::(s, "BIGINT", i64::MIN) + )?, + (DataType::Int64, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int64Type, |s| do_parse_string_to_int_try::< + i64, + >(s, i64::MIN))? + } + (dt, _) => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; + Ok(cast_array) +} + +fn finalize_int_result(result: T, negative: bool) -> Option +where + T: num_traits::CheckedNeg + num_traits::Zero + Copy + PartialOrd, +{ + if negative { + Some(result) + } else { + result.checked_neg().filter(|&n| n >= T::zero()) + } +} + +fn do_parse_string_to_int_legacy(str: &str, min_value: T) -> Result> +where + T: num_traits::CheckedNeg + + num_traits::Zero + + Copy + + PartialOrd + + std::ops::Mul + + std::ops::Div + + From, + T: num_traits::CheckedSub, +{ + let trimmed_bytes = str.as_bytes().trim_ascii(); + + let (negative, digits) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Ok(None), + }; + + let mut result: T = T::zero(); + let radix = T::from(10_u8); + let stop_value = min_value / radix; + + let mut iter = digits.iter(); + + for &ch in iter.by_ref() { + if ch == b'.' { + break; + } + + if !ch.is_ascii_digit() { + return Ok(None); + } + + if result < stop_value { + return Ok(None); + } + let v = result * radix; + let digit: T = T::from(ch - b'0'); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => return Ok(None), + } + } + + for &ch in iter { + if !ch.is_ascii_digit() { + return Ok(None); + } + } + + Ok(finalize_int_result(result, negative)) +} + +fn do_parse_string_to_int_ansi( + str: &str, + type_name: &str, + min_value: T, +) -> Result> +where + T: num_traits::CheckedNeg + + num_traits::Zero + + Copy + + PartialOrd + + std::ops::Mul + + std::ops::Div + + From, + T: num_traits::CheckedSub, +{ + let error = || Err(cast_invalid_input(str, "STRING", type_name)); + + let trimmed_bytes = str.as_bytes().trim_ascii(); + + let (negative, digits) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return error(), + }; + + let mut result: T = T::zero(); + let radix = T::from(10_u8); + let stop_value = min_value / radix; + + for &ch in digits { + if ch == b'.' || !ch.is_ascii_digit() { + return error(); + } + + if result < stop_value { + return error(); + } + let v = result * radix; + let digit: T = T::from(ch - b'0'); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => return error(), + } + } + + finalize_int_result(result, negative) + .map(Some) + .ok_or_else(|| cast_invalid_input(str, "STRING", type_name)) +} + +fn do_parse_string_to_int_try(str: &str, min_value: T) -> Result> +where + T: num_traits::CheckedNeg + + num_traits::Zero + + Copy + + PartialOrd + + std::ops::Mul + + std::ops::Div + + From, + T: num_traits::CheckedSub, +{ + let trimmed_bytes = str.as_bytes().trim_ascii(); + + let (negative, digits) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Ok(None), + }; + + let mut result: T = T::zero(); + let radix = T::from(10_u8); + let stop_value = min_value / radix; + + for &ch in digits { + if ch == b'.' || !ch.is_ascii_digit() { + return Ok(None); + } + + if result < stop_value { + return Ok(None); + } + let v = result * radix; + let digit: T = T::from(ch - b'0'); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => return Ok(None), + } + } + + Ok(finalize_int_result(result, negative)) +} + +fn parse_string_to_i8_legacy(str: &str) -> Result> { + match do_parse_string_to_int_legacy::(str, i32::MIN)? { + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Ok(None), + } +} + +fn parse_string_to_i8_ansi(str: &str) -> Result> { + match do_parse_string_to_int_ansi::(str, "TINYINT", i32::MIN)? { + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Err(cast_invalid_input(str, "STRING", "TINYINT")), + } +} + +fn parse_string_to_i8_try(str: &str) -> Result> { + match do_parse_string_to_int_try::(str, i32::MIN)? { + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Ok(None), + } +} + +fn parse_string_to_i16_legacy(str: &str) -> Result> { + match do_parse_string_to_int_legacy::(str, i32::MIN)? { + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), + _ => Ok(None), + } +} + +fn parse_string_to_i16_ansi(str: &str) -> Result> { + match do_parse_string_to_int_ansi::(str, "SMALLINT", i32::MIN)? { + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), + _ => Err(cast_invalid_input(str, "STRING", "SMALLINT")), + } +} + +fn parse_string_to_i16_try(str: &str) -> Result> { + match do_parse_string_to_int_try::(str, i32::MIN)? { + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), + _ => Ok(None), + } +} + +fn parse_sign(bytes: &[u8]) -> Option<(bool, &[u8])> { + let (&first, rest) = bytes.split_first()?; + match first { + b'-' if !rest.is_empty() => Some((true, rest)), + b'+' if !rest.is_empty() => Some((false, rest)), + _ => Some((false, bytes)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Decimal128Array; + use arrow::datatypes::{Decimal256Type, TimeUnit}; + + fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> Result> { + match eval_mode { + EvalMode::Legacy => parse_string_to_i8_legacy(str), + EvalMode::Ansi => parse_string_to_i8_ansi(str), + EvalMode::Try => parse_string_to_i8_try(str), + } + } + + #[test] + fn test_cast_string_as_i8() { + assert_eq!( + cast_string_to_i8("127", EvalMode::Legacy).unwrap(), + Some(127_i8) + ); + assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None); + assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err()); + assert_eq!( + cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + assert_eq!( + cast_string_to_i8(".", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None); + assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None); + assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err()); + assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err()); + } + + #[test] + fn test_cast_string_to_boolean() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("true"), + Some("false"), + Some("t"), + Some("f"), + Some("yes"), + Some("no"), + Some("y"), + Some("n"), + Some("1"), + Some("0"), + Some("invalid"), + None, + ])); + + let result = spark_cast_utf8_to_boolean::(&array, EvalMode::Legacy).unwrap(); + let bool_arr = result.as_any().downcast_ref::().unwrap(); + + assert!(bool_arr.value(0)); + assert!(!bool_arr.value(1)); + assert!(bool_arr.value(2)); + assert!(!bool_arr.value(3)); + assert!(bool_arr.value(4)); + assert!(!bool_arr.value(5)); + assert!(bool_arr.value(6)); + assert!(!bool_arr.value(7)); + assert!(bool_arr.value(8)); + assert!(!bool_arr.value(9)); + assert!(bool_arr.is_null(10)); // "invalid" -> null in Legacy + assert!(bool_arr.is_null(11)); // null -> null + } + + #[test] + fn test_cast_string_to_boolean_ansi_error() { + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("invalid")])); + let result = spark_cast_utf8_to_boolean::(&array, EvalMode::Ansi); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("[CAST_INVALID_INPUT]")); + } + + #[test] + fn test_date_parser() { + for date in &[ + "2020", + "2020-01", + "2020-01-01", + "02020-01-01", + "002020-01-01", + "0002020-01-01", + "2020-1-1", + "2020-01-01 ", + "2020-01-01T", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] { + assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262)); + } + } + + for date in &[ + "abc", + "", + "not_a_date", + "3/", + "3/12", + "3/12/2020", + "3/12/2002 T", + "202", + "2020-010-01", + "2020-10-010", + "2020-10-010T", + "--262143-12-31", + "--262143-12-31 ", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + assert_eq!(date_parser(date, *eval_mode).unwrap(), None); + } + assert!(date_parser(date, EvalMode::Ansi).is_err()); + } + } + + #[test] + fn test_cast_string_to_date_array() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + Some("2020-01-01T"), + ])); + + let result = + cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(date32_array.len(), 4); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), 18262)); + } + + #[test] + fn test_timestamp_parser() { + let tz: arrow::array::timezone::Tz = "UTC".parse().unwrap(); + // 2020 dates — all 7 timestamp format levels + assert_eq!( + timestamp_parser("2020", EvalMode::Legacy, &tz).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01", EvalMode::Legacy, &tz).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01", EvalMode::Legacy, &tz).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12", EvalMode::Legacy, &tz).unwrap(), + Some(1577880000000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, &tz).unwrap(), + Some(1577882040000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, &tz).unwrap(), + Some(1577882096000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, &tz) + .unwrap(), + Some(1577882096123456) + ); + // 0100 dates + assert_eq!( + timestamp_parser("0100", EvalMode::Legacy, &tz).unwrap(), + Some(-59011459200000000) + ); + assert_eq!( + timestamp_parser("0100-01", EvalMode::Legacy, &tz).unwrap(), + Some(-59011459200000000) + ); + assert_eq!( + timestamp_parser("0100-01-01", EvalMode::Legacy, &tz).unwrap(), + Some(-59011459200000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12", EvalMode::Legacy, &tz).unwrap(), + Some(-59011416000000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, &tz).unwrap(), + Some(-59011413960000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, &tz).unwrap(), + Some(-59011413904000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, &tz) + .unwrap(), + Some(-59011413903876544) + ); + // 10000 dates + assert_eq!( + timestamp_parser("10000", EvalMode::Legacy, &tz).unwrap(), + Some(253402300800000000) + ); + assert_eq!( + timestamp_parser("10000-01", EvalMode::Legacy, &tz).unwrap(), + Some(253402300800000000) + ); + assert_eq!( + timestamp_parser("10000-01-01", EvalMode::Legacy, &tz).unwrap(), + Some(253402300800000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12", EvalMode::Legacy, &tz).unwrap(), + Some(253402344000000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, &tz).unwrap(), + Some(253402346040000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, &tz).unwrap(), + Some(253402346096000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, &tz) + .unwrap(), + Some(253402346096123456) + ); + } + + #[test] + fn test_cast_string_to_timestamp_array() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + Some("0100-01-01T12:34:56.123456"), + Some("10000-01-01T12:34:56.123456"), + ])); + let tz: arrow::array::timezone::Tz = "UTC".parse().unwrap(); + let result = cast_string_to_timestamp( + &array, + &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + EvalMode::Legacy, + &tz, + ) + .unwrap(); + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!(result.len(), 4); + } + + #[test] + fn test_cast_string_to_date_valid_with_whitespace() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("-262143-12-31"), + Some("\n -262143-12-31 "), + Some("-262143-12-31T \t\n"), + Some("\n\t-262143-12-31T\r"), + Some("-262143-12-31T 123123123"), + Some("\r\n-262143-12-31T \r123123123"), + Some("\n -262143-12-31T \n\t"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + let result = + cast_string_to_date(&array, &DataType::Date32, *eval_mode).unwrap(); + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result.len(), 7); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), -96464928)); + } + } + + #[test] + fn test_cast_string_to_date_invalid_formats() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + Some("2020-010-01T"), + Some("202"), + Some(" 202 "), + Some("\n 2020-\r8 "), + Some("2020-01-01T"), + Some("-4607172990231812908"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + let result = + cast_string_to_date(&array, &DataType::Date32, *eval_mode).unwrap(); + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + date32_array.iter().collect::>(), + vec![ + Some(18262), + Some(18262), + Some(18262), + None, + None, + None, + None, + Some(18262), + None, + ] + ); + } + + let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Ansi); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("[CAST_INVALID_INPUT]")); + } + + #[test] + fn test_cast_string_to_float32() { + // Normal values + let array: ArrayRef = + Arc::new(StringArray::from(vec![Some("1.5"), Some("-2.75")])); + let result = cast_string_to_float(&array, &DataType::Float32, EvalMode::Legacy) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert!((arr.value(0) - 1.5_f32).abs() < f32::EPSILON); + assert!((arr.value(1) - (-2.75_f32)).abs() < f32::EPSILON); + + // NaN, Infinity + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("NaN"), + Some("Infinity"), + Some("-Infinity"), + ])); + let result = cast_string_to_float(&array, &DataType::Float32, EvalMode::Legacy) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert!(arr.value(0).is_nan()); + assert!(arr.value(1).is_infinite() && arr.value(1) > 0.0); + assert!(arr.value(2).is_infinite() && arr.value(2) < 0.0); + + // D/F suffixes + let array: ArrayRef = + Arc::new(StringArray::from(vec![Some("1.5D"), Some("2.5F")])); + let result = cast_string_to_float(&array, &DataType::Float32, EvalMode::Legacy) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert!((arr.value(0) - 1.5_f32).abs() < f32::EPSILON); + assert!((arr.value(1) - 2.5_f32).abs() < f32::EPSILON); + + // Empty string and whitespace → null in Legacy + let array: ArrayRef = + Arc::new(StringArray::from(vec![Some(""), Some(" ")])); + let result = cast_string_to_float(&array, &DataType::Float32, EvalMode::Legacy) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert!(arr.is_null(0)); + assert!(arr.is_null(1)); + + // Invalid string → error in Ansi + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("abc")])); + let result = + cast_string_to_float(&array, &DataType::Float32, EvalMode::Ansi); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("[CAST_INVALID_INPUT]")); + } + + #[test] + fn test_cast_string_to_float64() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("1.5"), + Some("NaN"), + Some("Infinity"), + Some("-Infinity"), + Some("1.5D"), + Some("2.5F"), + None, + ])); + let result = cast_string_to_float(&array, &DataType::Float64, EvalMode::Legacy) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert!((arr.value(0) - 1.5_f64).abs() < f64::EPSILON); + assert!(arr.value(1).is_nan()); + assert!(arr.value(2).is_infinite() && arr.value(2) > 0.0); + assert!(arr.value(3).is_infinite() && arr.value(3) < 0.0); + assert!((arr.value(4) - 1.5_f64).abs() < f64::EPSILON); + assert!((arr.value(5) - 2.5_f64).abs() < f64::EPSILON); + assert!(arr.is_null(6)); + + // Invalid → null in Try, error in Ansi + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("xyz")])); + let result = cast_string_to_float(&array, &DataType::Float64, EvalMode::Try) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert!(arr.is_null(0)); + + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("xyz")])); + let result = + cast_string_to_float(&array, &DataType::Float64, EvalMode::Ansi); + assert!(result.is_err()); + } + + #[test] + fn test_cast_string_to_decimal128() { + // Integer string + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("42")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal128(10, 2), + &10, + &2, + EvalMode::Legacy, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 4200); // 42 * 10^2 + + // Decimal string + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("3.14")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal128(10, 2), + &10, + &2, + EvalMode::Legacy, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 314); + + // Scientific notation + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("1.5E2")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal128(10, 2), + &10, + &2, + EvalMode::Legacy, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 15000); // 150.00 -> 15000 + + // Empty string → null in Legacy + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal128(10, 2), + &10, + &2, + EvalMode::Legacy, + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + + // Empty string → error in Ansi + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal128(10, 2), + &10, + &2, + EvalMode::Ansi, + ); + assert!(result.is_err()); + + // Value exceeding precision → error + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("99999999999")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal128(5, 0), + &5, + &0, + EvalMode::Ansi, + ); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("NUMERIC_VALUE_OUT_OF_RANGE")); + } + + #[test] + fn test_cast_string_to_decimal256() { + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("12345.67")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal256(38, 2), + &38, + &2, + EvalMode::Legacy, + ) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(arr.value(0), i256::from_i128(1234567)); + + // High-precision value + let array: ArrayRef = + Arc::new(StringArray::from(vec![Some("99999999999999999999")])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal256(38, 0), + &38, + &0, + EvalMode::Legacy, + ) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + arr.value(0), + i256::from_i128(99999999999999999999_i128) + ); + + // Null input + let array: ArrayRef = Arc::new(StringArray::from(vec![None::<&str>])); + let result = cast_string_to_decimal( + &array, + &DataType::Decimal256(38, 2), + &38, + &2, + EvalMode::Legacy, + ) + .unwrap(); + let arr = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert!(arr.is_null(0)); + } +} diff --git a/datafusion/spark/src/function/conversion/cast_utils.rs b/datafusion/spark/src/function/conversion/cast_utils.rs new file mode 100644 index 0000000000000..7f3ed31be2c1d --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast_utils.rs @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, +}; +use arrow::compute::unary; +use arrow::datatypes::{DataType, Int64Type, TimeUnit, TimestampMicrosecondType}; +use arrow::error::ArrowError; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{DataFusionError, Result, exec_err}; +use std::sync::Arc; + +/// Spark evaluation modes matching Spark's three cast behaviors. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +pub enum EvalMode { + /// Legacy: default behavior before Spark 4.0. Silently returns NULL on errors. + Legacy, + /// Ansi: strict ANSI SQL mode that throws errors on invalid input. + Ansi, + /// Try: like Ansi but converts errors to NULL instead of failing. + Try, +} + +/// Simplified Spark cast options for the UDF context. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct SparkCastOptions { + pub eval_mode: EvalMode, + pub timezone: String, +} + +impl SparkCastOptions { + pub fn new(eval_mode: EvalMode, timezone: &str) -> Self { + Self { + eval_mode, + timezone: timezone.to_string(), + } + } +} + +pub(crate) const MICROS_PER_SECOND: i64 = 1_000_000; + +pub(crate) static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); + +/// Parse a Spark SQL type name string into an Arrow DataType. +pub fn parse_spark_datatype(s: &str) -> Result { + let s = s.trim().to_uppercase(); + match s.as_str() { + "BOOLEAN" | "BOOL" => Ok(DataType::Boolean), + "TINYINT" | "BYTE" | "INT1" => Ok(DataType::Int8), + "SMALLINT" | "SHORT" | "INT2" => Ok(DataType::Int16), + "INT" | "INTEGER" | "INT4" => Ok(DataType::Int32), + "BIGINT" | "LONG" | "INT8" => Ok(DataType::Int64), + "FLOAT" | "REAL" => Ok(DataType::Float32), + "DOUBLE" => Ok(DataType::Float64), + "STRING" => Ok(DataType::Utf8), + "BINARY" => Ok(DataType::Binary), + "DATE" => Ok(DataType::Date32), + "TIMESTAMP" => Ok(DataType::Timestamp( + TimeUnit::Microsecond, + Some("UTC".into()), + )), + "TIMESTAMP_NTZ" => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)), + _ if s.starts_with("DECIMAL") + || s.starts_with("DEC") + || s.starts_with("NUMERIC") => + { + parse_decimal_type(&s) + } + _ => exec_err!("Unsupported Spark SQL type: {s}"), + } +} + +fn parse_decimal_type(s: &str) -> Result { + // DECIMAL, DECIMAL(p), DECIMAL(p, s) + if let Some(paren_start) = s.find('(') { + let paren_end = s.find(')').ok_or_else(|| { + DataFusionError::Execution(format!("Invalid decimal type: {s}")) + })?; + let inner = &s[paren_start + 1..paren_end]; + let parts: Vec<&str> = inner.split(',').map(|p| p.trim()).collect(); + match parts.len() { + 1 => { + let precision: u8 = parts[0].parse().map_err(|_| { + DataFusionError::Execution(format!("Invalid precision: {s}")) + })?; + Ok(DataType::Decimal128(precision, 0)) + } + 2 => { + let precision: u8 = parts[0].parse().map_err(|_| { + DataFusionError::Execution(format!("Invalid precision: {s}")) + })?; + let scale: i8 = parts[1].parse().map_err(|_| { + DataFusionError::Execution(format!("Invalid scale: {s}")) + })?; + Ok(DataType::Decimal128(precision, scale)) + } + _ => exec_err!("Invalid decimal type: {s}"), + } + } else { + // DECIMAL without parameters defaults to DECIMAL(10, 0) in Spark + Ok(DataType::Decimal128(10, 0)) + } +} + +// --- Error helpers --- + +/// Creates a DataFusionError with Spark's [CAST_INVALID_INPUT] message format. +#[inline] +pub fn cast_invalid_input( + value: &str, + from_type: &str, + to_type: &str, +) -> DataFusionError { + DataFusionError::Execution(format!( + "[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error." + )) +} + +/// Creates a DataFusionError with Spark's [CAST_OVERFLOW] message format. +#[inline] +pub fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> DataFusionError { + DataFusionError::Execution(format!( + "[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error." + )) +} + +/// Creates a DataFusionError with Spark's [NUMERIC_VALUE_OUT_OF_RANGE] message format. +#[inline] +pub fn numeric_value_out_of_range( + value: &str, + precision: u8, + scale: i8, +) -> DataFusionError { + DataFusionError::Execution(format!( + "[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). \ + If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead." + )) +} + +/// Creates a DataFusionError for values that exceed the supported numeric range. +#[inline] +pub fn numeric_out_of_range(value: &str) -> DataFusionError { + DataFusionError::Execution(format!( + "[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric \ + since it has more than 38 digits." + )) +} + +// --- Post-processing --- + +/// A fork & modified version of Arrow's `unary_dyn` +pub(crate) fn unary_dyn( + array: &ArrayRef, + op: F, +) -> std::result::Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + if let Some(d) = array.as_any_dictionary_opt() { + let new_values = unary_dyn::(d.values(), op)?; + return Ok(Arc::new(d.with_values(Arc::new(new_values)))); + } + + match array.as_primitive_opt::() { + Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + array.data_type() + ))), + } +} + +/// Spark-specific post-processing after cast. Handles: +/// - Timestamp -> Int64: divide by MICROS_PER_SECOND +/// - Timestamp -> Utf8: remove trailing zeroes from fractional seconds +pub(crate) fn spark_cast_postprocess( + array: ArrayRef, + from_type: &DataType, + to_type: &DataType, +) -> ArrayRef { + match (from_type, to_type) { + (DataType::Timestamp(_, _), DataType::Int64) => { + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)) + .unwrap() + } + (DataType::Dictionary(_, value_type), DataType::Int64) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)) + .unwrap() + } + (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(&array), + (DataType::Dictionary(_, value_type), DataType::Utf8) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + remove_trailing_zeroes(&array) + } + _ => array, + } +} + +/// Integer floor division (rounds toward negative infinity). +fn div_floor(a: i64, b: i64) -> i64 { + let d = a / b; + let r = a % b; + if (r != 0) && ((r ^ b) < 0) { d - 1 } else { d } +} + +fn remove_trailing_zeroes(array: &ArrayRef) -> ArrayRef { + let string_array = as_generic_string_array::(&array).unwrap(); + let result = string_array + .iter() + .map(|s| s.map(trim_end)) + .collect::>(); + Arc::new(result) as ArrayRef +} + +fn trim_end(s: &str) -> &str { + if s.rfind('.').is_some() { + s.trim_end_matches('0') + } else { + s + } +} + +/// Attach timezone metadata to timestamp arrays that are missing it, +/// and handle NTZ (no timezone) to TZ conversions. +pub(crate) fn array_with_timezone( + array: ArrayRef, + timezone: &str, + to_type: Option<&DataType>, +) -> std::result::Result { + match array.data_type() { + DataType::Timestamp(_, None) => { + // TimestampNTZ: no timezone info + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array), + Some(DataType::Timestamp(_, Some(_))) => { + // Convert NTZ to TZ by adding the session timezone + timestamp_ntz_to_timestamp(&array, timezone, Some(timezone)) + } + _ => Ok(array), + } + } + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { + if !timezone.is_empty() { + let ts_array = array.as_primitive::(); + let array_with_tz = ts_array.clone().with_timezone(timezone.to_string()); + let array = Arc::new(array_with_tz) as ArrayRef; + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => { + pre_timestamp_cast(&array, timezone) + } + _ => Ok(array), + } + } else { + Ok(array) + } + } + _ => Ok(array), + } +} + +/// Convert a timestamp without timezone to one with timezone. +fn timestamp_ntz_to_timestamp( + array: &ArrayRef, + _from_tz: &str, + to_tz: Option<&str>, +) -> std::result::Result { + let ts_array = array.as_primitive::(); + let result = ts_array + .clone() + .with_timezone_opt(to_tz.map(|s| Arc::from(s) as Arc)); + Ok(Arc::new(result) as ArrayRef) +} + +/// Pre-process timestamp for cast to string/date by setting UTC timezone. +fn pre_timestamp_cast( + array: &ArrayRef, + _timezone: &str, +) -> std::result::Result { + let ts_array = array.as_primitive::(); + let result = ts_array.clone().with_timezone("UTC"); + Ok(Arc::new(result) as ArrayRef) +} + +/// Format a decimal value string with proper decimal point placement. +pub(crate) fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + let bound = precision.min(rest.len()) + sign.len(); + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_spark_datatype_basic() { + assert_eq!(parse_spark_datatype("BOOLEAN").unwrap(), DataType::Boolean); + assert_eq!(parse_spark_datatype("bool").unwrap(), DataType::Boolean); + assert_eq!(parse_spark_datatype("TINYINT").unwrap(), DataType::Int8); + assert_eq!(parse_spark_datatype("BYTE").unwrap(), DataType::Int8); + assert_eq!(parse_spark_datatype("SMALLINT").unwrap(), DataType::Int16); + assert_eq!(parse_spark_datatype("SHORT").unwrap(), DataType::Int16); + assert_eq!(parse_spark_datatype("INT").unwrap(), DataType::Int32); + assert_eq!(parse_spark_datatype("INTEGER").unwrap(), DataType::Int32); + assert_eq!(parse_spark_datatype("BIGINT").unwrap(), DataType::Int64); + assert_eq!(parse_spark_datatype("LONG").unwrap(), DataType::Int64); + assert_eq!(parse_spark_datatype("FLOAT").unwrap(), DataType::Float32); + assert_eq!(parse_spark_datatype("REAL").unwrap(), DataType::Float32); + assert_eq!(parse_spark_datatype("DOUBLE").unwrap(), DataType::Float64); + assert_eq!(parse_spark_datatype("STRING").unwrap(), DataType::Utf8); + assert_eq!(parse_spark_datatype("BINARY").unwrap(), DataType::Binary); + assert_eq!(parse_spark_datatype("DATE").unwrap(), DataType::Date32); + assert_eq!( + parse_spark_datatype("TIMESTAMP").unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!( + parse_spark_datatype("TIMESTAMP_NTZ").unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, None) + ); + } + + #[test] + fn test_parse_spark_datatype_decimal() { + assert_eq!( + parse_spark_datatype("DECIMAL").unwrap(), + DataType::Decimal128(10, 0) + ); + assert_eq!( + parse_spark_datatype("DECIMAL(18)").unwrap(), + DataType::Decimal128(18, 0) + ); + assert_eq!( + parse_spark_datatype("DECIMAL(10, 2)").unwrap(), + DataType::Decimal128(10, 2) + ); + assert_eq!( + parse_spark_datatype("DEC(38, 18)").unwrap(), + DataType::Decimal128(38, 18) + ); + assert_eq!( + parse_spark_datatype("NUMERIC(20, 5)").unwrap(), + DataType::Decimal128(20, 5) + ); + } + + #[test] + fn test_parse_spark_datatype_case_insensitive() { + assert_eq!(parse_spark_datatype("int").unwrap(), DataType::Int32); + assert_eq!(parse_spark_datatype(" INT ").unwrap(), DataType::Int32); + assert_eq!(parse_spark_datatype("Double").unwrap(), DataType::Float64); + } + + #[test] + fn test_parse_spark_datatype_invalid() { + assert!(parse_spark_datatype("UNKNOWN").is_err()); + assert!(parse_spark_datatype("").is_err()); + } + + #[test] + fn test_eval_mode_eq() { + assert_eq!(EvalMode::Legacy, EvalMode::Legacy); + assert_ne!(EvalMode::Legacy, EvalMode::Ansi); + assert_ne!(EvalMode::Ansi, EvalMode::Try); + } + + #[test] + fn test_div_floor() { + assert_eq!(div_floor(7, 2), 3); + assert_eq!(div_floor(-7, 2), -4); + assert_eq!(div_floor(7, -2), -4); + assert_eq!(div_floor(-7, -2), 3); + assert_eq!(div_floor(6, 2), 3); + assert_eq!(div_floor(0, 1), 0); + } + + #[test] + fn test_format_decimal_str() { + assert_eq!(format_decimal_str("12345", 5, 2), "123.45"); + assert_eq!(format_decimal_str("-1", 1, 3), "-0.001"); + assert_eq!(format_decimal_str("42", 2, 0), "42"); + } +} diff --git a/datafusion/spark/src/function/conversion/mod.rs b/datafusion/spark/src/function/conversion/mod.rs index a87df9a2c87a0..eba9b5460bc51 100644 --- a/datafusion/spark/src/function/conversion/mod.rs +++ b/datafusion/spark/src/function/conversion/mod.rs @@ -15,11 +15,36 @@ // specific language governing permissions and limitations // under the License. +pub mod cast; +pub mod cast_boolean; +pub mod cast_complex; +pub mod cast_datetime; +pub mod cast_numeric; +pub mod cast_string; +pub mod cast_utils; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(cast::SparkCast, spark_cast); +make_udf_function!(cast::SparkTryCast, spark_try_cast); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + spark_cast, + "Casts expr to the target type using Spark-compatible semantics.", + arg1 arg2 + )); + export_functions!(( + spark_try_cast, + "Casts expr to the target type using Spark TRY_CAST semantics (returns NULL on error).", + arg1 arg2 + )); +} pub fn functions() -> Vec> { - vec![] + vec![spark_cast(), spark_try_cast()] }