Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ use crate::conversion_funcs::boolean::{
cast_boolean_to_decimal, cast_boolean_to_timestamp, is_df_cast_from_bool_spark_compatible,
};
use crate::conversion_funcs::numeric::{
cast_decimal_to_timestamp, cast_float32_to_decimal128, cast_float64_to_decimal128,
cast_float_to_timestamp, cast_int_to_decimal128, cast_int_to_timestamp,
is_df_cast_from_decimal_spark_compatible, is_df_cast_from_float_spark_compatible,
is_df_cast_from_int_spark_compatible, spark_cast_decimal_to_boolean,
spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, spark_cast_int_to_int,
spark_cast_nonintegral_numeric_to_integral,
cast_decimal128_to_utf8, cast_decimal_to_timestamp, cast_float32_to_decimal128,
cast_float64_to_decimal128, cast_float_to_timestamp, cast_int_to_decimal128,
cast_int_to_timestamp, is_df_cast_from_decimal_spark_compatible,
is_df_cast_from_float_spark_compatible, is_df_cast_from_int_spark_compatible,
spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, spark_cast_float64_to_utf8,
spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral,
};
use crate::conversion_funcs::string::{
cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int,
Expand Down Expand Up @@ -378,6 +378,12 @@ pub(crate) fn cast_array(
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, &from_type, to_type)
}
(Decimal128(_p, _s), Boolean) => spark_cast_decimal_to_boolean(&array),
// Spark LEGACY cast uses Java BigDecimal.toString() which produces scientific notation
// when adjusted_exponent < -6 (e.g. "0E-18" for zero with scale=18).
// TRY and ANSI use plain notation ("0.000000000000000000") so DataFusion handles those.
(Decimal128(_, scale), Utf8) if eval_mode == EvalMode::Legacy => {
cast_decimal128_to_utf8(&array, *scale)
}
(Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?),
(Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
Expand Down
89 changes: 87 additions & 2 deletions native/spark-expr/src/conversion_funcs/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::{
Array, ArrayRef, AsArray, BooleanBuilder, Decimal128Array, Decimal128Builder, Float32Array,
Float64Array, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array,
OffsetSizeTrait, PrimitiveArray, TimestampMicrosecondBuilder,
OffsetSizeTrait, PrimitiveArray, StringArray, TimestampMicrosecondBuilder,
};
use arrow::datatypes::{
i256, is_validate_decimal_precision, ArrowPrimitiveType, DataType, Decimal128Type, Float32Type,
Expand Down Expand Up @@ -71,7 +71,11 @@ pub(crate) fn is_df_cast_from_decimal_spark_compatible(to_type: &DataType) -> bo
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::Utf8 // note that there can be formatting differences
// DataFusion's Decimal128→Utf8 cast uses plain notation (toPlainString semantics),
// matching Spark's TRY and ANSI modes. LEGACY mode is handled by a separate match
// arm in cast_array that applies Java BigDecimal.toString() (scientific notation
// for values where adjusted_exponent < -6, e.g. "0E-18" for zero with scale=18).
| DataType::Utf8
)
}

Expand Down Expand Up @@ -569,6 +573,62 @@ pub(crate) fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -
}
}

/// Casts a Decimal128 array to string using Java's BigDecimal.toString() semantics,
/// which is Spark's LEGACY eval mode behavior. Plain notation when scale >= 0 and
/// adjusted_exponent >= -6, otherwise scientific notation (e.g. "0E-18" for zero
/// with scale=18, since adjusted_exponent = -18 + 0 = -18 < -6).
///
/// TRY and ANSI modes produce plain notation via DataFusion's cast instead.
pub(crate) fn cast_decimal128_to_utf8(array: &ArrayRef, scale: i8) -> SparkResult<ArrayRef> {
let decimal_array = array
.as_any()
.downcast_ref::<Decimal128Array>()
.expect("Expected a Decimal128Array");
let output: StringArray = decimal_array
.iter()
.map(|opt_val| opt_val.map(|unscaled| decimal128_to_java_string(unscaled, scale)))
.collect();
Ok(Arc::new(output))
}

