diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f627b0c465..6f399fa9fd 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -149,7 +149,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Ascii] -> CometScalarFunction("ascii"), classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[Chr] -> CometScalarFunction("char"), - classOf[ConcatWs] -> CometScalarFunction("concat_ws"), + classOf[ConcatWs] -> CometConcatWs, classOf[Concat] -> CometConcat, classOf[Contains] -> CometScalarFunction("contains"), classOf[EndsWith] -> CometScalarFunction("ends_with"), diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index db60709007..871efd3702 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -199,6 +199,27 @@ object CometConcat extends CometScalarFunction[Concat]("concat") { } } +object CometConcatWs extends CometExpressionSerde[ConcatWs] { + + override def convert(expr: ConcatWs, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + expr.children.headOption match { + // Match Spark behavior: when the separator is NULL, the result of concat_ws is NULL. + case Some(Literal(null, _)) => + val nullLiteral = Literal.create(null, expr.dataType) + exprToProtoInternal(nullLiteral, inputs, binding) + + case _ if expr.children.forall(_.foldable) => + // Fall back to Spark for all-literal args so ConstantFolding can handle it + withInfo(expr, "all arguments are foldable") + None + + case _ => + // For all other cases, use the generic scalar function implementation. + CometScalarFunction[ConcatWs]("concat_ws").convert(expr, inputs, binding) + } + } +} + object CometLike extends CometExpressionSerde[Like] { override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { diff --git a/spark/src/test/resources/sql-tests/expressions/string/concat_ws.sql b/spark/src/test/resources/sql-tests/expressions/string/concat_ws.sql index 4a3df68965..81ccfb0f36 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/concat_ws.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/concat_ws.sql @@ -42,6 +42,6 @@ INSERT INTO names VALUES(1, 'James', 'B', 'Taylor'), (2, 'Smith', 'C', 'Davis'), query SELECT concat_ws(' ', first_name, middle_initial, last_name) FROM names --- literal + literal + literal -query ignore(https://github.com/apache/datafusion-comet/issues/3339) +-- literal + literal + literal (falls back to Spark when all args are foldable) +query spark_answer_only SELECT concat_ws(',', 'hello', 'world'), concat_ws(',', '', ''), concat_ws(',', NULL, 'b', 'c'), concat_ws(NULL, 'a', 'b')