Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.paimon.partition.PartitionPredicate
import org.apache.paimon.predicate.{FullTextSearch, Predicate, TopN, VectorSearch}
import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
import org.apache.paimon.table.source.{DataSplit, Split}
import org.apache.paimon.types.RowType

import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.read.SupportsReportPartitioning
Expand All @@ -37,7 +38,8 @@ case class PaimonScan(
override val pushedTopN: Option[TopN],
override val pushedVectorSearch: Option[VectorSearch],
override val pushedFullTextSearch: Option[FullTextSearch] = None,
bucketedScanDisabled: Boolean = false)
bucketedScanDisabled: Boolean = false,
variantProjections: Map[String, RowType] = Map.empty)
extends PaimonBaseScan(table)
with SupportsReportPartitioning {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.PaimonSparkTestBase

import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GetStructField
import org.apache.spark.sql.catalyst.expressions.variant.VariantGet
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.types.{StructType, VariantType}

class VariantTest extends VariantTestBase {
override protected def sparkConf: SparkConf = {
Expand All @@ -31,3 +38,97 @@ class VariantInferShreddingTest extends VariantTestBase {
super.sparkConf.set("spark.paimon.variant.inferShreddingSchema", "true")
}
}

/**
* Spark 4-specific plan-shape tests for the PushDownVariantExtract optimizer rule.
*
* These tests verify the rewrite at the Catalyst expression level: VariantGet(col, Literal(path),
* targetType) → GetStructField(col, ordinal)
*
* and confirm that the variant column's output type in the scan relation changes from VariantType
* to StructType after the pushdown, which is the evidence that Parquet column pruning will kick in
* at read time.
*/
class VariantPushDownPlanTest extends PaimonSparkTestBase {

override protected def sparkConf: SparkConf =
super.sparkConf.set("spark.paimon.variant.inferShreddingSchema", "true")

// Explicit 3-field shredding schema used across all tests in this class.
private val shreddedSchema3: String =
"""{"type":"ROW","fields":[{"name":"v","type":{"type":"ROW","fields":[""" +
"""{"name":"age","type":"INT"},""" +
"""{"name":"city","type":"STRING"},""" +
"""{"name":"score","type":"DOUBLE"}""" +
"""]}}]}"""

test("Paimon Variant: VariantGet is replaced by GetStructField in optimized plan") {
sql(s"""
|CREATE TABLE T (id INT, v VARIANT)
|TBLPROPERTIES ('parquet.variant.shreddingSchema' = '$shreddedSchema3')
|""".stripMargin)
sql("""
|INSERT INTO T VALUES
| (1, parse_json('{"age":26,"city":"Beijing","score":9.5}')),
| (2, parse_json('{"age":27,"city":"Hangzhou","score":8.0}'))
|""".stripMargin)

val q =
"SELECT variant_get(v, '$.age', 'int'), variant_get(v, '$.score', 'double') FROM T"
checkAnswer(sql(q), Seq(Row(26, 9.5d), Row(27, 8.0d)))

val projectExprs = sql(q).queryExecution.optimizedPlan
.collectFirst { case p: Project => p }
.get
.projectList

// After pushdown, no VariantGet should remain in the top-level project list.
assert(
!projectExprs.exists(_.exists(_.isInstanceOf[VariantGet])),
"VariantGet should have been replaced by GetStructField after PushDownVariantExtract")

// GetStructField nodes must now be present in its place.
assert(
projectExprs.exists(_.exists(_.isInstanceOf[GetStructField])),
"expected GetStructField to appear in the optimized project list")
}

test("Paimon Variant: scan output type changes from VariantType to StructType after pushdown") {
sql(s"""
|CREATE TABLE T (id INT, v VARIANT)
|TBLPROPERTIES ('parquet.variant.shreddingSchema' = '$shreddedSchema3')
|""".stripMargin)
sql("INSERT INTO T VALUES (1, parse_json('{\"age\":26,\"city\":\"Beijing\",\"score\":9.5}'))")

// Without pushdown (direct reference): the variant column stays VariantType in the scan.
val qFull = "SELECT v FROM T"
val fullOutput = sql(qFull).queryExecution.optimizedPlan
.collectFirst {
case r: org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation => r
}
.get
.output
val variantColFull = fullOutput.find(_.name == "v").get
assert(
variantColFull.dataType == VariantType,
"without pushdown, the scan output type for 'v' should remain VariantType")

// With pushdown (only VariantGet accesses): the variant column becomes a StructType
// whose fields correspond only to the accessed sub-columns (age + score, not city).
val qPushed =
"SELECT variant_get(v, '$.age', 'int'), variant_get(v, '$.score', 'double') FROM T"
val pushedOutput = sql(qPushed).queryExecution.optimizedPlan
.collectFirst {
case r: org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation => r
}
.get
.output
val variantColPushed = pushedOutput.find(_.name == "v").get
assert(
variantColPushed.dataType.isInstanceOf[StructType],
"after pushdown, the scan output type for 'v' should be StructType (shredded sub-schema)")
assert(
variantColPushed.dataType.asInstanceOf[StructType].length == 2,
"the projected StructType should contain exactly 2 fields (age, score), not all 3")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.paimon.predicate.{FullTextSearch, Predicate, TopN, VectorSearc
import org.apache.paimon.spark.commands.BucketExpression.quote
import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
import org.apache.paimon.table.source.{DataSplit, Split}
import org.apache.paimon.types.RowType

import org.apache.spark.sql.PaimonUtils.fieldReference
import org.apache.spark.sql.connector.expressions._
Expand All @@ -43,11 +44,14 @@ case class PaimonScan(
override val pushedTopN: Option[TopN],
override val pushedVectorSearch: Option[VectorSearch],
override val pushedFullTextSearch: Option[FullTextSearch] = None,
bucketedScanDisabled: Boolean = false)
bucketedScanDisabled: Boolean = false,
variantProjections: Map[String, RowType] = Map.empty)
extends PaimonBaseScan(table)
with SupportsReportPartitioning
with SupportsReportOrdering {

override protected def variantProjectionMap: Map[String, RowType] = variantProjections

def disableBucketedScan(): PaimonScan = {
copy(bucketedScanDisabled = true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.paimon.spark.execution.{OldCompatibleStrategy, PaimonStrategy}
import org.apache.paimon.spark.execution.adaptive.DisableUnnecessaryPaimonBucketedScan

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.paimon.shims.SparkShimLoader

/** Spark session extension to extends the syntax and adds the rules. */
Expand Down Expand Up @@ -66,6 +68,26 @@ class PaimonSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
// optimization rules
extensions.injectOptimizerRule(_ => OptimizeMetadataOnlyDeleteFromPaimonTable)
extensions.injectOptimizerRule(_ => MergePaimonScalarSubqueries)
SparkShimLoader.shim.variantExtractRule().foreach {
rule =>
// PushDownVariantExtract must run AFTER V2ScanRelationPushDown converts
// DataSourceV2Relation to DataSourceV2ScanRelation in the "Early Filter and Projection
// Push-Down" batch. injectOptimizerRule places rules in the "Operator Optimization" batch
// (part of super.defaultBatches), which runs BEFORE the scan push-down. The only batch
// that runs after scan building is "User Provided Optimizers", populated via
// experimentalMethods.extraOptimizations. We register there via a side effect and return
// a no-op placeholder for the injectOptimizerRule slot.
extensions.injectOptimizerRule {
session =>
if (!session.experimental.extraOptimizations.exists(_ eq rule)) {
session.experimental.extraOptimizations =
session.experimental.extraOptimizations :+ rule
}
new Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan
}
}
}

// planner extensions
extensions.injectPlannerStrategy(spark => PaimonStrategy(spark))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,31 @@ trait BaseScan extends Scan with SupportsReportStatistics with Logging {
}
}

private[paimon] val (readTableRowType, metadataFields) = {
/** Hook for subclasses to provide variant column projections (colName -> VariantRowType). */
protected def variantProjectionMap: Map[String, RowType] = Map.empty

private[paimon] lazy val (readTableRowType, metadataFields) = {
requiredSchema.fields.foreach(f => checkMetadataColumn(f.name))
val (_requiredTableFields, _metadataFields) =
requiredSchema.fields.partition(field => tableRowType.containsField(field.name))
val _readTableRowType =
SparkTypeUtils.prunePaimonRowType(StructType(_requiredTableFields), tableRowType)
(_readTableRowType, _metadataFields)
val _finalReadType = applyVariantProjections(_readTableRowType, variantProjectionMap)
(_finalReadType, _metadataFields)
}

private def applyVariantProjections(
rowType: RowType,
projections: Map[String, RowType]): RowType = {
if (projections.isEmpty) return rowType
val newFields = rowType.getFields.asScala.map {
field =>
projections.get(field.name()) match {
case Some(variantRowType) => field.newType(variantRowType)
case None => field
}
}
rowType.copy(newFields.asJava)
}

private def checkMetadataColumn(fieldName: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ trait SparkShim {
notMatchedBySourceActions: Seq[MergeAction],
withSchemaEvolution: Boolean): MergeIntoTable

// for variant_get pushdown (Spark 4 only; returns None on Spark 3)
def variantExtractRule(): Option[Rule[LogicalPlan]] = None

// for variant
def toPaimonVariant(o: Object): Variant

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,4 +982,82 @@ abstract class VariantTestBase extends PaimonSparkTestBase {
Seq(Row(2, 2))
)
}

// ---- variant_get pushdown tests ----
// The shreddingSchema below declares 3 sub-fields (age, city, score) for column v.
// Queries that only touch a subset of them via variant_get should cause the optimizer
// to project only the needed sub-columns into the scan (PushDownVariantExtract rule).
private val shreddedSchema3 =
"""{"type":"ROW","fields":[{"name":"v","type":{"type":"ROW","fields":[""" +
"""{"name":"age","type":"INT"},""" +
"""{"name":"city","type":"STRING"},""" +
"""{"name":"score","type":"DOUBLE"}""" +
"""]}}]}"""

test("Paimon Variant: variant_get pushdown reduces projected sub-columns for shredded variant") {
sql(s"""
|CREATE TABLE T (id INT, v VARIANT)
|TBLPROPERTIES ('parquet.variant.shreddingSchema' = '$shreddedSchema3')
|""".stripMargin)
sql("""
|INSERT INTO T VALUES
| (1, parse_json('{"age":26,"city":"Beijing","score":9.5}')),
| (2, parse_json('{"age":27,"city":"Hangzhou","score":8.0}'))
|""".stripMargin)

// Query only uses 2 of 3 shredded fields (age + score, city is intentionally skipped).
val q =
"SELECT variant_get(v, '$.age', 'int'), variant_get(v, '$.score', 'double') FROM T ORDER BY id"
checkAnswer(sql(q), Seq(Row(26, 9.5d), Row(27, 8.0d)))

// Performance assertion: the scan should carry a projection with exactly 2 sub-fields,
// not all 3. Fewer projected sub-fields means fewer Parquet column reads at runtime.
val scan = getPaimonScan(q)
assert(
scan.variantProjections.contains("v"),
"PushDownVariantExtract should have populated variantProjections for column 'v'")
assert(
scan.variantProjections("v").getFieldCount == 2,
s"expected 2 projected sub-fields (age, score) out of 3 total, " +
s"got ${scan.variantProjections("v").getFieldCount}"
)
}

test("Paimon Variant: variant_get pushdown does not fire when variant column is read directly") {
sql(s"""
|CREATE TABLE T (id INT, v VARIANT)
|TBLPROPERTIES ('parquet.variant.shreddingSchema' = '$shreddedSchema3')
|""".stripMargin)
sql("INSERT INTO T VALUES (1, parse_json('{\"age\":26,\"city\":\"Beijing\",\"score\":9.5}'))")

// Direct reference to the full variant column must prevent pushdown for that column.
val q = "SELECT v, variant_get(v, '$.age', 'int') FROM T"
checkAnswer(
sql(q),
sql("SELECT parse_json('{\"age\":26,\"city\":\"Beijing\",\"score\":9.5}'), 26"))

val scan = getPaimonScan(q)
assert(
scan.variantProjections.isEmpty,
"pushdown must NOT fire when the variant column itself is projected")
}

test("Paimon Variant: variant_get pushdown deduplicates repeated access to the same path") {
sql(s"""
|CREATE TABLE T (id INT, v VARIANT)
|TBLPROPERTIES ('parquet.variant.shreddingSchema' = '$shreddedSchema3')
|""".stripMargin)
sql("INSERT INTO T VALUES (1, parse_json('{\"age\":26,\"city\":\"Beijing\",\"score\":9.5}'))")

// The same path accessed twice should still map to a single sub-column in the projection.
val q =
"SELECT variant_get(v, '$.age', 'int') AS a1, variant_get(v, '$.age', 'int') AS a2 FROM T"
checkAnswer(sql(q), Seq(Row(26, 26)))

val scan = getPaimonScan(q)
assert(scan.variantProjections.contains("v"))
assert(
scan.variantProjections("v").getFieldCount == 1,
"two accesses to the same path should deduplicate to a single projected sub-column")
}
}
Loading
Loading