/// Formats a Decimal128 unscaled value as a string matching Java's BigDecimal.toString():
/// - Plain notation when scale >= 0 and adjusted_exponent >= -6
/// - Scientific notation otherwise
///
/// adjusted_exponent = -scale + (numDigits - 1)
fn decimal128_to_java_string(unscaled: i128, scale: i8) -> String {
let negative = unscaled < 0;
let sign = if negative { "-" } else { "" };
let coeff = unscaled.unsigned_abs().to_string();
let num_digits = coeff.len() as i64;
let adj_exp = -(scale as i64) + (num_digits - 1);

if scale >= 0 && adj_exp >= -6 {
let scale_u = scale as usize;
let num_digits_u = num_digits as usize;
if scale_u == 0 {
format!("{sign}{coeff}")
} else if num_digits_u > scale_u {
let (int_part, frac_part) = coeff.split_at(num_digits_u - scale_u);
format!("{sign}{int_part}.{frac_part}")
} else {
let leading = scale_u - num_digits_u;
format!("{sign}0.{}{coeff}", "0".repeat(leading))
}
} else {
let mantissa = if num_digits == 1 {
coeff.clone()
} else {
format!("{}.{}", &coeff[..1], &coeff[1..])
};
if adj_exp > 0 {
format!("{sign}{mantissa}E+{adj_exp}")
} else {
format!("{sign}{mantissa}E{adj_exp}")
}
}
}

pub(crate) fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
Expand Down Expand Up @@ -1310,4 +1370,29 @@ mod tests {
let f64_inf: ArrayRef = Arc::new(Float64Array::from(vec![Some(f64::INFINITY)]));
assert!(cast_float_to_timestamp(&f64_inf, tz, EvalMode::Ansi).is_err());
}

#[test]
fn test_decimal128_to_java_string() {
// scale >= 0, adj_exp >= -6 → plain notation
assert_eq!(decimal128_to_java_string(0, 0), "0");
assert_eq!(decimal128_to_java_string(0, 2), "0.00");
assert_eq!(decimal128_to_java_string(12345, 2), "123.45");
assert_eq!(decimal128_to_java_string(-12345, 2), "-123.45");
assert_eq!(decimal128_to_java_string(1, 2), "0.01");
assert_eq!(decimal128_to_java_string(42, 0), "42");
assert_eq!(decimal128_to_java_string(-42, 0), "-42");
assert_eq!(decimal128_to_java_string(1, 6), "0.000001"); // adj_exp = -6 (boundary)

// scale >= 0, adj_exp < -6 → scientific notation (Spark LEGACY mode)
assert_eq!(decimal128_to_java_string(0, 18), "0E-18"); // adj_exp = -18
assert_eq!(decimal128_to_java_string(0, 7), "0E-7"); // adj_exp = -7
assert_eq!(decimal128_to_java_string(1, 7), "1E-7");
assert_eq!(decimal128_to_java_string(1, 18), "1E-18");

// scale < 0 → scientific notation
assert_eq!(decimal128_to_java_string(0, -2), "0E+2");
assert_eq!(decimal128_to_java_string(1, -2), "1E+2");
assert_eq!(decimal128_to_java_string(123, -2), "1.23E+4");
assert_eq!(decimal128_to_java_string(-123, -2), "-1.23E+4");
}
}
16 changes: 11 additions & 5 deletions spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,18 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
"There can be differences in precision. " +
"For example, the input \"1.4E-45\" will produce 1.0E-45 " +
"instead of 1.4E-45"))
case d: DecimalType if d.scale < 0 =>
// Negative-scale decimals require spark.sql.legacy.allowNegativeScaleOfDecimal=true.
// When that config is enabled, Spark formats them using Java BigDecimal.toString()
// which produces scientific notation (e.g. "1.23E+4"). Comet matches this behavior.
// When the config is disabled, negative-scale decimals cannot be created in Spark,
// so we mark this as incompatible to avoid native execution on unexpected inputs.
val allowNegativeScale = SQLConf.get
.getConfString("spark.sql.legacy.allowNegativeScaleOfDecimal", "false")
.toBoolean
if (allowNegativeScale) Compatible() else Incompatible()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1068
Compatible(
Some(
"There can be formatting differences in some case due to Spark using " +
"scientific notation where Comet does not"))
Compatible()
case DataTypes.BinaryType =>
Compatible()
case StructType(fields) =>
Expand Down
148 changes: 139 additions & 9 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, D

