Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
- [ ] sequence
- [ ] shuffle
- [ ] slice
- [ ] sort_array
- [x] sort_array

### bitwise_funcs

Expand Down
37 changes: 17 additions & 20 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[ArrayMin] -> CometArrayMin,
classOf[ArrayRemove] -> CometArrayRemove,
classOf[ArrayRepeat] -> CometArrayRepeat,
classOf[SortArray] -> CometSortArray,
classOf[ArraysOverlap] -> CometArraysOverlap,
classOf[ArrayUnion] -> CometArrayUnion,
classOf[CreateArray] -> CometCreateArray,
Expand Down Expand Up @@ -778,30 +779,26 @@ object QueryPlanSerde extends Logging with CometExprShim {
* TODO: Include SparkSQL's [[YearMonthIntervalType]] and [[DayTimeIntervalType]]
*/
// scalastyle:on
def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
def canRank(dt: DataType): Boolean = {
dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: BooleanType | _: BinaryType | _: StringType => true
case _ => false
}
def supportedScalarSortElementType(dt: DataType): Boolean = {
dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please combine all true branches?

_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: BooleanType | _: BinaryType | _: StringType =>
true
case _ =>
false
}
}

def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
if (sortOrder.length == 1) {
val canSort = sortOrder.head.dataType match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: BooleanType | _: BinaryType | _: StringType => true
case ArrayType(elementType, _) => canRank(elementType)
case MapType(_, valueType, _) => canRank(valueType)
case _ => false
case ArrayType(elementType, _) => supportedScalarSortElementType(elementType)
case MapType(_, valueType, _) => supportedScalarSortElementType(valueType)
case _ => supportedScalarSortElementType(sortOrder.head.dataType)
}
if (!canSort) {
withInfo(op, s"Sort on single column of type ${sortOrder.head.dataType} is not supported")
Expand Down
77 changes: 76 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ package org.apache.comet.serde

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometExprShim
Expand Down Expand Up @@ -200,6 +201,80 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
}
}

object CometSortArray extends CometExpressionSerde[SortArray] {
private def containsFloatingPoint(dt: DataType): Boolean = {
Comment thread
grorge123 marked this conversation as resolved.
dt match {
case FloatType | DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
}
}

private def supportedSortArrayElementType(
dt: DataType,
nestedInArray: Boolean = false): Boolean = {
dt match {
// DataFusion's array_sort compares nested arrays through Arrow's rank kernel.
// That kernel does not support Struct or Null child values,
// so array<array<struct<...>>> and array<array<null>> would fail at runtime.
case _: NullType if !nestedInArray =>
true
case ArrayType(elementType, _) =>
supportedSortArrayElementType(elementType, nestedInArray = true)
case StructType(fields) if !nestedInArray =>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment explaining why there is a restriction around structs in arrays?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have added it. Besides, I found nulltype has a similar problem, I have fixed it.

fields.forall(f => supportedSortArrayElementType(f.dataType))
case _ =>
supportedScalarSortElementType(dt)
}
}

override def getSupportLevel(expr: SortArray): SupportLevel = {
val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType

if (!supportedSortArrayElementType(elementType)) {
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
containsFloatingPoint(elementType)) {
Incompatible(
Some(
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
}

override def convert(
expr: SortArray,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding)
val (sortDirectionExprProto, nullOrderingExprProto) = expr.ascendingOrder match {
case Literal(value: Boolean, BooleanType) =>
Comment thread
grorge123 marked this conversation as resolved.
val direction = if (value) "ASC" else "DESC"
val nullOrdering = if (value) "NULLS FIRST" else "NULLS LAST"
(
exprToProtoInternal(Literal(direction), inputs, binding),
exprToProtoInternal(Literal(nullOrdering), inputs, binding))
case other =>
withInfo(expr, s"ascendingOrder must be a boolean literal: $other")
(None, None)
}

val sortArrayScalarExpr =
scalarFunctionExprToProto(
"array_sort",
arrayExprProto,
sortDirectionExprProto,
nullOrderingExprProto)
optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*)
}
}

object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] {

override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None)
Expand Down
Loading
Loading