import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible
import org.apache.comet.serde.{Compatible, Incompatible}

class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

Expand Down Expand Up @@ -641,6 +641,73 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType)
}

test("cast DecimalType(38,18) to StringType") {
castTest(generateDecimalsPrecision38Scale18(), DataTypes.StringType)
}

test("cast DecimalType with negative scale to StringType") {
// Negative-scale decimals are a legacy Spark feature gated on
// spark.sql.legacy.allowNegativeScaleOfDecimal=true. Spark LEGACY cast uses Java's
// BigDecimal.toString() which produces scientific notation for negative-scale values
// (e.g. 12300 stored as Decimal(7,-2) with unscaled=123 → "1.23E+4").
// CometCast.canCastToString checks the
// config and returns Incompatible when it is false.
//
// Parquet does not support negative-scale decimals so we use checkSparkAnswer directly
// (no parquet round-trip) to avoid schema coercion.

// With config enabled, enable localTableScan so Comet can take over the full plan
// and execute the cast natively. Parquet does not support negative-scale decimals so
// the data is kept in-memory; localTableScan.enabled bridges that gap.
withSQLConf(
"spark.sql.legacy.allowNegativeScaleOfDecimal" -> "true",
"spark.comet.exec.localTableScan.enabled" -> "true") {
val dfNeg2 = Seq(
Some(BigDecimal("0")),
Some(BigDecimal("100")),
Some(BigDecimal("12300")),
Some(BigDecimal("-99900")),
Some(BigDecimal("9999900")),
None)
.toDF("b")
.withColumn("a", col("b").cast(DecimalType(7, -2)))
.drop("b")
.select(col("a").cast(DataTypes.StringType).as("result"))
checkSparkAnswerAndOperator(dfNeg2)

val dfNeg4 = Seq(
Some(BigDecimal("0")),
Some(BigDecimal("10000")),
Some(BigDecimal("120000")),
Some(BigDecimal("-9990000")),
None)
.toDF("b")
.withColumn("a", col("b").cast(DecimalType(7, -4)))
.drop("b")
.select(col("a").cast(DataTypes.StringType).as("result"))
checkSparkAnswerAndOperator(dfNeg4)
}

// With config disabled (default): the SQL parser rejects negative scale, so
// negative-scale decimals cannot be created through normal SQL paths.
// CometCast.isSupported returns Incompatible for this case, ensuring Comet does
// not attempt native execution if such a value ever reaches the planner.
// Note: DecimalType(7, -2) must be constructed while config=true, because the
// constructor itself checks the config and throws if negative scale is disallowed.
var negScaleType: DecimalType = null
withSQLConf("spark.sql.legacy.allowNegativeScaleOfDecimal" -> "true") {
negScaleType = DecimalType(7, -2)
}
withSQLConf("spark.sql.legacy.allowNegativeScaleOfDecimal" -> "false") {
assert(
CometCast.isSupported(
negScaleType,
DataTypes.StringType,
None,
CometEvalMode.LEGACY) == Incompatible())
}
}

test("cast DecimalType(10,2) to TimestampType") {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
}
Expand Down Expand Up @@ -1173,6 +1240,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("cast DateType to StringType") {
// generateDates() covers: 1970-2027 sampled monthly, DST transition dates, and edge
// cases including "999-01-01" (year < 1000, zero-padded to "0999-01-01") and
// "12345-01-01" (year > 9999, no truncation). Date→String is timezone-independent.
castTest(generateDates(), DataTypes.StringType)
}

Expand Down Expand Up @@ -1247,7 +1317,41 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("cast TimestampType to StringType") {
castTest(generateTimestamps(), DataTypes.StringType)
// UTC baseline — also exercises fractional-second trailing-zero stripping
// and pre-epoch values via generateTimestamps()
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
castTest(generateTimestamps(), DataTypes.StringType)
}
// Spark formats timestamps in the session timezone without tz suffix.
// pre_timestamp_cast shifts the UTC value by the session tz offset before
// passing to DataFusion, so DST-sensitive timezones must also be correct.
compatibleTimezones.foreach { tz =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
castTest(generateTimestamps(), DataTypes.StringType)
}
}
}

test("cast TimestampType to StringType - ancient timestamps") {
// Pre-1900 timestamps cannot go through Parquet (INT96 rejects them) so we create
// the data in-memory via microseconds-since-epoch cast to TimestampType.
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
// Epoch-micros for a few three-digit-year dates:
// 0100-03-01 00:00:00 UTC = -59,006,361,600,000,000 µs from epoch
// 0500-06-15 12:30:00 UTC = -46,374,377,400,000,000 µs from epoch
// 0999-12-31 23:59:59 UTC = -30,610,224,001,000,000 µs from epoch
val ancientMicros = Seq(
-59006361600000000L, // 0100-03-01
-46374377400000000L, // 0500-06-15
-30610224001000000L
) // 0999-12-31
ancientMicros
.toDF("micros")
.selectExpr("CAST(micros AS TIMESTAMP) AS a")
.createOrReplaceTempView("ancient_ts")
checkSparkAnswer(spark.sql("SELECT CAST(a AS STRING) FROM ancient_ts"))
checkSparkAnswer(spark.sql("SELECT CAST(a AS BIGINT) FROM ancient_ts"))
}
}

test("cast TimestampType to DateType") {
Expand Down Expand Up @@ -1465,6 +1569,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
BigDecimal("-2147483647.123123123"),
BigDecimal("-123456.789"),
BigDecimal("0.00000000000"),
// Small-magnitude non-zero: adj_exp = -9 + 0 = -9 < -6, so LEGACY produces
// scientific notation "1E-9" / "1.000000000E-9" rather than plain "0.000000001".
BigDecimal("0.000000001"),
BigDecimal("-0.000000001"),
BigDecimal("123456.789"),
// Int Max
BigDecimal("2147483647.123123123"),
Expand Down Expand Up @@ -1550,12 +1658,29 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateTimestamps(): DataFrame = {
val values =
Seq(
"2024-01-01T12:34:56.123456",
"2024-01-01T01:00:00Z",
"9999-12-31T01:00:00-02:00",
"2024-12-31T01:00:00+02:00")
val values = Seq(
// post-epoch with microseconds
"2024-01-01T12:34:56.123456",
// UTC, no fractional seconds (output has no decimal point)
"2024-01-01T01:00:00Z",
// year 9999 boundary
"9999-12-31T01:00:00-02:00",
// positive UTC offset
"2024-12-31T01:00:00+02:00",
// pre-epoch
"1960-01-01T00:00:00Z",
"1900-06-15T10:30:00Z",
// last microsecond before epoch
"1969-12-31T23:59:59.999999",
// fractional-second trailing-zero stripping
// .100000 → ".1", .123000 → ".123", .001000 → ".001", .000001 → ".000001"
"2024-06-01T00:00:00.100000",
"2024-06-01T00:00:00.123000",
"2024-06-01T00:00:00.001000",
"2024-06-01T00:00:00.000001",
// DST transition moments (America/New_York spring-forward / fall-back in UTC)
"2024-03-10T07:00:00Z",
"2024-11-03T06:00:00Z")
withNulls(values)
.toDF("str")
.withColumn("a", col("str").cast(DataTypes.TimestampType))
Expand Down Expand Up @@ -1763,7 +1888,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

private def roundtripParquet(df: DataFrame, tempDir: File): DataFrame = {
val filename = new File(tempDir, s"castTest_${System.currentTimeMillis()}.parquet").toString
df.write.mode(SaveMode.Overwrite).parquet(filename)
// CORRECTED mode writes timestamps as proleptic Gregorian without rebase.
// Required because generateTimestamps() includes pre-1900 values (e.g. 1900-06-15)
// which trigger INT96's default EXCEPTION mode when written with certain timezones.
withSQLConf("spark.sql.parquet.int96RebaseModeInWrite" -> "CORRECTED") {
df.write.mode(SaveMode.Overwrite).parquet(filename)
}
spark.read.parquet(filename)
}
}
Loading