diff --git a/Cargo.lock b/Cargo.lock index 5ab0b8c84a563..92e89a13bc959 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1745,7 +1745,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "arrow-schema", @@ -1817,7 +1817,7 @@ dependencies = [ [[package]] name = "datafusion-benchmarks" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "clap", @@ -1842,7 +1842,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -1865,7 +1865,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -1887,7 +1887,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -1918,7 +1918,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "52.1.0" +version = "52.0.0" dependencies = [ "ahash", "apache-avro", @@ -1945,7 +1945,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "52.1.0" +version = "52.0.0" dependencies = [ "futures", "log", @@ -1954,7 +1954,7 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-compression", @@ -1989,7 +1989,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-arrow" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "arrow-ipc", @@ -2012,7 +2012,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "52.1.0" +version = "52.0.0" dependencies = [ "apache-avro", "arrow", @@ -2031,7 +2031,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2052,7 +2052,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2072,7 +2072,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2101,11 +2101,11 @@ dependencies = [ [[package]] name = "datafusion-doc" -version = "52.1.0" +version = "52.0.0" [[package]] name = "datafusion-examples" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "arrow-flight", @@ -2129,6 +2129,7 @@ dependencies = [ "object_store", "prost", "rand 0.9.2", + "serde", "serde_json", "strum 0.27.2", "strum_macros 0.27.2", @@ -2144,7 +2145,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2165,7 +2166,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2189,7 +2190,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2200,7 +2201,7 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "52.1.0" +version = "52.0.0" dependencies = [ "abi_stable", "arrow", @@ -2234,7 +2235,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "arrow-buffer", @@ -2267,7 +2268,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "52.1.0" +version = "52.0.0" dependencies = [ "ahash", "arrow", @@ -2288,7 +2289,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "52.1.0" +version = "52.0.0" dependencies = [ "ahash", "arrow", @@ -2301,7 +2302,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "arrow-ord", @@ -2324,7 +2325,7 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2338,7 +2339,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2354,7 +2355,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "52.1.0" +version = "52.0.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2362,7 +2363,7 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "52.1.0" +version = "52.0.0" dependencies = [ "datafusion-doc", "quote", @@ -2371,7 +2372,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2398,7 +2399,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "52.1.0" +version = "52.0.0" dependencies = [ "ahash", "arrow", @@ -2425,7 +2426,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-adapter" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2438,7 +2439,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "52.1.0" +version = "52.0.0" dependencies = [ "ahash", "arrow", @@ -2453,7 +2454,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2473,7 +2474,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "52.1.0" +version = "52.0.0" dependencies = [ "ahash", "arrow", @@ -2509,7 +2510,7 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2539,6 +2540,7 @@ dependencies = [ "pbjson", "pretty_assertions", "prost", + "rand 0.9.2", "serde", "serde_json", "tokio", @@ -2546,7 +2548,7 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2558,7 +2560,7 @@ dependencies = [ [[package]] name = "datafusion-pruning" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2576,7 +2578,7 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "52.1.0" +version = "52.0.0" dependencies = [ "async-trait", "datafusion-common", @@ -2588,7 +2590,7 @@ dependencies = [ [[package]] name = "datafusion-spark" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "bigdecimal", @@ -2610,7 +2612,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "bigdecimal", @@ -2636,7 +2638,7 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "52.1.0" +version = "52.0.0" dependencies = [ "arrow", "async-trait", @@ -2667,7 +2669,7 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "52.1.0" +version = "52.0.0" dependencies = [ "async-recursion", "async-trait", @@ -2689,7 +2691,7 @@ dependencies = [ [[package]] name = "datafusion-wasmtest" -version = "52.1.0" +version = "52.0.0" dependencies = [ "chrono", "console_error_panic_hook", diff --git a/Cargo.toml b/Cargo.toml index 6424f512cc3df..7aef936189bc8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.88.0" # Define DataFusion version -version = "52.1.0" +version = "52.0.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -112,43 +112,43 @@ chrono = { version = "0.4.42", default-features = false } criterion = "0.8" ctor = "0.6.3" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "52.1.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "52.1.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "52.1.0" } -datafusion-common = { path = "datafusion/common", version = "52.1.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "52.1.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "52.1.0", default-features = false } -datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "52.1.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "52.1.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "52.1.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "52.1.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "52.1.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "52.1.0" } -datafusion-execution = { path = "datafusion/execution", version = "52.1.0", default-features = false } -datafusion-expr = { path = "datafusion/expr", version = "52.1.0", default-features = false } -datafusion-expr-common = { path = "datafusion/expr-common", version = "52.1.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "52.1.0" } -datafusion-functions = { path = "datafusion/functions", version = "52.1.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "52.1.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "52.1.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "52.1.0", default-features = false } -datafusion-functions-table = { path = "datafusion/functions-table", version = "52.1.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "52.1.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "52.1.0" } -datafusion-macros = { path = "datafusion/macros", version = "52.1.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "52.1.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "52.1.0", default-features = false } -datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "52.1.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "52.1.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "52.1.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "52.1.0" } -datafusion-proto = { path = "datafusion/proto", version = "52.1.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "52.1.0" } -datafusion-pruning = { path = "datafusion/pruning", version = "52.1.0" } -datafusion-session = { path = "datafusion/session", version = "52.1.0" } -datafusion-spark = { path = "datafusion/spark", version = "52.1.0" } -datafusion-sql = { path = "datafusion/sql", version = "52.1.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "52.1.0" } +datafusion = { path = "datafusion/core", version = "52.0.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "52.0.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "52.0.0" } +datafusion-common = { path = "datafusion/common", version = "52.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "52.0.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "52.0.0", default-features = false } +datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "52.0.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "52.0.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "52.0.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "52.0.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "52.0.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "52.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "52.0.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "52.0.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "52.0.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "52.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "52.0.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "52.0.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "52.0.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "52.0.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "52.0.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "52.0.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "52.0.0" } +datafusion-macros = { path = "datafusion/macros", version = "52.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "52.0.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "52.0.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "52.0.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "52.0.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "52.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "52.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "52.0.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "52.0.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "52.0.0" } +datafusion-session = { path = "datafusion/session", version = "52.0.0" } +datafusion-spark = { path = "datafusion/spark", version = "52.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "52.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "52.0.0" } doc-comment = "0.3" env_logger = "0.11" diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index b0190dadf3c3f..cdc45dea1a757 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -59,6 +59,7 @@ mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } rand = { workspace = true } +serde = { version = "1", features = ["derive"] } serde_json = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } diff --git a/datafusion-examples/examples/custom_data_source/adapter_serialization.rs b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs new file mode 100644 index 0000000000000..5b7122006f903 --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs @@ -0,0 +1,546 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods (`serialize_physical_plan` and `deserialize_physical_plan`) +//! to implement custom serialization logic. +//! +//! The key insight is that `FileScanConfig::expr_adapter_factory` is NOT serialized by +//! default. This example shows how to: +//! 1. Detect plans with custom adapters during serialization +//! 2. Wrap them as Extension nodes with JSON-serialized adapter metadata +//! 3. Unwrap and restore the adapter during deserialization +//! +//! This demonstrates nested serialization (protobuf outer, JSON inner) and the power +//! of the `PhysicalExtensionCodec` interception pattern. Both plan and expression +//! serialization route through the codec, enabling interception at every node in the tree. + +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::record_batch; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::assert_batches_eq; +use datafusion::common::{Result, not_impl_err}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, +}; +use datafusion::datasource::physical_plan::{FileScanConfig, FileScanConfigBuilder}; +use datafusion::datasource::source::DataSourceExec; +use datafusion::execution::TaskContext; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + PhysicalExtensionCodec, PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; +use datafusion_proto::protobuf::{ + PhysicalExprNode, PhysicalExtensionNode, PhysicalPlanNode, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; +use prost::Message; +use serde::{Deserialize, Serialize}; + +/// Example showing how to preserve custom adapter information during plan serialization. +/// +/// This demonstrates: +/// 1. Creating a custom PhysicalExprAdapter with metadata +/// 2. Using PhysicalExtensionCodec to intercept serialization +/// 3. Wrapping adapter info as Extension nodes +/// 4. Restoring adapters during deserialization +pub async fn adapter_serialization() -> Result<()> { + println!("=== PhysicalExprAdapter Serialization Example ===\n"); + + // Step 1: Create sample Parquet data in memory + println!("Step 1: Creating sample Parquet data..."); + let store = Arc::new(InMemory::new()) as Arc; + let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))?; + let path = Path::from("data.parquet"); + write_parquet(&store, &path, &batch).await?; + + // Step 2: Set up session with custom adapter + println!("Step 2: Setting up session with custom adapter..."); + let logical_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::clone(&store), + ); + + // Create a table with our custom MetadataAdapterFactory + let adapter_factory = Arc::new(MetadataAdapterFactory::new("v1")); + let listing_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///data.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(logical_schema) + .with_expr_adapter_factory( + Arc::clone(&adapter_factory) as Arc + ); + let table = ListingTable::try_new(listing_config)?; + ctx.register_table("my_table", Arc::new(table))?; + + // Step 3: Create physical plan with filter + println!("Step 3: Creating physical plan with filter..."); + let df = ctx.sql("SELECT * FROM my_table WHERE id > 5").await?; + let original_plan = df.create_physical_plan().await?; + + // Verify adapter is present in original plan + let has_adapter_before = verify_adapter_in_plan(&original_plan, "original"); + println!(" Original plan has adapter: {has_adapter_before}"); + + // Step 4: Serialize with our custom codec + println!("\nStep 4: Serializing plan with AdapterPreservingCodec..."); + let codec = AdapterPreservingCodec; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&original_plan), + &codec, + &codec, + )?; + println!(" Serialized {} bytes", bytes.len()); + println!(" (DataSourceExec with adapter was wrapped as PhysicalExtensionNode)"); + + // Step 5: Deserialize with our custom codec + println!("\nStep 5: Deserializing plan with AdapterPreservingCodec..."); + let task_ctx = ctx.task_ctx(); + let restored_plan = + physical_plan_from_bytes_with_proto_converter(&bytes, &task_ctx, &codec, &codec)?; + + // Verify adapter is restored + let has_adapter_after = verify_adapter_in_plan(&restored_plan, "restored"); + println!(" Restored plan has adapter: {has_adapter_after}"); + + // Step 6: Execute and compare results + println!("\nStep 6: Executing plans and comparing results..."); + let original_results = + datafusion::physical_plan::collect(Arc::clone(&original_plan), task_ctx.clone()) + .await?; + let restored_results = + datafusion::physical_plan::collect(restored_plan, task_ctx).await?; + + #[rustfmt::skip] + let expected = [ + "+----+", + "| id |", + "+----+", + "| 6 |", + "| 7 |", + "| 8 |", + "| 9 |", + "| 10 |", + "+----+", + ]; + + println!("\n Original plan results:"); + arrow::util::pretty::print_batches(&original_results)?; + assert_batches_eq!(expected, &original_results); + + println!("\n Restored plan results:"); + arrow::util::pretty::print_batches(&restored_results)?; + assert_batches_eq!(expected, &restored_results); + + println!("\n=== Example Complete! ==="); + println!("Key takeaways:"); + println!( + " 1. PhysicalExtensionCodec provides serialize_physical_plan/deserialize_physical_plan hooks" + ); + println!(" 2. Custom metadata can be wrapped as PhysicalExtensionNode"); + println!(" 3. Nested serialization (protobuf + JSON) works seamlessly"); + println!( + " 4. Both plans produce identical results despite serialization round-trip" + ); + println!(" 5. Adapters are fully preserved through the serialization round-trip"); + + Ok(()) +} + +// ============================================================================ +// MetadataAdapter - A simple custom adapter with a tag +// ============================================================================ + +/// A custom PhysicalExprAdapter that wraps another adapter. +/// The tag metadata is stored in the factory, not the adapter itself. +#[derive(Debug)] +struct MetadataAdapter { + inner: Arc, +} + +impl PhysicalExprAdapter for MetadataAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + // Simply delegate to inner adapter + self.inner.rewrite(expr) + } +} + +// ============================================================================ +// MetadataAdapterFactory - Factory for creating MetadataAdapter instances +// ============================================================================ + +/// Factory for creating MetadataAdapter instances. +/// The tag is stored in the factory and extracted via Debug formatting in `extract_adapter_tag`. +#[derive(Debug)] +struct MetadataAdapterFactory { + // Note: This field is read via Debug formatting in `extract_adapter_tag`. + // Rust's dead code analysis doesn't recognize Debug-based field access. + // In PR #19234, this field is used by `with_partition_values`, but that method + // doesn't exist in upstream DataFusion's PhysicalExprAdapter trait. + #[allow(dead_code)] + tag: String, +} + +impl MetadataAdapterFactory { + fn new(tag: impl Into) -> Self { + Self { tag: tag.into() } + } +} + +impl PhysicalExprAdapterFactory for MetadataAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + let inner = DefaultPhysicalExprAdapterFactory + .create(logical_file_schema, physical_file_schema); + Arc::new(MetadataAdapter { inner }) + } +} + +// ============================================================================ +// AdapterPreservingCodec - Custom codec that preserves adapters +// ============================================================================ + +/// Extension payload structure for serializing adapter info +#[derive(Serialize, Deserialize)] +struct ExtensionPayload { + /// Marker to identify this is our custom extension + marker: String, + /// JSON-serialized adapter metadata + adapter_metadata: AdapterMetadata, + /// Protobuf-serialized inner DataSourceExec (without adapter) + inner_plan_bytes: Vec, +} + +/// Metadata about the adapter to recreate it during deserialization +#[derive(Serialize, Deserialize)] +struct AdapterMetadata { + /// The adapter tag (e.g., "v1") + tag: String, +} + +const EXTENSION_MARKER: &str = "adapter_preserving_extension_v1"; + +/// A codec that intercepts serialization to preserve adapter information. +#[derive(Debug)] +struct AdapterPreservingCodec; + +impl PhysicalExtensionCodec for AdapterPreservingCodec { + // Required method: decode custom extension nodes + fn try_decode( + &self, + buf: &[u8], + _inputs: &[Arc], + ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + // Try to parse as our extension payload + if let Ok(payload) = serde_json::from_slice::(buf) + && payload.marker == EXTENSION_MARKER + { + // Decode the inner plan + let inner_proto = PhysicalPlanNode::decode(&payload.inner_plan_bytes[..]) + .map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to decode inner plan: {e}" + )) + })?; + + // Deserialize the inner plan using default implementation + let inner_plan = + inner_proto.try_into_physical_plan_with_converter(ctx, self, self)?; + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + not_impl_err!("Unknown extension type") + } + + // Required method: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { + // We don't need this for the example - we use serialize_physical_plan instead + not_impl_err!( + "try_encode not used - adapter wrapping happens in serialize_physical_plan" + ) + } +} + +impl PhysicalProtoConverterExtension for AdapterPreservingCodec { + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + // Check if this is a DataSourceExec with adapter + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && let Some(adapter_factory) = &config.expr_adapter_factory + && let Some(tag) = extract_adapter_tag(adapter_factory.as_ref()) + { + // Try to extract our MetadataAdapterFactory's tag + println!(" [Serialize] Found DataSourceExec with adapter tag: {tag}"); + + // 1. Create adapter metadata + let adapter_metadata = AdapterMetadata { tag }; + + // 2. Create a copy of the config without the adapter + let config_without_adapter = rebuild_config_without_adapter(config); + + // 3. Create a new DataSourceExec without adapter + let plan_without_adapter: Arc = + DataSourceExec::from_data_source(config_without_adapter); + + // 4. Serialize the inner plan to protobuf bytes + let inner_proto = PhysicalPlanNode::try_from_physical_plan_with_converter( + plan_without_adapter, + extension_codec, + self, + )?; + + let mut inner_bytes = Vec::new(); + inner_proto.encode(&mut inner_bytes).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to encode inner plan: {e}" + )) + })?; + + // 5. Create extension payload + let payload = ExtensionPayload { + marker: EXTENSION_MARKER.to_string(), + adapter_metadata, + inner_plan_bytes: inner_bytes, + }; + let payload_bytes = serde_json::to_vec(&payload).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to serialize payload: {e}" + )) + })?; + + // 6. Return as PhysicalExtensionNode + return Ok(PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + PhysicalExtensionNode { + node: payload_bytes, + inputs: vec![], // Leaf node + }, + )), + }); + } + + // No adapter found - use default serialization + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // Interception point: override deserialization to unwrap adapters + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + // Check if this is our custom extension wrapper + if let Some(PhysicalPlanType::Extension(extension)) = &proto.physical_plan_type + && let Ok(payload) = + serde_json::from_slice::(&extension.node) + && payload.marker == EXTENSION_MARKER + { + println!( + " [Deserialize] Found adapter extension with tag: {}", + payload.adapter_metadata.tag + ); + + // Decode the inner plan + let inner_proto = PhysicalPlanNode::decode(&payload.inner_plan_bytes[..]) + .map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to decode inner plan: {e}" + )) + })?; + + // Deserialize the inner plan using default implementation + let inner_plan = inner_proto.try_into_physical_plan_with_converter( + ctx, + extension_codec, + self, + )?; + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + // Not our extension - use default deserialization + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +// ============================================================================ +// Helper functions +// ============================================================================ + +/// Write a RecordBatch to Parquet in the object store +async fn write_parquet( + store: &dyn ObjectStore, + path: &Path, + batch: &arrow::record_batch::RecordBatch, +) -> Result<()> { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?; + writer.write(batch)?; + writer.close()?; + + let payload = PutPayload::from_bytes(buf.into()); + store.put(path, payload).await?; + Ok(()) +} + +/// Extract the tag from a MetadataAdapterFactory. +/// +/// Note: Since `PhysicalExprAdapterFactory` doesn't provide `as_any()` for downcasting, +/// we parse the Debug output. In a production system, you might add a dedicated trait +/// method for metadata extraction. +fn extract_adapter_tag(factory: &dyn PhysicalExprAdapterFactory) -> Option { + let debug_str = format!("{factory:?}"); + if debug_str.contains("MetadataAdapterFactory") { + // Extract tag from debug output: MetadataAdapterFactory { tag: "v1" } + if let Some(start) = debug_str.find("tag: \"") { + let after_tag = &debug_str[start + 6..]; + if let Some(end) = after_tag.find('"') { + return Some(after_tag[..end].to_string()); + } + } + } + None +} + +/// Create an adapter factory from a tag +fn create_adapter_factory(tag: &str) -> Arc { + Arc::new(MetadataAdapterFactory::new(tag)) +} + +/// Rebuild a FileScanConfig without the adapter +fn rebuild_config_without_adapter(config: &FileScanConfig) -> FileScanConfig { + FileScanConfigBuilder::from(config.clone()) + .with_expr_adapter(None) + .build() +} + +/// Inject an adapter into a plan (assumes plan is a DataSourceExec with FileScanConfig) +fn inject_adapter_into_plan( + plan: Arc, + adapter_factory: Arc, +) -> Result> { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = exec.data_source().as_any().downcast_ref::() + { + let new_config = FileScanConfigBuilder::from(config.clone()) + .with_expr_adapter(Some(adapter_factory)) + .build(); + return Ok(DataSourceExec::from_data_source(new_config)); + } + // If not a DataSourceExec with FileScanConfig, return as-is + Ok(plan) +} + +/// Helper to verify if a plan has an adapter (for testing/validation) +fn verify_adapter_in_plan(plan: &Arc, label: &str) -> bool { + // Walk the plan tree to find DataSourceExec with adapter + fn check_plan(plan: &dyn ExecutionPlan) -> bool { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && config.expr_adapter_factory.is_some() + { + return true; + } + // Check children + for child in plan.children() { + if check_plan(child.as_ref()) { + return true; + } + } + false + } + + let has_adapter = check_plan(plan.as_ref()); + println!(" [Verify] {label} plan adapter check: {has_adapter}"); + has_adapter +} diff --git a/datafusion-examples/examples/custom_data_source/main.rs b/datafusion-examples/examples/custom_data_source/main.rs index 5846626d81380..73e65182a0b59 100644 --- a/datafusion-examples/examples/custom_data_source/main.rs +++ b/datafusion-examples/examples/custom_data_source/main.rs @@ -26,6 +26,7 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module +//! - `adapter_serialization` — preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception //! - `csv_json_opener` — use low level FileOpener APIs to read CSV/JSON into Arrow RecordBatches //! - `csv_sql_streaming` — build and run a streaming query plan from a SQL statement against a local CSV file //! - `custom_datasource` — run queries against a custom datasource (TableProvider) @@ -34,6 +35,7 @@ //! - `default_column_values` — implement custom default value handling for missing columns using field metadata and PhysicalExprAdapter //! - `file_stream_provider` — run a query on FileStreamProvider which implements StreamProvider for reading and writing to arbitrary stream sources/sinks +mod adapter_serialization; mod csv_json_opener; mod csv_sql_streaming; mod custom_datasource; @@ -50,6 +52,7 @@ use strum_macros::{Display, EnumIter, EnumString, VariantNames}; #[strum(serialize_all = "snake_case")] enum ExampleKind { All, + AdapterSerialization, CsvJsonOpener, CsvSqlStreaming, CustomDatasource, @@ -74,6 +77,9 @@ impl ExampleKind { Box::pin(example.run()).await?; } } + ExampleKind::AdapterSerialization => { + adapter_serialization::adapter_serialization().await? + } ExampleKind::CsvJsonOpener => csv_json_opener::csv_json_opener().await?, ExampleKind::CsvSqlStreaming => { csv_sql_streaming::csv_sql_streaming().await? diff --git a/datafusion-examples/examples/proto/composed_extension_codec.rs b/datafusion-examples/examples/proto/composed_extension_codec.rs index f3910d461b6a8..7b4bcded6d0ad 100644 --- a/datafusion-examples/examples/proto/composed_extension_codec.rs +++ b/datafusion-examples/examples/proto/composed_extension_codec.rs @@ -43,6 +43,7 @@ use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion_proto::physical_plan::{ AsExecutionPlan, ComposedPhysicalExtensionCodec, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, }; use datafusion_proto::protobuf; @@ -140,6 +141,7 @@ impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { buf: &[u8], inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if buf == "ParentExec".as_bytes() { Ok(Arc::new(ParentExec { @@ -150,7 +152,12 @@ impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { } } - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { if node.as_any().downcast_ref::().is_some() { buf.extend_from_slice("ParentExec".as_bytes()); Ok(()) @@ -216,6 +223,7 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if buf == "ChildExec".as_bytes() { Ok(Arc::new(ChildExec {})) @@ -224,7 +232,12 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { } } - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { if node.as_any().downcast_ref::().is_some() { buf.extend_from_slice("ChildExec".as_bytes()); Ok(()) diff --git a/datafusion-examples/examples/proto/expression_deduplication.rs b/datafusion-examples/examples/proto/expression_deduplication.rs new file mode 100644 index 0000000000000..a591f41d8682a --- /dev/null +++ b/datafusion-examples/examples/proto/expression_deduplication.rs @@ -0,0 +1,277 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods to implement expression deduplication during deserialization. +//! +//! This pattern is inspired by PR #18192, which introduces expression caching +//! to reduce memory usage when deserializing plans with duplicate expressions. +//! +//! The key insight is that identical expressions serialize to identical protobuf bytes. +//! By caching deserialized expressions keyed by their protobuf bytes, we can: +//! 1. Return the same Arc for duplicate expressions +//! 2. Reduce memory allocation during deserialization +//! 3. Enable downstream optimizations that rely on Arc pointer equality +//! +//! This demonstrates the decorator pattern enabled by the `PhysicalExtensionCodec` trait, +//! where all expression serialization/deserialization routes through the codec methods. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::Operator; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::expressions::{BinaryExpr, col}; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion::prelude::SessionContext; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; +use prost::Message; + +/// Example showing how to implement expression deduplication using the codec decorator pattern. +/// +/// This demonstrates: +/// 1. Creating a CachingCodec that caches expressions by their protobuf bytes +/// 2. Intercepting deserialization to return cached Arcs for duplicate expressions +/// 3. Verifying that duplicate expressions share the same Arc after deserialization +/// +/// Deduplication is keyed by the protobuf bytes representing the expression, +/// in reality deduplication could be done based on e.g. the pointer address of the +/// serialized expression in memory, but this is simpler to demonstrate. +/// +/// In this case our expression is trivial and just for demonstration purposes. +/// In real scenarios, expressions can be much more complex, e.g. a large InList +/// expression could be megabytes in size, so deduplication can save significant memory +/// in addition to more correctly representing the original plan structure. +pub async fn expression_deduplication() -> Result<()> { + println!("=== Expression Deduplication Example ===\n"); + + // Create a schema for our test expressions + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); + + // Step 1: Create expressions with duplicates + println!("Step 1: Creating expressions with duplicates..."); + + // Create expression: col("a") + let a = col("a", &schema)?; + + // Create a clone to show duplicates + let a_clone = Arc::clone(&a); + + // Combine: a OR a_clone + let combined_expr = + Arc::new(BinaryExpr::new(a, Operator::Or, a_clone)) as Arc; + println!(" Created expression: a OR a with duplicates"); + println!(" Note: a appears twice in the expression tree\n"); + // Step 2: Create a filter plan with this expression + println!("Step 2: Creating physical plan with the expression..."); + + let input = Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))); + let filter_plan: Arc = + Arc::new(FilterExec::try_new(combined_expr, input)?); + + println!(" Created FilterExec with duplicate sub-expressions\n"); + + // Step 3: Serialize with the caching codec + println!("Step 3: Serializing plan..."); + + let extension_codec = DefaultPhysicalExtensionCodec {}; + let caching_converter = CachingCodec::new(); + let proto = + caching_converter.execution_plan_to_proto(&filter_plan, &extension_codec)?; + + // Serialize to bytes + let mut bytes = Vec::new(); + proto.encode(&mut bytes).unwrap(); + println!(" Serialized plan to {} bytes\n", bytes.len()); + + // Step 4: Deserialize with the caching codec + println!("Step 4: Deserializing plan with CachingCodec..."); + + let ctx = SessionContext::new(); + let deserialized_plan = proto.try_into_physical_plan_with_converter( + &ctx.task_ctx(), + &extension_codec, + &caching_converter, + )?; + + // Step 5: check that we deduplicated expressions + println!("Step 5: Checking for deduplicated expressions..."); + let Some(filter_exec) = deserialized_plan.as_any().downcast_ref::() + else { + panic!("Deserialized plan is not a FilterExec"); + }; + let predicate = Arc::clone(filter_exec.predicate()); + let binary_expr = predicate + .as_any() + .downcast_ref::() + .expect("Predicate is not a BinaryExpr"); + let left = &binary_expr.left(); + let right = &binary_expr.right(); + // Check if left and right point to the same Arc + let deduplicated = Arc::ptr_eq(left, right); + if deduplicated { + println!(" Success: Duplicate expressions were deduplicated!"); + println!( + " Cache Stats: hits={}, misses={}", + caching_converter.stats.read().unwrap().cache_hits, + caching_converter.stats.read().unwrap().cache_misses, + ); + } else { + println!(" Failure: Duplicate expressions were NOT deduplicated."); + } + + Ok(()) +} + +// ============================================================================ +// CachingCodec - Implements expression deduplication +// ============================================================================ + +/// Statistics for cache performance monitoring +#[derive(Debug, Default)] +struct CacheStats { + cache_hits: usize, + cache_misses: usize, +} + +/// A codec that caches deserialized expressions to enable deduplication. +/// +/// When deserializing, if we've already seen the same protobuf bytes, +/// we return the cached Arc instead of creating a new allocation. +#[derive(Debug, Default)] +struct CachingCodec { + /// Cache mapping protobuf bytes -> deserialized expression + expr_cache: RwLock, Arc>>, + /// Statistics for demonstration + stats: RwLock, +} + +impl CachingCodec { + fn new() -> Self { + Self::default() + } +} + +impl PhysicalExtensionCodec for CachingCodec { + // Required: decode custom extension nodes + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + datafusion::common::not_impl_err!("No custom extension nodes") + } + + // Required: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { + datafusion::common::not_impl_err!("No custom extension nodes") + } +} + +impl PhysicalProtoConverterExtension for CachingCodec { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // CACHING IMPLEMENTATION: Intercept expression deserialization + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + // Create cache key from protobuf bytes + let mut key = Vec::new(); + proto.encode(&mut key).map_err(|e| { + datafusion::error::DataFusionError::Internal(format!( + "Failed to encode proto for cache key: {e}" + )) + })?; + + // Check cache first + { + let cache = self.expr_cache.read().unwrap(); + if let Some(cached) = cache.get(&key) { + // Cache hit! Update stats and return cached Arc + let mut stats = self.stats.write().unwrap(); + stats.cache_hits += 1; + return Ok(Arc::clone(cached)); + } + } + + // Cache miss - deserialize and store + let expr = + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self)?; + + // Store in cache + { + let mut cache = self.expr_cache.write().unwrap(); + cache.insert(key, Arc::clone(&expr)); + let mut stats = self.stats.write().unwrap(); + stats.cache_misses += 1; + } + + Ok(expr) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} diff --git a/datafusion-examples/examples/proto/main.rs b/datafusion-examples/examples/proto/main.rs index f56078b31997d..16fcd70f34eba 100644 --- a/datafusion-examples/examples/proto/main.rs +++ b/datafusion-examples/examples/proto/main.rs @@ -27,8 +27,10 @@ //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module //! - `composed_extension_codec` — example of using multiple extension codecs for serialization / deserialization +//! - `expression_deduplication` — example of expression caching/deduplication using the codec decorator pattern mod composed_extension_codec; +mod expression_deduplication; use datafusion::error::{DataFusionError, Result}; use strum::{IntoEnumIterator, VariantNames}; @@ -39,6 +41,7 @@ use strum_macros::{Display, EnumIter, EnumString, VariantNames}; enum ExampleKind { All, ComposedExtensionCodec, + ExpressionDeduplication, } impl ExampleKind { @@ -59,6 +62,9 @@ impl ExampleKind { ExampleKind::ComposedExtensionCodec => { composed_extension_codec::composed_extension_codec().await? } + ExampleKind::ExpressionDeduplication => { + expression_deduplication::expression_deduplication().await? + } } Ok(()) } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 55a031d870122..860fe97f35bfa 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -698,10 +698,12 @@ impl DFSchema { // check nested fields match (dt1, dt2) { (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref() == v2.as_ref() + Self::datatype_is_logically_equal(v1.as_ref(), v2.as_ref()) + } + (DataType::Dictionary(_, v1), othertype) + | (othertype, DataType::Dictionary(_, v1)) => { + Self::datatype_is_logically_equal(v1.as_ref(), othertype) } - (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, - (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, (DataType::List(f1), DataType::List(f2)) | (DataType::LargeList(f1), DataType::LargeList(f2)) | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) => { @@ -1792,6 +1794,27 @@ mod tests { &DataType::Utf8, &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) )); + + // Dictionary is logically equal to logically equivalent value type + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Utf8View, + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + )); + + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8, false).into() + )) + ), + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8View, false).into() + )) + ) + )); } #[test] diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 7cedaf86cb52f..3fc980f386df8 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -44,6 +44,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, binary, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ @@ -62,8 +63,12 @@ use datafusion_physical_plan::execution_plan::ExecutionPlan; use datafusion_physical_plan::expressions::col; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::JoinOn; +use datafusion_physical_plan::joins::{ + DynamicFilterRoutingMode, HashJoinExec, PartitionMode, +}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ @@ -366,6 +371,60 @@ fn hash_join_exec( .unwrap() } +// Build a partitioned hash join for two inputs. +fn partitioned_hash_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, +) -> Arc { + hash_join_exec(left, right, join_on, join_type) +} + +// Traverses down the plan and returns the first hash join with the count of how +// many of its direct children are hash RepartitionExec nodes. +fn first_hash_join_and_direct_hash_repartition_children( + plan: &Arc, +) -> Option<(&HashJoinExec, usize)> { + if let Some(hash_join) = plan.as_any().downcast_ref::() { + let direct_hash_repartition_children = hash_join + .children() + .into_iter() + .filter(|child| { + child + .as_any() + .downcast_ref::() + .is_some_and(|repartition| { + matches!(repartition.partitioning(), Partitioning::Hash(_, _)) + }) + }) + .count(); + return Some((hash_join, direct_hash_repartition_children)); + } + + for child in plan.children() { + if let Some(result) = first_hash_join_and_direct_hash_repartition_children(child) + { + return Some(result); + } + } + None +} + +// Add RepartitionExec for the given input. +fn add_repartition( + input: Arc, + column_name: &str, + partition_count: usize, +) -> Arc { + let expr = Arc::new(Column::new_with_schema(column_name, &input.schema()).unwrap()) + as Arc; + Arc::new( + RepartitionExec::try_new(input, Partitioning::Hash(vec![expr], partition_count)) + .unwrap(), + ) +} + fn filter_exec(input: Arc) -> Arc { let predicate = Arc::new(BinaryExpr::new( col("c", &schema()).unwrap(), @@ -405,6 +464,23 @@ fn ensure_distribution_helper( ensure_distribution(distribution_context, &config).map(|item| item.data.plan) } +/// Like [`ensure_distribution_helper`] but uses bottom-up `transform_up`. +fn ensure_distribution_helper_transform_up( + plan: Arc, + target_partitions: usize, +) -> Result> { + let distribution_context = DistributionContext::new_default(plan); + let mut config = ConfigOptions::new(); + config.execution.target_partitions = target_partitions; + config.optimizer.enable_round_robin_repartition = false; + config.optimizer.repartition_file_scans = false; + config.optimizer.repartition_file_min_size = 1024; + config.optimizer.prefer_existing_sort = false; + distribution_context + .transform_up(|node| ensure_distribution(node, &config)) + .map(|item| item.data.plan) +} + fn test_suite_default_config_options() -> ConfigOptions { let mut config = ConfigOptions::new(); @@ -737,6 +813,195 @@ fn multi_hash_joins() -> Result<()> { Ok(()) } +// Verify that if the join inputs are not direct hash repartition children, +// enforce_distribution keeps direct children as file-partitioned scans. +#[test] +fn enforce_distribution_switches_to_partition_index_without_hash_repartition() +-> Result<()> { + let left = parquet_exec(); + let right = parquet_exec(); + + let join_on = vec![( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) + as Arc, + Arc::new(Column::new_with_schema("a", &right.schema()).unwrap()) + as Arc, + )]; + + let join = partitioned_hash_join_exec(left, right, &join_on, &JoinType::Inner); + + let optimized = ensure_distribution_helper_transform_up(join, 1)?; + assert_plan!(optimized, @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let (hash_join, direct_hash_repartition_children) = + first_hash_join_and_direct_hash_repartition_children(&optimized) + .expect("expected HashJoinExec"); + + assert_eq!( + hash_join.dynamic_filter_routing_mode, + DynamicFilterRoutingMode::PartitionIndex, + ); + assert_eq!(direct_hash_repartition_children, 0); + + Ok(()) +} + +#[test] +fn enforce_distribution_rejects_misaligned_left_repartitioned() -> Result<()> { + let left = parquet_exec_multiple(); + let right = parquet_exec(); + + let join_on = vec![( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) + as Arc, + Arc::new(Column::new_with_schema("a", &right.schema()).unwrap()) + as Arc, + )]; + + let join = partitioned_hash_join_exec(left, right, &join_on, &JoinType::Inner); + let result = ensure_distribution_helper_transform_up(join, 1); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("incompatible partitioning schemes"), + "Expected error about incompatible partitioning, got: {err}", + ); + + Ok(()) +} + +#[test] +fn enforce_distribution_rejects_misaligned_right_repartitioned() -> Result<()> { + let left = parquet_exec(); + let right = parquet_exec_multiple(); + + let join_on = vec![( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) + as Arc, + Arc::new(Column::new_with_schema("a", &right.schema()).unwrap()) + as Arc, + )]; + + let join = partitioned_hash_join_exec(left, right, &join_on, &JoinType::Inner); + let result = ensure_distribution_helper_transform_up(join, 1); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("incompatible partitioning schemes"), + "Expected error about incompatible partitioning, got: {err}", + ); + + Ok(()) +} + +#[test] +fn enforce_distribution_uses_case_hash_with_indirect_repartition() -> Result<()> { + let left = projection_exec_with_alias( + add_repartition(parquet_exec(), "a", 4), + vec![("a".to_string(), "a".to_string())], + ); + + let right = aggregate_exec_with_alias( + add_repartition(parquet_exec(), "a", 4), + vec![("a".to_string(), "a".to_string())], + ); + + let join_on = vec![( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) + as Arc, + Arc::new(Column::new_with_schema("a", &right.schema()).unwrap()) + as Arc, + )]; + + let join = partitioned_hash_join_exec(left, right, &join_on, &JoinType::Inner); + + let optimized = ensure_distribution_helper_transform_up(join, 4)?; + assert_plan!(optimized, @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 + ProjectionExec: expr=[a@0 as a] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + + let (hash_join, direct_hash_repartition_children) = + first_hash_join_and_direct_hash_repartition_children(&optimized) + .expect("expected HashJoinExec"); + assert_eq!( + hash_join.dynamic_filter_routing_mode, + DynamicFilterRoutingMode::CaseHash, + ); + assert_eq!(direct_hash_repartition_children, 1); + + Ok(()) +} + +// Verify hash repartition under an aliased branch off the top join key path is not counted as a +// direct repartition child of the top join. +#[test] +fn enforce_distribution_ignores_hash_repartition_off_dynamic_filter_path() -> Result<()> { + let lower_left = projection_exec_with_alias( + add_repartition(parquet_exec(), "a", 4), + vec![("a".to_string(), "a2".to_string())], + ); + let lower_right: Arc = parquet_exec(); + + let lower_join_on = vec![( + Arc::new(Column::new_with_schema("a2", &lower_left.schema()).unwrap()) + as Arc, + Arc::new(Column::new_with_schema("a", &lower_right.schema()).unwrap()) + as Arc, + )]; + + let lower_join: Arc = Arc::new( + HashJoinExec::try_new( + lower_left.clone(), + lower_right.clone(), + lower_join_on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let left = parquet_exec(); + let join_on = vec![( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) + as Arc, + Arc::new(Column::new_with_schema("a", &lower_join.schema()).unwrap()) + as Arc, + )]; + + let join = partitioned_hash_join_exec(left, lower_join, &join_on, &JoinType::Inner); + + let optimized = ensure_distribution_helper_transform_up(join, 1)?; + assert_plan!(optimized, @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@1)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a2@0, a@0)] + ProjectionExec: expr=[a@0 as a2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let (_, direct_hash_repartition_children) = + first_hash_join_and_direct_hash_repartition_children(&optimized) + .expect("expected HashJoinExec"); + assert_eq!(direct_hash_repartition_children, 0); + + Ok(()) +} + #[test] fn multi_joins_after_alias() -> Result<()> { let left = parquet_exec(); @@ -3652,8 +3917,15 @@ fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { // Create distribution context let dist_context = DistributionContext::new( spm_exec, - true, - vec![DistributionContext::new(parquet_exec, false, vec![])], + DistFlags { + dist_changing: true, + repartitioned: false, + }, + vec![DistributionContext::new( + parquet_exec, + DistFlags::default(), + vec![], + )], ); // Apply the function diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index d6357fdf6bc7d..f453cb3f46b90 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -976,6 +976,38 @@ async fn test_topk_filter_passes_through_coalesce_batches() { ); } +/// Returns a `SessionConfig` with dynamic filter pushdown enabled and `batch_size=10`. +fn dynamic_filter_session_config() -> SessionConfig { + let mut config = SessionConfig::new().with_batch_size(10); + config.options_mut().execution.parquet.pushdown_filters = true; + config + .options_mut() + .optimizer + .enable_dynamic_filter_pushdown = true; + config +} + +/// Optimizes the plan with `FilterPushdown`, creates a session context, and collects results. +async fn optimize_and_collect( + plan: Arc, + session_config: SessionConfig, +) -> (Arc, Vec) { + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + (plan, batches) +} + #[tokio::test] async fn test_hashjoin_dynamic_filter_pushdown() { use datafusion_common::JoinType; @@ -1373,6 +1405,158 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { ); } +#[tokio::test] +async fn test_partitioned_hashjoin_no_repartition_dynamic_filter_pushdown() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{ + DynamicFilterRoutingMode, HashJoinExec, PartitionMode, + }; + + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + + let build_p0 = vec![ + record_batch!( + ("a", Utf8, ["aa", "kk"]), + ("b", Utf8, ["ba", "gg"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let build_p1 = vec![ + record_batch!( + ("a", Utf8, ["zz"]), + ("b", Utf8, ["zz"]), + ("c", Float64, [2.0]) + ) + .unwrap(), + ]; + + let probe_p0 = vec![ + record_batch!( + ("a", Utf8, ["aa", "kk"]), + ("b", Utf8, ["ba", "gg"]), + ("e", Float64, [10.0, 20.0]) + ) + .unwrap(), + ]; + let probe_p1 = vec![ + record_batch!( + ("a", Utf8, ["zz", "zz"]), + ("b", Utf8, ["zz", "zz"]), + ("e", Float64, [30.0, 40.0]) + ) + .unwrap(), + ]; + + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches_for_partition(build_p0) + .with_batches_for_partition(build_p1) + .with_file_group(FileGroup::new(vec![PartitionedFile::new( + "build_0.parquet", + 123, + )])) + .with_file_group(FileGroup::new(vec![PartitionedFile::new( + "build_1.parquet", + 123, + )])) + .build(); + + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches_for_partition(probe_p0) + .with_batches_for_partition(probe_p1) + .with_file_group(FileGroup::new(vec![PartitionedFile::new( + "probe_0.parquet", + 123, + )])) + .with_file_group(FileGroup::new(vec![PartitionedFile::new( + "probe_1.parquet", + 123, + )])) + .build(); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + Arc::clone(&probe_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap() + .with_dynamic_filter_routing_mode(DynamicFilterRoutingMode::PartitionIndex), + ); + + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), + )]) + .unwrap(), + cp, + )) as Arc; + + let mut session_config = dynamic_filter_session_config(); + session_config + .options_mut() + .optimizer + .preserve_file_partitions = 1; + let (plan, batches) = optimize_and_collect(plan, session_config).await; + + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={2 groups: [[build_0.parquet], [build_1.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={2 groups: [[probe_0.parquet], [probe_1.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ {0: a@0 >= aa AND a@0 <= kk AND b@1 >= ba AND b@1 <= gg AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:kk,c1:gg}]), 1: a@0 >= zz AND a@0 <= zz AND b@1 >= zz AND b@1 <= zz AND struct(a@0, b@1) IN (SET) ([{c0:zz,c1:zz}])} ] + " + ); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 4); + + insta::assert_snapshot!( + format!("{}", pretty_format_batches(&batches).unwrap()), + @r" + +----+----+-----+----+----+------+ + | a | b | c | a | b | e | + +----+----+-----+----+----+------+ + | zz | zz | 2.0 | zz | zz | 30.0 | + | zz | zz | 2.0 | zz | zz | 40.0 | + | kk | gg | 2.0 | kk | gg | 20.0 | + | aa | ba | 1.0 | aa | ba | 10.0 | + +----+----+-----+----+----+------+ + " + ); +} + #[tokio::test] async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { use datafusion_common::JoinType; @@ -1705,6 +1889,129 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { ); } +#[tokio::test] +async fn test_nested_hashjoin_with_repartition_dynamic_filters() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let t1_batches = vec![ + record_batch!(("a", Utf8, ["aa", "ab"]), ("x", Float64, [1.0, 2.0])).unwrap(), + ]; + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("x", DataType::Float64, false), + ])); + let t1_scan = TestScanBuilder::new(Arc::clone(&t1_schema)) + .with_support(true) + .with_batches(t1_batches) + .build(); + + let t2_batches = vec![ + record_batch!( + ("b", Utf8, ["aa", "ab", "ac", "ad", "ae"]), + ("c", Utf8, ["ca", "cb", "cc", "cd", "ce"]), + ("y", Float64, [1.0, 2.0, 3.0, 4.0, 5.0]) + ) + .unwrap(), + ]; + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + Field::new("y", DataType::Float64, false), + ])); + let t2_scan = TestScanBuilder::new(Arc::clone(&t2_schema)) + .with_support(true) + .with_batches(t2_batches) + .build(); + + let t3_batches = vec![ + record_batch!( + ("d", Utf8, ["ca", "cb", "cc", "cd", "ce", "cf", "cg", "ch"]), + ("z", Float64, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + ) + .unwrap(), + ]; + let t3_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("z", DataType::Float64, false), + ])); + let t3_scan = TestScanBuilder::new(Arc::clone(&t3_schema)) + .with_support(true) + .with_batches(t3_batches) + .build(); + + let join1_on = vec![(col("c", &t2_schema).unwrap(), col("d", &t3_schema).unwrap())]; + let join1 = Arc::new( + HashJoinExec::try_new( + t2_scan, + t3_scan, + join1_on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) as Arc; + + let join1_schema = join1.schema(); + let repartition = Arc::new( + RepartitionExec::try_new( + join1, + Partitioning::Hash(vec![col("b", &join1_schema).unwrap()], 1), + ) + .unwrap(), + ) as Arc; + + let join2_on = vec![( + col("b", &repartition.schema()).unwrap(), + col("a", &t1_schema).unwrap(), + )]; + let join2 = Arc::new( + HashJoinExec::try_new( + repartition, + Arc::clone(&t1_scan), + join2_on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) as Arc; + + let mut session_config = dynamic_filter_session_config(); + session_config + .options_mut() + .optimizer + .preserve_file_partitions = 1; + + let plan = FilterPushdown::new_post_optimization() + .optimize(join2, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + stream.next().await.unwrap().unwrap(); + + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter"), + "expected dynamic filter on nested probe side after repartition:\n{plan_str}" + ); + assert!( + plan_str.contains("projection=[a, x], file_type=test, pushdown_supported=true, predicate=DynamicFilter"), + "expected dynamic filter on top-level probe side after repartition:\n{plan_str}" + ); +} + #[tokio::test] async fn test_hashjoin_parent_filter_pushdown() { use datafusion_common::JoinType; @@ -3587,7 +3894,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() { // Verify that a dynamic filter was created let dynamic_filter = hash_join - .dynamic_filter_for_test() + .dynamic_filter() .expect("Dynamic filter should be created"); // Verify that is_used() returns the expected value based on probe side support. diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs index 1afdc4823f0a4..b25c9337156e8 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -20,10 +20,12 @@ use arrow::{array::RecordBatch, compute::concat_batches}; use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; use datafusion_common::{Result, config::ConfigOptions, internal_err}; use datafusion_datasource::{ - PartitionedFile, file::FileSource, file_scan_config::FileScanConfig, - file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, - file_stream::FileOpener, source::DataSourceExec, + PartitionedFile, file::FileSource, file_groups::FileGroup, + file_scan_config::FileScanConfig, file_scan_config::FileScanConfigBuilder, + file_stream::FileOpenFuture, file_stream::FileOpener, source::DataSourceExec, }; +use datafusion_physical_expr::expressions::DynamicFilterRuntimeContext; +use datafusion_physical_expr_common::physical_expr::bind_runtime_physical_expr; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; @@ -52,6 +54,7 @@ pub struct TestOpener { batch_size: Option, projection: Option>, predicate: Option>, + partition: usize, } impl FileOpener for TestOpener { @@ -71,9 +74,16 @@ impl FileOpener for TestOpener { batches = new_batches.into_iter().collect(); } + let runtime_ctx = DynamicFilterRuntimeContext::for_partition(self.partition); + let predicate = self + .predicate + .clone() + .map(|expr| bind_runtime_physical_expr(expr, &runtime_ctx)) + .transpose()?; + let mut new_batches = Vec::new(); for batch in batches { - let batch = if let Some(predicate) = &self.predicate { + let batch = if let Some(predicate) = &predicate { batch_filter(&batch, predicate)? } else { batch @@ -102,6 +112,7 @@ pub struct TestSource { predicate: Option>, batch_size: Option, batches: Vec, + per_partition_batches: Option>>, metrics: ExecutionPlanMetricsSet, projection: Option>, table_schema: datafusion_datasource::TableSchema, @@ -109,12 +120,22 @@ pub struct TestSource { impl TestSource { pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { + Self::new_with_partitions(schema, support, batches, None) + } + + pub fn new_with_partitions( + schema: SchemaRef, + support: bool, + batches: Vec, + per_partition_batches: Option>>, + ) -> Self { let table_schema = datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), batches, + per_partition_batches, predicate: None, batch_size: None, projection: None, @@ -128,13 +149,20 @@ impl FileSource for TestSource { &self, _object_store: Arc, _base_config: &FileScanConfig, - _partition: usize, + partition: usize, ) -> Result> { + let batches = if let Some(ref per_partition) = self.per_partition_batches { + per_partition.get(partition).cloned().unwrap_or_default() + } else { + self.batches.clone() + }; + Ok(Arc::new(TestOpener { - batches: self.batches.clone(), + batches, batch_size: self.batch_size, projection: self.projection.clone(), predicate: self.predicate.clone(), + partition, })) } @@ -219,7 +247,9 @@ impl FileSource for TestSource { pub struct TestScanBuilder { support: bool, batches: Vec, + per_partition_batches: Vec>, schema: SchemaRef, + file_groups: Vec, } impl TestScanBuilder { @@ -227,7 +257,9 @@ impl TestScanBuilder { Self { support: false, batches: vec![], + per_partition_batches: vec![], schema, + file_groups: vec![], } } @@ -241,17 +273,39 @@ impl TestScanBuilder { self } + pub fn with_batches_for_partition(mut self, batches: Vec) -> Self { + self.per_partition_batches.push(batches); + self + } + + pub fn with_file_group(mut self, group: FileGroup) -> Self { + self.file_groups.push(group); + self + } + pub fn build(self) -> Arc { - let source = Arc::new(TestSource::new( + let per_partition_batches = if self.per_partition_batches.is_empty() { + None + } else { + Some(self.per_partition_batches) + }; + + let source = Arc::new(TestSource::new_with_partitions( Arc::clone(&self.schema), self.support, self.batches, + per_partition_batches, )); - let base_config = - FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source) - .with_file(PartitionedFile::new("test.parquet", 123)) - .build(); - DataSourceExec::from_data_source(base_config) + let mut builder = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source); + if self.file_groups.is_empty() { + builder = builder.with_file(PartitionedFile::new("test.parquet", 123)); + } else { + for group in self.file_groups { + builder = builder.with_file_group(group); + } + } + DataSourceExec::from_data_source(builder.build()) } } diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 83bdf79c8fcc0..de1c093c433ae 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -26,6 +26,7 @@ use crate::{ use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::DataType; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; +use datafusion_physical_expr::expressions::DynamicFilterRuntimeContext; use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr::utils::reassign_expr_columns; use datafusion_physical_expr_adapter::replace_columns_with_literals; @@ -44,7 +45,7 @@ use datafusion_datasource::{PartitionedFile, TableSchema}; use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use datafusion_physical_expr_common::physical_expr::{ - PhysicalExpr, is_dynamic_physical_expr, + PhysicalExpr, bind_runtime_physical_expr, is_dynamic_physical_expr, }; use datafusion_physical_plan::metrics::{ Count, ExecutionPlanMetricsSet, MetricBuilder, PruningMetrics, @@ -255,6 +256,15 @@ impl FileOpener for ParquetOpener { .transpose()?; } + // Bind runtime context for this opener partition. + // For partition-index dynamic filters this binds probe partition `i` + // to build-side filter `i`. + let runtime_ctx = + DynamicFilterRuntimeContext::for_partition(self.partition_index); + predicate = predicate + .map(|p| bind_runtime_physical_expr(p, &runtime_ctx)) + .transpose()?; + let reorder_predicates = self.reorder_filters; let pushdown_filters = self.pushdown_filters; let force_filter_selections = self.force_filter_selections; @@ -981,7 +991,9 @@ mod test { use datafusion_expr::{col, lit}; use datafusion_physical_expr::{ PhysicalExpr, - expressions::{Column, DynamicFilterPhysicalExpr, Literal}, + expressions::{ + BinaryExpr, Column, DynamicFilterPhysicalExpr, DynamicFilterUpdate, Literal, + }, planner::logical2physical, projection::ProjectionExprs, }; @@ -1068,6 +1080,12 @@ mod test { self } + /// Set the partition index. + fn with_partition_index(mut self, index: usize) -> Self { + self.partition_index = index; + self + } + /// Set the predicate. fn with_predicate(mut self, predicate: Arc) -> Self { self.predicate = Some(predicate); @@ -1954,4 +1972,114 @@ mod test { "Reverse scan with non-contiguous row groups should correctly map RowSelection" ); } + + #[tokio::test] + async fn test_partition_bind_in_opener() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "test.parquet", batch.clone()).await; + + let schema = batch.schema(); + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + + let col_a = Arc::new(Column::new("a", 0)) as Arc; + + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + datafusion_physical_expr::expressions::lit(true) as Arc, + )); + + let p0_filter = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + datafusion_physical_expr::expressions::lit(1i32) as Arc, + )) as Arc, + datafusion_expr::Operator::And, + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::LtEq, + datafusion_physical_expr::expressions::lit(2i32) as Arc, + )) as Arc, + )) as Arc; + + let p1_filter = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + datafusion_physical_expr::expressions::lit(10i32) + as Arc, + )) as Arc, + datafusion_expr::Operator::And, + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::LtEq, + datafusion_physical_expr::expressions::lit(20i32) + as Arc, + )) as Arc, + )) as Arc; + + let p2_filter = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + datafusion_physical_expr::expressions::lit(3i32) as Arc, + )) as Arc, + datafusion_expr::Operator::And, + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::LtEq, + datafusion_physical_expr::expressions::lit(3i32) as Arc, + )) as Arc, + )) as Arc; + + dynamic_filter + .update(DynamicFilterUpdate::Partitioned(vec![ + Some(p0_filter), + Some(p1_filter), + Some(p2_filter), + ])) + .unwrap(); + + let opener_p0 = ParquetOpenerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_predicate(Arc::clone(&dynamic_filter) as Arc) + .with_partition_index(0) + .with_row_group_stats_pruning(true) + .build(); + let stream = opener_p0.open(file.clone()).unwrap().await.unwrap(); + let (_, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_rows, 3); + + let opener_p1 = ParquetOpenerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_predicate(Arc::clone(&dynamic_filter) as Arc) + .with_partition_index(1) + .with_row_group_stats_pruning(true) + .build(); + let stream = opener_p1.open(file.clone()).unwrap().await.unwrap(); + let (_, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_rows, 0); + + let opener_p2 = ParquetOpenerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_predicate(Arc::clone(&dynamic_filter) as Arc) + .with_partition_index(2) + .with_row_group_stats_pruning(true) + .build(); + let stream = opener_p2.open(file).unwrap().await.unwrap(); + let (_, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_rows, 3); + } } diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 2e0919b1447de..10c7d05d24806 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -583,6 +583,10 @@ impl FileSource for ParquetSource { self.predicate.clone() } + fn with_filter(&self, filter: Arc) -> Option> { + Some(Arc::new(self.with_predicate(filter))) + } + fn with_batch_size(&self, batch_size: usize) -> Arc { let mut conf = self.clone(); conf.batch_size = Some(batch_size); diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index f5380c27ecc28..e4ea677f98f16 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -74,6 +74,10 @@ pub trait FileSource: Send + Sync { fn filter(&self) -> Option> { None } + /// Return a new [`FileSource`] with the specified filter, if supported. + fn with_filter(&self, _filter: Arc) -> Option> { + None + } /// Return the projection that will be applied to the output stream on top of the table schema. fn projection(&self) -> Option<&ProjectionExprs> { None diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index c8636343ccc5a..1d21a89a2b7e8 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -1351,7 +1351,7 @@ mod tests { use arrow::datatypes::Field; use datafusion_common::stats::Precision; - use datafusion_common::{ColumnStatistics, internal_err}; + use datafusion_common::{ColumnStatistics, ScalarValue, internal_err}; use datafusion_expr::{Operator, SortExpr}; use datafusion_physical_expr::create_physical_sort_expr; use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; diff --git a/datafusion/datasource/src/test_util.rs b/datafusion/datasource/src/test_util.rs index c8d5dd54cb8a2..2889b4996e34a 100644 --- a/datafusion/datasource/src/test_util.rs +++ b/datafusion/datasource/src/test_util.rs @@ -84,6 +84,12 @@ impl FileSource for MockSource { self.filter.clone() } + fn with_filter(&self, filter: Arc) -> Option> { + let mut source = self.clone(); + source.filter = Some(filter); + Some(Arc::new(source)) + } + fn with_batch_size(&self, _batch_size: usize) -> Arc { Arc::new(Self { ..self.clone() }) } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index a0faca76e91e4..f5f49c643c285 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -260,11 +260,7 @@ fn coerce_exprs_for_schema( } #[expect(deprecated)] Expr::Wildcard { .. } => Ok(expr), - _ => { - // maintain the original name when casting - let name = dst_schema.field(idx).name(); - Ok(expr.cast_to(new_type, src_schema)?.alias(name)) - } + _ => expr.cast_to(new_type, src_schema), } } else { Ok(expr) diff --git a/datafusion/ffi/src/proto/physical_extension_codec.rs b/datafusion/ffi/src/proto/physical_extension_codec.rs index 0577e72366478..7b579862e1bd2 100644 --- a/datafusion/ffi/src/proto/physical_extension_codec.rs +++ b/datafusion/ffi/src/proto/physical_extension_codec.rs @@ -26,7 +26,10 @@ use datafusion_expr::{ AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl, }; use datafusion_physical_plan::ExecutionPlan; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use datafusion_proto::physical_plan::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; use tokio::runtime::Handle; use crate::execution::FFI_TaskContextProvider; @@ -141,8 +144,12 @@ unsafe extern "C" fn try_decode_fn_wrapper( .collect::>>(); let inputs = rresult_return!(inputs); - let plan = - rresult_return!(codec.try_decode(buf.as_ref(), &inputs, task_ctx.as_ref())); + let plan = rresult_return!(codec.try_decode( + buf.as_ref(), + &inputs, + task_ctx.as_ref(), + &DefaultPhysicalProtoConverter + )); RResult::ROk(FFI_ExecutionPlan::new(plan, None)) } @@ -156,7 +163,7 @@ unsafe extern "C" fn try_encode_fn_wrapper( let plan: Arc = rresult_return!((&node).try_into()); let mut bytes = Vec::new(); - rresult_return!(codec.try_encode(plan, &mut bytes)); + rresult_return!(codec.try_encode(plan, &mut bytes, &DefaultPhysicalProtoConverter)); RResult::ROk(bytes.into()) } @@ -327,6 +334,7 @@ impl PhysicalExtensionCodec for ForeignPhysicalExtensionCodec { buf: &[u8], inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let inputs = inputs .iter() @@ -340,7 +348,12 @@ impl PhysicalExtensionCodec for ForeignPhysicalExtensionCodec { Ok(plan) } - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { let plan = FFI_ExecutionPlan::new(node, None); let bytes = df_result!(unsafe { (self.0.try_encode)(&self.0, plan) })?; @@ -418,7 +431,10 @@ pub(crate) mod tests { use datafusion_functions_aggregate::sum::Sum; use datafusion_functions_window::rank::{Rank, RankType}; use datafusion_physical_plan::ExecutionPlan; - use datafusion_proto::physical_plan::PhysicalExtensionCodec; + use datafusion_proto::physical_plan::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, + }; use crate::execution_plan::tests::EmptyExec; use crate::proto::physical_extension_codec::FFI_PhysicalExtensionCodec; @@ -441,6 +457,7 @@ pub(crate) mod tests { buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if buf[0] != Self::MAGIC_NUMBER { return exec_err!( @@ -459,6 +476,7 @@ pub(crate) mod tests { &self, node: Arc, buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result<()> { buf.push(Self::MAGIC_NUMBER); @@ -579,10 +597,18 @@ pub(crate) mod tests { let exec = create_test_exec(); let input_execs = [create_test_exec()]; let mut bytes = Vec::new(); - foreign_codec.try_encode(Arc::clone(&exec), &mut bytes)?; - - let returned_exec = - foreign_codec.try_decode(&bytes, &input_execs, ctx.task_ctx().as_ref())?; + foreign_codec.try_encode( + Arc::clone(&exec), + &mut bytes, + &DefaultPhysicalProtoConverter, + )?; + + let returned_exec = foreign_codec.try_decode( + &bytes, + &input_execs, + ctx.task_ctx().as_ref(), + &DefaultPhysicalProtoConverter, + )?; assert!(returned_exec.as_any().is::()); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 2bdc05abe3806..7aadfefcc31e0 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -85,7 +85,7 @@ md-5 = { version = "^0.10.0", optional = true } num-traits = { workspace = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.9", optional = true } +sha2 = { version = "^0.10.8", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } uuid = { version = "1.19", features = ["v4"], optional = true } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 02395c76bdd92..0e3d0c7b0f239 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1441,7 +1441,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) AS a + Projection: CAST(a AS LargeUtf8) EmptyRelation: rows=0 " )?; @@ -1477,7 +1477,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) AS a + Projection: CAST(a AS LargeUtf8) EmptyRelation: rows=0 " )?; @@ -1507,7 +1507,7 @@ mod test { true, sort_plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) AS a + Projection: CAST(a AS LargeUtf8) Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 @@ -1536,7 +1536,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) AS a + Projection: CAST(a AS LargeUtf8) Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 @@ -1572,7 +1572,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeBinary) AS a + Projection: CAST(a AS LargeBinary) EmptyRelation: rows=0 " )?; @@ -1629,7 +1629,7 @@ mod test { true, sort_plan.clone(), @r" - Projection: CAST(a AS LargeBinary) AS a + Projection: CAST(a AS LargeBinary) Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 @@ -1660,7 +1660,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeBinary) AS a + Projection: CAST(a AS LargeBinary) Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 36a6df54ddaf0..c8dceed737f34 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -538,15 +538,14 @@ fn recursive_cte_projection_pushdown() -> Result<()> { // columns from the base table and recursive table, eliminating unused columns assert_snapshot!( format!("{plan}"), - @r" - SubqueryAlias: nodes - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS id - TableScan: test projection=[col_int32] - Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) AS id - Filter: nodes.id < Int32(3) - TableScan: nodes projection=[id] - " + @r#"SubqueryAlias: nodes + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id + TableScan: test projection=[col_int32] + Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id] +"# ); Ok(()) } @@ -562,16 +561,14 @@ fn recursive_cte_with_aliased_self_reference() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r" - SubqueryAlias: nodes - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS id - TableScan: test projection=[col_int32] - Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) AS id - SubqueryAlias: child - Filter: nodes.id < Int32(3) - TableScan: nodes projection=[id] - ", + @r#"SubqueryAlias: nodes + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id + TableScan: test projection=[col_int32] + Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) + SubqueryAlias: child + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id]"#, ); Ok(()) } @@ -624,16 +621,15 @@ fn recursive_cte_projection_pushdown_baseline() -> Result<()> { // and only the needed column is selected from the recursive table assert_snapshot!( format!("{plan}"), - @r" - SubqueryAlias: countdown - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS n - Filter: test.col_int32 = Int32(5) - TableScan: test projection=[col_int32] - Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) AS n - Filter: countdown.n > Int32(1) - TableScan: countdown projection=[n] - " + @r#"SubqueryAlias: countdown + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS n + Filter: test.col_int32 = Int32(5) + TableScan: test projection=[col_int32] + Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) + Filter: countdown.n > Int32(1) + TableScan: countdown projection=[n] +"# ); Ok(()) } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 2358a21940912..822ab85c22796 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -74,7 +74,9 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; - /// Get the data type of this expression, given the schema of the input + /// Get the data type of this expression, given the schema of the input. + /// Returns an error if the data type cannot be determined, ex. if the + /// schema is missing a required field. fn data_type(&self, input_schema: &Schema) -> Result { Ok(self.return_field(input_schema)?.data_type().to_owned()) } @@ -415,6 +417,28 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { 0 } + /// Bind runtime-specific data into this expression, if needed. + /// + /// This hook lets an expression replace itself with a runtime-bound version using the given + /// `context` (e.g. binding a per-partition view). + /// + /// Binding is single-pass over the existing tree. If this method returns a replacement + /// expression that itself contains additional bindable nodes, those newly introduced nodes are + /// not rebound in the same call. + /// + /// You should not call this method directly as it does not handle recursion. Instead use + /// [`bind_runtime_physical_expr`] to handle recursion and bind the full expression tree. + /// + /// Note for implementers: this method should *not* handle recursion. + /// Recursion is handled in [`bind_runtime_physical_expr`]. + fn bind_runtime( + &self, + _context: &(dyn Any + Send + Sync), + ) -> Result>> { + // By default, this expression does not need runtime binding. + Ok(None) + } + /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// @@ -607,6 +631,46 @@ pub fn snapshot_physical_expr_opt( }) } +/// Bind runtime-specific data into the given `PhysicalExpr`. +/// +/// See the documentation of [`PhysicalExpr::bind_runtime`] for more details. +/// +/// Runtime binding is applied once over the current expression tree. +/// +/// # Returns +/// +/// Returns a runtime-bound expression if any node required binding, +/// otherwise returns the original expression. +pub fn bind_runtime_physical_expr( + expr: Arc, + context: &(dyn Any + Send + Sync), +) -> Result> { + bind_runtime_physical_expr_opt(expr, context).data() +} + +/// Bind runtime-specific data into the given `PhysicalExpr`. +/// +/// See the documentation of [`PhysicalExpr::bind_runtime`] for more details. +/// +/// Runtime binding is applied once over the current expression tree. +/// +/// # Returns +/// +/// Returns a [`Transformed`] indicating whether any runtime binding happened, +/// along with the resulting expression. +pub fn bind_runtime_physical_expr_opt( + expr: Arc, + context: &(dyn Any + Send + Sync), +) -> Result>> { + expr.transform_up(|e| { + if let Some(bound) = e.bind_runtime(context)? { + Ok(Transformed::yes(bound)) + } else { + Ok(Transformed::no(Arc::clone(&e))) + } + }) +} + /// Check the generation of this `PhysicalExpr`. /// Dynamic `PhysicalExpr`s may have a generation that is incremented /// every time the state of the `PhysicalExpr` changes. @@ -666,6 +730,7 @@ mod test { use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch}; use arrow::datatypes::{DataType, Schema}; use datafusion_expr_common::columnar_value::ColumnarValue; + use std::any::Any; use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -673,7 +738,7 @@ mod test { struct TestExpr {} impl PhysicalExpr for TestExpr { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -715,6 +780,161 @@ mod test { } } + #[derive(Debug, PartialEq, Eq, Hash)] + struct RuntimeBindableExpr { + name: &'static str, + // Selector used to decide if this node should bind for a given context. + bind_key: Option<&'static str>, + children: Vec>, + } + + impl RuntimeBindableExpr { + fn new( + name: &'static str, + bind_key: Option<&'static str>, + children: Vec>, + ) -> Self { + Self { + name, + bind_key, + children, + } + } + } + + impl PhysicalExpr for RuntimeBindableExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn nullable(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(false) + } + + fn evaluate( + &self, + batch: &RecordBatch, + ) -> datafusion_common::Result { + let data = vec![1; batch.num_rows()]; + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data)))) + } + + fn children(&self) -> Vec<&Arc> { + self.children.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(Self { + name: self.name, + bind_key: self.bind_key, + children, + })) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(self.name) + } + + fn bind_runtime( + &self, + context: &(dyn Any + Send + Sync), + ) -> datafusion_common::Result>> { + let Some(bind_key) = self.bind_key else { + return Ok(None); + }; + let Some(ctx) = context.downcast_ref::() else { + return Ok(None); + }; + // Bind only when selector in context matches this node's key. + if ctx.target_key != bind_key { + return Ok(None); + } + + Ok(Some(Arc::new(Self { + // Simulate replacing runtime placeholder with bound payload. + name: ctx.bound_name, + bind_key: None, + children: self.children.clone(), + }))) + } + } + + impl Display for RuntimeBindableExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.fmt_sql(f) + } + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct ErrorOnBindExpr; + + impl PhysicalExpr for ErrorOnBindExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn nullable(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(false) + } + + fn evaluate( + &self, + batch: &RecordBatch, + ) -> datafusion_common::Result { + let data = vec![1; batch.num_rows()]; + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion_common::Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("ErrorOnBindExpr") + } + + fn bind_runtime( + &self, + _context: &(dyn Any + Send + Sync), + ) -> datafusion_common::Result>> { + // Used to verify traversal propagates bind errors. + Err(datafusion_common::DataFusionError::Internal( + "forced bind_runtime error".to_string(), + )) + } + } + + impl Display for ErrorOnBindExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.fmt_sql(f) + } + } + + struct RuntimeBindContext { + // Which bindable nodes should be replaced for this call. + target_key: &'static str, + // Replacement used by matching nodes. + bound_name: &'static str, + } + macro_rules! assert_arrays_eq { ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => { let expected = $EXPECTED.to_array(1).unwrap(); @@ -845,4 +1065,109 @@ mod test { &BooleanArray::from(vec![true; 5]), ); } + + #[test] + fn test_bind_runtime_physical_expr_default_noop() { + // TestExpr does not override bind_runtime, so traversal is a no-op. + let expr: Arc = Arc::new(TestExpr {}); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "bound", + }; + + let transformed = + super::bind_runtime_physical_expr_opt(Arc::clone(&expr), &ctx).unwrap(); + + assert!(!transformed.transformed); + assert!(Arc::ptr_eq(&expr, &transformed.data)); + } + + #[test] + fn test_bind_runtime_physical_expr_recurses() { + // Only the right child matches target_key and should be rewritten. + let left: Arc = + Arc::new(RuntimeBindableExpr::new("left", Some("left"), vec![])); + let right: Arc = + Arc::new(RuntimeBindableExpr::new("right", Some("right"), vec![])); + let root: Arc = Arc::new(RuntimeBindableExpr::new( + "root", + None, + vec![Arc::clone(&left), Arc::clone(&right)], + )); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let transformed = super::bind_runtime_physical_expr_opt(root, &ctx).unwrap(); + assert!(transformed.transformed); + + let root = transformed + .data + .as_any() + .downcast_ref::() + .expect("root should be RuntimeBindableExpr"); + let left = root.children[0] + .as_any() + .downcast_ref::() + .expect("left should be RuntimeBindableExpr"); + let right = root.children[1] + .as_any() + .downcast_ref::() + .expect("right should be RuntimeBindableExpr"); + + assert_eq!(left.name, "left"); + assert_eq!(right.name, "right_bound"); + assert_eq!(right.bind_key, None); + } + + #[test] + fn test_bind_runtime_physical_expr_returns_data() { + // The non-_opt helper should return the rewritten tree directly. + let expr: Arc = + Arc::new(RuntimeBindableExpr::new("right", Some("right"), vec![])); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let bound = super::bind_runtime_physical_expr(expr, &ctx).unwrap(); + let bound = bound + .as_any() + .downcast_ref::() + .expect("bound should be RuntimeBindableExpr"); + + assert_eq!(bound.name, "right_bound"); + assert_eq!(bound.bind_key, None); + } + + #[test] + fn test_bind_runtime_physical_expr_context_mismatch_no_transform() { + // Context mismatch returns no transform even for bindable nodes. + let expr: Arc = + Arc::new(RuntimeBindableExpr::new("left", Some("left"), vec![])); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let transformed = + super::bind_runtime_physical_expr_opt(Arc::clone(&expr), &ctx).unwrap(); + + assert!(!transformed.transformed); + assert!(Arc::ptr_eq(&expr, &transformed.data)); + } + + #[test] + fn test_bind_runtime_physical_expr_propagates_error() { + // A bind_runtime error from any node should fail the traversal. + let expr: Arc = Arc::new(ErrorOnBindExpr); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let err = super::bind_runtime_physical_expr_opt(expr, &ctx).unwrap_err(); + assert!(err.to_string().contains("forced bind_runtime error")); + } } diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index 7703d201aaea9..c0a22b1082f2d 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -20,9 +20,10 @@ use std::{any::Any, fmt::Display, hash::Hash, sync::Arc}; use tokio::sync::watch; use crate::PhysicalExpr; +use crate::expressions::lit; use arrow::datatypes::{DataType, Schema}; use datafusion_common::{ - Result, + Result, internal_err, tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_expr::ColumnarValue; @@ -46,12 +47,55 @@ impl FilterState { } } +/// Per-partition filter expressions indexed by partition number. +type PartitionedFilters = Vec>>; + +/// Payload for dynamic filter updates. +#[derive(Debug, Clone)] +pub enum DynamicFilterUpdate { + /// Update the global expression returned by [`DynamicFilterPhysicalExpr::current`]. + /// + /// This is used by CASE-hash / collect-left routing where a single filter + /// expression represents all partitions. + Global(Arc), + /// Update per-partition expressions used for partition-local lookup. + /// + /// Index `i` corresponds to partition `i`. Missing or out-of-range entries + /// are treated as `true` (fail-open) by `current_for_partition`. + Partitioned(PartitionedFilters), +} + +/// Runtime context for binding [`DynamicFilterPhysicalExpr`] instances. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DynamicFilterRuntimeContext { + partition: usize, +} + +impl DynamicFilterRuntimeContext { + /// Create a runtime binding context for a probe partition index. + pub fn for_partition(partition: usize) -> Self { + Self { partition } + } + + /// Return the bound probe partition index. + pub fn partition(self) -> usize { + self.partition + } +} + /// A dynamic [`PhysicalExpr`] that can be updated by anyone with a reference to it. /// /// Any `ExecutionPlan` that uses this expression and holds a reference to it internally should probably also /// implement `ExecutionPlan::reset_state` to remain compatible with recursive queries and other situations where /// the same `ExecutionPlan` is reused with different data. -#[derive(Debug)] +/// +/// This means `evaluate()` doesn't need partition context since partition routing can be bound +/// once when setting up the execution stream. +/// +/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] +/// +/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters +#[derive(Debug, Clone)] pub struct DynamicFilterPhysicalExpr { /// The original children of this PhysicalExpr, if any. /// This is necessary because the dynamic filter may be initialized with a placeholder (e.g. `lit(true)`) @@ -63,6 +107,9 @@ pub struct DynamicFilterPhysicalExpr { remapped_children: Option>>, /// The source of dynamic filters. inner: Arc>, + /// Runtime-bound partition index for partition-local routing, if set. + /// When `None`, this expression represents the unbound/global view. + runtime_partition: Option, /// Broadcasts filter state (updates and completion) to all waiters. state_watch: watch::Sender, /// For testing purposes track the data type and nullability to make sure they don't change. @@ -82,6 +129,124 @@ struct Inner { /// This is redundant with the watch channel state, but allows us to return immediately /// from `wait_complete()` without subscribing if already complete. is_complete: bool, + /// Per-partition filter expressions for partition-index routing. + /// When both sides of a hash join preserve their file partitioning (no RepartitionExec(Hash)), + /// build-partition i corresponds to probe-partition i. This allows storing per-partition + /// filters so that each partition only sees its own bounds, giving tighter filtering. + partitioned_exprs: PartitionedFilters, +} + +/// An atomic snapshot of a [`DynamicFilterPhysicalExpr`] used to reconstruct the expression during +/// serialization / deserialization. +pub struct DynamicFilterSnapshot { + children: Vec>, + remapped_children: Option>>, + // Inner state. + generation: u64, + inner_expr: Arc, + is_complete: bool, +} + +impl DynamicFilterSnapshot { + pub fn new( + children: Vec>, + remapped_children: Option>>, + generation: u64, + inner_expr: Arc, + is_complete: bool, + ) -> Self { + Self { + children, + remapped_children, + generation, + inner_expr, + is_complete, + } + } + + pub fn children(&self) -> &[Arc] { + &self.children + } + + pub fn remapped_children(&self) -> Option<&[Arc]> { + self.remapped_children.as_deref() + } + + pub fn generation(&self) -> u64 { + self.generation + } + + pub fn inner_expr(&self) -> &Arc { + &self.inner_expr + } + + pub fn is_complete(&self) -> bool { + self.is_complete + } +} + +impl Display for DynamicFilterSnapshot { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DynamicFilterSnapshot {{ children: {:?}, remapped_children: {:?}, generation: {}, inner_expr: {:?}, is_complete: {} }}", + self.children, + self.remapped_children, + self.generation, + self.inner_expr, + self.is_complete + ) + } +} + +impl From for DynamicFilterPhysicalExpr { + fn from(snapshot: DynamicFilterSnapshot) -> Self { + let DynamicFilterSnapshot { + children, + remapped_children, + generation, + inner_expr, + is_complete, + } = snapshot; + + let state = if is_complete { + FilterState::Complete { generation } + } else { + FilterState::InProgress { generation } + }; + let (state_watch, _) = watch::channel(state); + + Self { + children, + remapped_children, + inner: Arc::new(RwLock::new(Inner { + generation, + expr: inner_expr, + is_complete, + partitioned_exprs: vec![], + })), + state_watch, + data_type: Arc::new(RwLock::new(None)), + nullable: Arc::new(RwLock::new(None)), + runtime_partition: None, + } + } +} + +impl From<&DynamicFilterPhysicalExpr> for DynamicFilterSnapshot { + fn from(expr: &DynamicFilterPhysicalExpr) -> Self { + let (generation, inner_expr, is_complete) = { + let inner = expr.inner.read(); + (inner.generation, Arc::clone(&inner.expr), inner.is_complete) + }; + DynamicFilterSnapshot { + children: expr.children.clone(), + remapped_children: expr.remapped_children.clone(), + generation, + inner_expr, + is_complete, + } + } } impl Inner { @@ -92,6 +257,7 @@ impl Inner { generation: 1, expr, is_complete: false, + partitioned_exprs: Vec::new(), } } @@ -110,6 +276,7 @@ impl Hash for DynamicFilterPhysicalExpr { Arc::as_ptr(&self.inner).hash(state); self.children.dyn_hash(state); self.remapped_children.dyn_hash(state); + self.runtime_partition.hash(state); } } @@ -122,6 +289,7 @@ impl PartialEq for DynamicFilterPhysicalExpr { Arc::ptr_eq(&self.inner, &other.inner) && self.children == other.children && self.remapped_children == other.remapped_children + && self.runtime_partition == other.runtime_partition } } @@ -170,24 +338,62 @@ impl DynamicFilterPhysicalExpr { children, remapped_children: None, // Initially no remapped children inner: Arc::new(RwLock::new(Inner::new(inner))), + runtime_partition: None, state_watch, data_type: Arc::new(RwLock::new(None)), nullable: Arc::new(RwLock::new(None)), } } + /// Create a new [`DynamicFilterPhysicalExpr`] from `self`, except it overwrites the + /// internal state with the source filter's state. + /// + /// This is a low-level API intended for use by the proto deserialization layer. + /// + /// # Safety + /// + /// The dynamic filter should not be in use when calling this method, otherwise there + /// may be undefined behavior. Changing the inner state of a filter may do the following: + /// - transition the state to complete without notifying the watch + /// - cause a generation number to be emitted which is out of order + pub fn new_from_source( + self: &Arc, + source: &DynamicFilterPhysicalExpr, + ) -> Result { + // If there's any references to this filter or any watchers, we should not replace the + // inner state. + if self.is_used() { + return internal_err!( + "Cannot replace the inner state of a DynamicFilterPhysicalExpr that is in use" + ); + }; + + Ok(Self { + children: self.children.clone(), + remapped_children: self.remapped_children.clone(), + inner: Arc::clone(&source.inner), + // Reuse source watch channel so waiters on any relinked clone + // observe update/complete notifications from the producer. + state_watch: source.state_watch.clone(), + data_type: Arc::clone(&self.data_type), + nullable: Arc::clone(&self.nullable), + runtime_partition: None, + }) + } + fn remap_children( - children: &[Arc], - remapped_children: Option<&Vec>>, + &self, expr: Arc, ) -> Result> { - if let Some(remapped_children) = remapped_children { + if let Some(remapped_children) = &self.remapped_children { // Remap the children to the new children // of the expression. expr.transform_up(|child| { // Check if this is any of our original children - if let Some(pos) = - children.iter().position(|c| c.as_ref() == child.as_ref()) + if let Some(pos) = self + .children + .iter() + .position(|c| c.as_ref() == child.as_ref()) { // If so, remap it to the current children // of the expression. @@ -205,47 +411,59 @@ impl DynamicFilterPhysicalExpr { } } - /// Get the current generation of the expression. - fn current_generation(&self) -> u64 { - self.inner.read().generation - } - /// Get the current expression. /// This will return the current expression with any children /// remapped to match calls to [`PhysicalExpr::with_new_children`]. pub fn current(&self) -> Result> { - let expr = Arc::clone(self.inner.read().expr()); - Self::remap_children(&self.children, self.remapped_children.as_ref(), expr) - } - - /// Update the current expression and notify all waiters. - /// Any children of this expression must be a subset of the original children - /// passed to the constructor. - /// This should be called e.g.: - /// - When we've computed the probe side's hash table in a HashJoinExec - /// - After every batch is processed if we update the TopK heap in a SortExec using a TopK approach. - pub fn update(&self, new_expr: Arc) -> Result<()> { - // Remap the children of the new expression to match the original children - // We still do this again in `current()` but doing it preventively here - // reduces the work needed in some cases if `current()` is called multiple times - // and the same externally facing `PhysicalExpr` is used for both `with_new_children` and `update()`.` - let new_expr = Self::remap_children( - &self.children, - self.remapped_children.as_ref(), - new_expr, - )?; + if let Some(partition) = self.runtime_partition { + self.current_for_partition(partition) + } else { + let expr = Arc::clone(self.inner.read().expr()); + self.remap_children(expr) + } + } + + /// Update this dynamic filter and notify all waiters. + /// + /// This is called by producers when new bounds are available, e.g.: + /// - after building bounds in `HashJoinExec` + /// - as `TopK` thresholds become more selective + /// + /// This single API handles both update modes: + /// - [`DynamicFilterUpdate::Global`]: updates the global expression. + /// - [`DynamicFilterUpdate::Partitioned`]: updates per-partition + /// expressions, with one filter per partition index. + /// + /// Each update increments [`PhysicalExpr::snapshot_generation`] and + /// notifies waiters via `state_watch`. + pub fn update(&self, update: DynamicFilterUpdate) -> Result<()> { + let update = match update { + // Remap global expression children to match the original children. + // We still do this again in `current()` but doing it preventively here + // reduces the work needed in some cases if `current()` is called multiple times + // and the same externally facing `PhysicalExpr` is used for both + // `with_new_children` and `update()`. + DynamicFilterUpdate::Global(new_expr) => { + DynamicFilterUpdate::Global(self.remap_children(new_expr)?) + } + DynamicFilterUpdate::Partitioned(partition_exprs) => { + DynamicFilterUpdate::Partitioned(partition_exprs) + } + }; // Load the current inner, increment generation, and store the new one let mut current = self.inner.write(); let new_generation = current.generation + 1; - *current = Inner { - generation: new_generation, - expr: new_expr, - is_complete: current.is_complete, - }; - drop(current); // Release the lock before broadcasting + current.generation = new_generation; + match update { + DynamicFilterUpdate::Global(new_expr) => current.expr = new_expr, + DynamicFilterUpdate::Partitioned(partition_exprs) => { + current.partitioned_exprs = partition_exprs + } + } + drop(current); // Release the lock before broadcasting. - // Broadcast the new state to all waiters + // Broadcast the new state to all waiters. let _ = self.state_watch.send(FilterState::InProgress { generation: new_generation, }); @@ -268,10 +486,44 @@ impl DynamicFilterPhysicalExpr { }); } + /// Get the filter expression for a specific partition. + /// + /// Semantics when per-partition filters are present: + /// - `Some(Some(expr))`: use the partition-local filter. + /// - `Some(None)`: the build partition is known empty, so return `false`. + /// - `None` (out-of-range): return `true` (fail-open) to avoid incorrect pruning if + /// partition alignment/count assumptions are violated by a source. + /// + /// Returns: + /// - `Ok(Expr)`: Dynamic filter expression to be used for the given partition + /// - `Ok(lit(false))`: Filters out everything on probe side (build side is empty for this partition) + /// - `Ok(lit(true))`: No filtering applied, returns probe data as-is (fail-open for safety) + fn current_for_partition(&self, partition: usize) -> Result> { + let guard = self.inner.read(); + if guard.partitioned_exprs.is_empty() { + let expr = Arc::clone(guard.expr()); + drop(guard); + return self.remap_children(expr); + } + match guard.partitioned_exprs.get(partition) { + Some(Some(expr)) => { + let expr = Arc::clone(expr); + drop(guard); + self.remap_children(expr) + } + Some(None) => Ok(lit(false) as Arc), + None => Ok(lit(true) as Arc), + } + } + /// Wait asynchronously for any update to this filter. /// /// This method will return when [`Self::update`] is called and the generation increases. /// It does not guarantee that the filter is complete. + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. pub async fn wait_update(&self) { let mut rx = self.state_watch.subscribe(); // Get the current generation @@ -283,17 +535,16 @@ impl DynamicFilterPhysicalExpr { /// Wait asynchronously until this dynamic filter is marked as complete. /// - /// This method returns immediately if the filter is already complete or if the filter - /// is not being used by any consumers. + /// This method returns immediately if the filter is already complete. /// Otherwise, it waits until [`Self::mark_complete`] is called. /// /// Unlike [`Self::wait_update`], this method guarantees that when it returns, /// the filter is fully complete with no more updates expected. - pub async fn wait_complete(self: &Arc) { - if !self.is_used() { - return; - } - + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. + pub async fn wait_complete(&self) { if self.inner.read().is_complete { return; } @@ -310,31 +561,65 @@ impl DynamicFilterPhysicalExpr { /// that created the filter). This is useful to avoid computing expensive filter /// expressions when no consumer will actually use them. /// - /// Note: We check the inner Arc's strong_count, not the outer Arc's count, because - /// when filters are transformed (e.g., via reassign_expr_columns during filter pushdown), - /// new outer Arc instances are created via with_new_children(), but they all share the - /// same inner `Arc>`. This is what allows filter updates to propagate to - /// consumers even after transformation. + /// # Implementation Details + /// + /// We check both Arc counts to handle two cases: + /// - Transformed filters (via `with_new_children`) share the inner Arc (inner count > 1) + /// - Direct clones (via `Arc::clone`) increment the outer count (outer count > 1) pub fn is_used(self: &Arc) -> bool { // Strong count > 1 means at least one consumer is holding a reference beyond the producer. - Arc::strong_count(&self.inner) > 1 + Arc::strong_count(self) > 1 || Arc::strong_count(&self.inner) > 1 + } + + /// Returns a unique identifier for the inner shared state. + /// + /// Useful for checking if two [`Arc`] with the same + /// underlying [`DynamicFilterPhysicalExpr`] are the same. + pub fn inner_id(&self) -> u64 { + Arc::as_ptr(&self.inner) as *const () as u64 } fn render( &self, f: &mut std::fmt::Formatter<'_>, - render_expr: impl FnOnce( + render_expr: impl Fn( Arc, &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result, ) -> std::fmt::Result { - let inner = self.current().map_err(|_| std::fmt::Error)?; - let current_generation = self.current_generation(); + let guard = self.inner.read(); + let current_generation = guard.generation; + let partitioned_exprs = guard.partitioned_exprs.clone(); + drop(guard); + + if let Some(partition) = self.runtime_partition { + write!(f, "DynamicFilter(partition={partition}) [ ")?; + let current = self + .current_for_partition(partition) + .map_err(|_| std::fmt::Error)?; + render_expr(current, f)?; + return write!(f, " ]"); + } + write!(f, "DynamicFilter [ ")?; - if current_generation == 1 { + if !partitioned_exprs.is_empty() { + write!(f, "{{")?; + for (i, partition) in partitioned_exprs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{i}: ")?; + match partition { + Some(expr) => render_expr(Arc::clone(expr), f)?, + None => write!(f, "pruned")?, + } + } + write!(f, "}}")?; + } else if current_generation == 1 { write!(f, "empty")?; } else { - render_expr(inner, f)?; + let current = self.current().map_err(|_| std::fmt::Error)?; + render_expr(current, f)?; } write!(f, " ]") @@ -362,12 +647,42 @@ impl PhysicalExpr for DynamicFilterPhysicalExpr { children: self.children.clone(), remapped_children: Some(children), inner: Arc::clone(&self.inner), + runtime_partition: self.runtime_partition, state_watch: self.state_watch.clone(), data_type: Arc::clone(&self.data_type), nullable: Arc::clone(&self.nullable), })) } + fn bind_runtime( + &self, + context: &(dyn Any + Send + Sync), + ) -> Result>> { + let partition = + if let Some(ctx) = context.downcast_ref::() { + ctx.partition() + } else if let Some(partition) = context.downcast_ref::() { + // Backward-compatible fallback for callers that pass the partition index directly. + *partition + } else { + return Ok(None); + }; + + if self.runtime_partition == Some(partition) { + return Ok(None); + } + + Ok(Some(Arc::new(Self { + children: self.children.clone(), + remapped_children: self.remapped_children.clone(), + inner: Arc::clone(&self.inner), + runtime_partition: Some(partition), + state_watch: self.state_watch.clone(), + data_type: Arc::clone(&self.data_type), + nullable: Arc::clone(&self.nullable), + }))) + } + fn data_type(&self, input_schema: &Schema) -> Result { let res = self.current()?.data_type(input_schema)?; #[cfg(test)] @@ -454,8 +769,11 @@ mod test { datatypes::{DataType, Field, Schema}, }; use datafusion_common::ScalarValue; + use datafusion_physical_expr_common::physical_expr::{ + bind_runtime_physical_expr, snapshot_generation, + }; - use super::*; + use super::{DynamicFilterRuntimeContext, *}; #[test] fn test_remap_children() { @@ -543,7 +861,9 @@ mod test { lit(43) as Arc, )); dynamic_filter - .update(Arc::clone(&new_expr) as Arc) + .update(DynamicFilterUpdate::Global( + Arc::clone(&new_expr) as Arc + )) .expect("Failed to update expression"); // Now we should be able to evaluate the new expression on both batches let result_1 = dynamic_filter_1.evaluate(&batch_1).unwrap(); @@ -573,7 +893,9 @@ mod test { // Update the current expression let new_expr = lit(100) as Arc; - dynamic_filter.update(Arc::clone(&new_expr)).unwrap(); + dynamic_filter + .update(DynamicFilterUpdate::Global(Arc::clone(&new_expr))) + .unwrap(); // Take another snapshot let snapshot = dynamic_filter.snapshot().unwrap(); assert_eq!(snapshot, Some(new_expr)); @@ -602,7 +924,9 @@ mod test { // Now change the current expression to something else. dynamic_filter - .update(lit(ScalarValue::Utf8(None)) as Arc) + .update(DynamicFilterUpdate::Global( + lit(ScalarValue::Utf8(None)) as Arc + )) .expect("Failed to update expression"); // Check that we error if we call data_type, nullable or evaluate after changing the expression. assert!( @@ -779,7 +1103,9 @@ mod test { // Update changes the underlying expression filter - .update(lit(false) as Arc) + .update(DynamicFilterUpdate::Global( + lit(false) as Arc + )) .expect("Update should succeed"); // Compute hash AFTER update @@ -845,7 +1171,9 @@ mod test { // Update the expression filter - .update(lit(false) as Arc) + .update(DynamicFilterUpdate::Global( + lit(false) as Arc + )) .expect("Update should succeed"); // Hash should STILL be the same (identity-based) @@ -860,4 +1188,330 @@ mod test { "Hash should be stable after update (identity-based)" ); } + + #[test] + fn test_update_partitioned_and_current_for_partition() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a = col("a", &schema).unwrap(); + + let dynamic_filter = DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true) as Arc, + ); + + // Create per-partition expressions + let partition_0_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + lit(10) as Arc, + )) as Arc; + let partition_1_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + lit(20) as Arc, + )) as Arc; + + let partition_exprs = vec![ + Some(Arc::clone(&partition_0_expr)), + Some(Arc::clone(&partition_1_expr)), + ]; + + dynamic_filter + .update(DynamicFilterUpdate::Partitioned(partition_exprs)) + .unwrap(); + + // Partition 0 should get its specific filter + let p0 = dynamic_filter.current_for_partition(0).unwrap(); + assert_eq!(format!("{p0}"), format!("{partition_0_expr}")); + + // Partition 1 should get its specific filter + let p1 = dynamic_filter.current_for_partition(1).unwrap(); + assert_eq!(format!("{p1}"), format!("{partition_1_expr}")); + } + + #[test] + fn test_current_for_partition_empty_and_out_of_range() { + let dynamic_filter = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + + let partition_exprs = vec![ + Some(lit(42) as Arc), + None, // Empty partition + ]; + + dynamic_filter + .update(DynamicFilterUpdate::Partitioned(partition_exprs)) + .unwrap(); + + // Partition 1 is empty, should return lit(false) + let p1 = dynamic_filter.current_for_partition(1).unwrap(); + assert_eq!(format!("{p1}"), "false"); + + // Partition 5 is out of range, should fail-open to lit(true) + let p5 = dynamic_filter.current_for_partition(5).unwrap(); + assert_eq!(format!("{p5}"), "true"); + } + + #[test] + fn test_bind_dynamic_filters_for_partition_with_data() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a = col("a", &schema).unwrap(); + + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true) as Arc, + )); + + let partition_0_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + lit(10) as Arc, + )) as Arc; + let partition_1_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::LtEq, + lit(20) as Arc, + )) as Arc; + + dynamic_filter + .update(DynamicFilterUpdate::Partitioned(vec![ + Some(Arc::clone(&partition_0_expr)), + Some(Arc::clone(&partition_1_expr)), + ])) + .unwrap(); + + // Bind within a parent expression to verify recursive tree rewrite. + let wrapper = Arc::new(BinaryExpr::new( + Arc::clone(&dynamic_filter) as Arc, + datafusion_expr::Operator::And, + lit(true) as Arc, + )) as Arc; + + let ctx = DynamicFilterRuntimeContext::for_partition(1); + let bound = bind_runtime_physical_expr(Arc::clone(&wrapper), &ctx).unwrap(); + + assert!( + format!("{bound}").contains("<="), + "Expected partition 1 expression in runtime-bound dynamic filter" + ); + assert!( + format!("{bound}").contains("partition=1"), + "Expected runtime-bound expression to retain partition context" + ); + } + + #[test] + fn test_bind_dynamic_filters_for_partition_without_partitioned_data() { + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![], + lit(42) as Arc, + )); + + let ctx = DynamicFilterRuntimeContext::for_partition(0); + let bound = bind_runtime_physical_expr( + Arc::clone(&dynamic_filter) as Arc, + &ctx, + ) + .unwrap(); + + assert!( + bound + .as_any() + .downcast_ref::() + .is_some(), + "Runtime binding should preserve dynamic filter type" + ); + assert!( + format!("{bound}").contains("partition=0"), + "Runtime binding should include partition context" + ); + } + + #[test] + fn test_runtime_bound_dynamic_filter_tracks_updates() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a = col("a", &schema).unwrap(); + + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true) as Arc, + )); + + let initial_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + lit(10) as Arc, + )) as Arc; + dynamic_filter + .update(DynamicFilterUpdate::Partitioned(vec![Some(initial_expr)])) + .unwrap(); + + let ctx = DynamicFilterRuntimeContext::for_partition(0); + let bound = bind_runtime_physical_expr( + Arc::clone(&dynamic_filter) as Arc, + &ctx, + ) + .unwrap(); + let generation_before = snapshot_generation(&bound); + assert!( + format!("{bound}").contains("10"), + "Expected initial partition-bound dynamic filter to reference literal 10" + ); + + let updated_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + lit(20) as Arc, + )) as Arc; + dynamic_filter + .update(DynamicFilterUpdate::Partitioned(vec![Some(updated_expr)])) + .unwrap(); + + let generation_after = snapshot_generation(&bound); + assert_ne!( + generation_before, generation_after, + "Runtime-bound dynamic filter should track source generation changes" + ); + assert!( + format!("{bound}").contains("20"), + "Expected runtime-bound dynamic filter to reflect updated partition expression" + ); + } + + #[test] + fn test_runtime_bind_before_partitioned_update_tracks_new_partition_data() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a = col("a", &schema).unwrap(); + + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true) as Arc, + )); + + // Bind before any partition-local payload exists. + let ctx = DynamicFilterRuntimeContext::for_partition(1); + let bound = bind_runtime_physical_expr( + Arc::clone(&dynamic_filter) as Arc, + &ctx, + ) + .unwrap(); + assert!( + format!("{bound}").contains("true"), + "Before partitioned update, bound dynamic filter should evaluate as global expression" + ); + + let generation_before = snapshot_generation(&bound); + + let partition_0_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + lit(10) as Arc, + )) as Arc; + let partition_1_expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::GtEq, + lit(30) as Arc, + )) as Arc; + dynamic_filter + .update(DynamicFilterUpdate::Partitioned(vec![ + Some(partition_0_expr), + Some(partition_1_expr), + ])) + .unwrap(); + + let generation_after = snapshot_generation(&bound); + assert_ne!( + generation_before, generation_after, + "Bound dynamic filter generation should change after source updates" + ); + assert!( + format!("{bound}").contains("30"), + "Bound dynamic filter should route to partition-local payload after update" + ); + } + + #[test] + fn test_new_from_source() { + // Create a source filter + let source = Arc::new(DynamicFilterPhysicalExpr::new( + vec![], + lit(42) as Arc, + )); + + // Update and mark complete + source + .update(DynamicFilterUpdate::Global( + lit(100) as Arc + )) + .unwrap(); + source.mark_complete(); + + // Create a target filter with different children + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); + let col_x = col("x", &schema).unwrap(); + let target = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_x)], + lit(0) as Arc, + )); + + // Create new filter from source's inner state + let combined = target.new_from_source(&source).unwrap(); + + // Verify inner state is shared (same inner_id) + assert_eq!( + combined.inner_id(), + source.inner_id(), + "new_from_source should share inner state with source" + ); + + // Verify children are from target, not source + let combined_snapshot = DynamicFilterSnapshot::from(&combined); + assert_eq!( + combined_snapshot.children().len(), + 1, + "Combined filter should have target's children" + ); + assert_eq!( + format!("{:?}", combined_snapshot.children()[0]), + format!("{:?}", col_x), + "Combined filter should have target's children" + ); + + // Verify inner expression comes from source + assert_eq!( + format!("{:?}", combined_snapshot.inner_expr()), + format!("{:?}", lit(100)), + "Combined filter should have source's inner expression" + ); + } + + #[tokio::test] + async fn test_new_from_source_wait_complete_notifications() { + // Create an incomplete source and relink a second filter to it. + let source = Arc::new(DynamicFilterPhysicalExpr::new( + vec![], + lit(42) as Arc, + )); + let target = Arc::new(DynamicFilterPhysicalExpr::new( + vec![], + lit(0) as Arc, + )); + let combined = Arc::new(target.new_from_source(&source).unwrap()); + + let waiter = tokio::spawn({ + let combined = Arc::clone(&combined); + async move { + combined.wait_complete().await; + } + }); + + // Ensure waiter has subscribed before completion is signalled. + tokio::task::yield_now().await; + source.mark_complete(); + + tokio::time::timeout(std::time::Duration::from_secs(1), waiter) + .await + .expect("wait_complete should be notified by source mark_complete") + .expect("wait_complete task should not panic"); + } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c9e02708d6c28..ce8ef61800efb 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -46,6 +46,9 @@ pub use cast_column::CastColumnExpr; pub use column::{Column, col, with_new_schema}; pub use datafusion_expr::utils::format_state_name; pub use dynamic_filters::DynamicFilterPhysicalExpr; +pub use dynamic_filters::DynamicFilterRuntimeContext; +pub use dynamic_filters::DynamicFilterSnapshot; +pub use dynamic_filters::DynamicFilterUpdate; pub use in_list::{InListExpr, in_list}; pub use is_not_null::{IsNotNullExpr, is_not_null}; pub use is_null::{IsNullExpr, is_null}; diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 6348c4663fa16..1713d7208ad06 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -34,6 +34,7 @@ use crate::utils::{ use arrow::compute::SortOptions; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; +use datafusion_common::plan_err; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::logical_plan::{Aggregate, JoinType}; @@ -49,7 +50,8 @@ use datafusion_physical_plan::aggregates::{ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, + CrossJoinExec, DynamicFilterRoutingMode, HashJoinExec, PartitionMode, + SortMergeJoinExec, }; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -295,6 +297,7 @@ pub fn adjust_input_keys_ordering( projection, mode, null_equality, + dynamic_filter_routing_mode, .. }) = plan.as_any().downcast_ref::() { @@ -304,7 +307,7 @@ pub fn adjust_input_keys_ordering( Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec, )| { - HashJoinExec::try_new( + let join = HashJoinExec::try_new( Arc::clone(left), Arc::clone(right), new_conditions.0, @@ -314,8 +317,9 @@ pub fn adjust_input_keys_ordering( projection.clone(), PartitionMode::Partitioned, *null_equality, - ) - .map(|e| Arc::new(e) as _) + )? + .with_dynamic_filter_routing_mode(*dynamic_filter_routing_mode); + Ok(Arc::new(join) as _) }; return reorder_partitioned_join_keys( requirements, @@ -618,6 +622,7 @@ pub fn reorder_join_keys_to_inputs( projection, mode, null_equality, + dynamic_filter_routing_mode, .. }) = plan_any.downcast_ref::() { @@ -635,7 +640,7 @@ pub fn reorder_join_keys_to_inputs( right_keys, } = join_keys; let new_join_on = new_join_conditions(&left_keys, &right_keys); - return Ok(Arc::new(HashJoinExec::try_new( + let join = HashJoinExec::try_new( Arc::clone(left), Arc::clone(right), new_join_on, @@ -644,7 +649,9 @@ pub fn reorder_join_keys_to_inputs( projection.clone(), PartitionMode::Partitioned, *null_equality, - )?)); + )? + .with_dynamic_filter_routing_mode(*dynamic_filter_routing_mode); + return Ok(Arc::new(join)); } } } else if let Some(SortMergeJoinExec { @@ -862,7 +869,14 @@ fn add_roundrobin_on_top( let new_plan = Arc::new(repartition) as _; - Ok(DistributionContext::new(new_plan, true, vec![input])) + Ok(DistributionContext::new( + new_plan, + DistFlags { + dist_changing: true, + repartitioned: true, + }, + vec![input], + )) } else { // Partition is not helpful, we already have desired number of partitions. Ok(input) @@ -931,7 +945,14 @@ fn add_hash_on_top( .with_preserve_order(); let plan = Arc::new(repartition) as _; - return Ok(DistributionContext::new(plan, true, vec![input])); + return Ok(DistributionContext::new( + plan, + DistFlags { + dist_changing: true, + repartitioned: true, + }, + vec![input], + )); } Ok(input) @@ -968,7 +989,14 @@ fn add_merge_on_top(input: DistributionContext) -> DistributionContext { Arc::new(CoalescePartitionsExec::new(Arc::clone(&input.plan))) as _ }; - DistributionContext::new(new_plan, true, vec![input]) + DistributionContext::new( + new_plan, + DistFlags { + dist_changing: true, + repartitioned: input.data.repartitioned, + }, + vec![input], + ) } else { input } @@ -1032,7 +1060,7 @@ pub fn replace_order_preserving_variants( .children .into_iter() .map(|child| { - if child.data { + if child.data.dist_changing { replace_order_preserving_variants(child) } else { Ok(child) @@ -1369,7 +1397,7 @@ pub fn ensure_distribution( .ordering_satisfy_requirement(sort_req.clone())?; if (!ordering_satisfied || !order_preserving_variants_desirable) - && child.data + && child.data.dist_changing { child = replace_order_preserving_variants(child)?; // If ordering requirements were satisfied before repartitioning, @@ -1387,7 +1415,7 @@ pub fn ensure_distribution( } } // Stop tracking distribution changing operators - child.data = false; + child.data.dist_changing = false; } else { // no ordering requirement match requirement { @@ -1446,21 +1474,96 @@ pub fn ensure_distribution( plan.with_new_children(children_plans)? }; + /// Helper to describe partitioning scheme for error messages + fn partitioning_scheme_name(is_repartitioned: bool) -> &'static str { + if is_repartitioned { + "hash-repartitioned" + } else { + "file-grouped" + } + } + + // For partitioned hash joins, decide dynamic filter routing mode. + // + // Dynamic filtering requires matching partitioning schemes on both sides: + // - PartitionIndex: Both sides use file-grouped partitioning (value-based). + // Partition i on build corresponds to partition i on probe by partition value. + // - CaseHash: Both sides use hash repartitioning (hash-based). + // Uses CASE expression with hash(row) % N to route to correct partition filter. + // + // NOTE: If partitioning schemes are misaligned (one file-grouped, one hash-repartitioned), + // the partitioned join itself is incorrect. + // Partition assignments don't match: + // - File-grouped: partition 0 = all rows where column="A" (value-based) + // - Hash-repartitioned: partition 0 = all rows where hash(column) % N == 0 (hash-based) + // These are incompatible, so the join will miss matching rows. + plan = if let Some(hash_join) = plan.as_any().downcast_ref::() + && matches!(hash_join.mode, PartitionMode::Partitioned) + { + let routing_mode = match ( + children[0].data.repartitioned, + children[1].data.repartitioned, + ) { + (false, false) => DynamicFilterRoutingMode::PartitionIndex, + (true, true) => DynamicFilterRoutingMode::CaseHash, + _ => { + return plan_err!( + "Partitioned hash join has incompatible partitioning schemes: \ + left side is {}, right side is {}.", + partitioning_scheme_name(children[0].data.repartitioned), + partitioning_scheme_name(children[1].data.repartitioned) + ); + } + }; + + if routing_mode != hash_join.dynamic_filter_routing_mode { + let rebuilt = HashJoinExec::try_new( + Arc::clone(hash_join.left()), + Arc::clone(hash_join.right()), + hash_join + .on() + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .collect(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.projection.clone(), + PartitionMode::Partitioned, + hash_join.null_equality(), + )? + .with_dynamic_filter_routing_mode(routing_mode); + Arc::new(rebuilt) + } else { + plan + } + } else { + plan + }; + Ok(Transformed::yes(DistributionContext::new( plan, data, children, ))) } -/// Keeps track of distribution changing operators (like `RepartitionExec`, -/// `SortPreservingMergeExec`, `CoalescePartitionsExec`) and their ancestors. -/// Using this information, we can optimize distribution of the plan if/when -/// necessary. -pub type DistributionContext = PlanContext; +/// State propagated during the bottom-up pass in [`ensure_distribution`]. +#[derive(Clone, Copy, Debug, Default)] +pub struct DistFlags { + /// Whether a distribution-changing operator (`RepartitionExec`, `SortPreservingMergeExec`, + /// `CoalescePartitionsExec`) exists in the subtree. + pub dist_changing: bool, + /// Whether the output partitioning originates from a [`RepartitionExec`]. + /// Used by partitioned hash joins to choose the dynamic filter routing mode. + pub repartitioned: bool, +} + +pub type DistributionContext = PlanContext; fn update_children(mut dist_context: DistributionContext) -> Result { for child_context in dist_context.children.iter_mut() { let child_plan_any = child_context.plan.as_any(); - child_context.data = + + // Track distribution-changing operators for order-preservation optimization. + child_context.data.dist_changing = if let Some(repartition) = child_plan_any.downcast_ref::() { !matches!( repartition.partitioning(), @@ -1470,23 +1573,46 @@ fn update_children(mut dist_context: DistributionContext) -> Result() || child_plan_any.is::() || child_context.plan.children().is_empty() - || child_context.children[0].data + || child_context.children[0].data.dist_changing || child_context .plan .required_input_distribution() .iter() .zip(child_context.children.iter()) .any(|(required_dist, child_context)| { - child_context.data + child_context.data.dist_changing && matches!( required_dist, Distribution::UnspecifiedDistribution ) }) - } + }; + + // Track whether partitioning originates from a RepartitionExec, following the + // partition-determining path through the context tree. + child_context.data.repartitioned = + if let Some(repartition) = child_plan_any.downcast_ref::() { + !matches!( + repartition.partitioning(), + Partitioning::UnknownPartitioning(_) + ) + } else if child_context.plan.children().is_empty() { + false + } else if let Some(hash_join) = child_plan_any.downcast_ref::() + && matches!(hash_join.mode, PartitionMode::CollectLeft) + { + // CollectLeft output partitioning is inherited from the probe (right) side. + child_context + .children + .get(1) + .map(|c| c.data.repartitioned) + .unwrap_or(false) + } else { + child_context.children.iter().any(|c| c.data.repartitioned) + }; } - dist_context.data = false; + dist_context.data = DistFlags::default(); Ok(dist_context) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 06f12a90195d2..e8e5f194db85b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -869,6 +869,34 @@ impl AggregateExec { &self.input_order_mode } + /// Returns the dynamic filter expression for this aggregate, if set. + pub fn dynamic_filter(&self) -> Option<&Arc> { + self.dynamic_filter.as_ref().map(|df| &df.filter) + } + + /// Replace the dynamic filter expression, recomputing any internal state + /// which may depend on the previous dynamic filter. + /// + /// This is a no-op if the aggregate does not support dynamic filtering. + /// + /// If dynamic filtering is supported, this method returns an error if the filter's + /// children reference invalid columns in the aggregate's input schema. + pub fn with_dynamic_filter( + mut self, + filter: Arc, + ) -> Result { + if let Some(supported_accumulators_info) = self.supported_accumulators_info() { + for child in filter.children() { + child.data_type(&self.input_schema)?; + } + self.dynamic_filter = Some(Arc::new(AggrDynFilter { + filter, + supported_accumulators_info, + })); + } + Ok(self) + } + fn statistics_inner(&self, child_statistics: &Statistics) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here @@ -949,27 +977,40 @@ impl AggregateExec { /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field. /// - If not supported, `self.dynamic_filter` should be kept `None` fn init_dynamic_filter(&mut self) { - if (!self.group_by.is_empty()) || (!matches!(self.mode, AggregateMode::Partial)) { - debug_assert!( - self.dynamic_filter.is_none(), - "The current operator node does not support dynamic filter" - ); - return; - } - // Already initialized. if self.dynamic_filter.is_some() { return; } - // Collect supported accumulators - // It is assumed the order of aggregate expressions are not changed from `AggregateExec` - // to `AggregateStream` + if let Some(supported_accumulators_info) = self.supported_accumulators_info() { + // Collect column references for the dynamic filter expression. + let all_cols: Vec> = supported_accumulators_info + .iter() + .map(|info| Arc::clone(&self.aggr_expr[info.aggr_index].expressions()[0])) + .collect(); + + self.dynamic_filter = Some(Arc::new(AggrDynFilter { + filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))), + supported_accumulators_info, + })); + } + } + + /// Returns the supported accumulator info if this aggregate supports + /// dynamic filtering, or `None` otherwise. + /// + /// Dynamic filtering requires: + /// - `Partial` aggregation mode with no group-by expressions + /// - All aggregate functions are `min` or `max` with a single column arg + fn supported_accumulators_info(&self) -> Option> { + if !self.group_by.is_empty() || !matches!(self.mode, AggregateMode::Partial) { + return None; + } + + // Collect supported accumulators. + // It is assumed the order of aggregate expressions are not changed + // from `AggregateExec` to `AggregateStream`. let mut aggr_dyn_filters = Vec::new(); - // All column references in the dynamic filter, used when initializing the dynamic - // filter, and it's used to decide if this dynamic filter is able to get push - // through certain node during optimization. - let mut all_cols: Vec> = Vec::new(); for (i, aggr_expr) in self.aggr_expr.iter().enumerate() { // 1. Only `min` or `max` aggregate function let fun_name = aggr_expr.fun().name(); @@ -987,7 +1028,6 @@ impl AggregateExec { if let [arg] = aggr_expr.expressions().as_slice() && arg.as_any().is::() { - all_cols.push(Arc::clone(arg)); aggr_dyn_filters.push(PerAccumulatorDynFilter { aggr_type, aggr_index: i, @@ -996,11 +1036,10 @@ impl AggregateExec { } } - if !aggr_dyn_filters.is_empty() { - self.dynamic_filter = Some(Arc::new(AggrDynFilter { - filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))), - supported_accumulators_info: aggr_dyn_filters, - })) + if aggr_dyn_filters.is_empty() { + None + } else { + Some(aggr_dyn_filters) } } @@ -1830,6 +1869,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf}; use datafusion_functions_aggregate::median::median_udaf; + use datafusion_functions_aggregate::min_max::min_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::PhysicalSortExpr; @@ -3328,13 +3368,10 @@ mod tests { // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). let aggregates: Vec> = vec![ Arc::new( - AggregateExprBuilder::new( - datafusion_functions_aggregate::min_max::min_udaf(), - vec![col("b", &schema)?], - ) - .schema(Arc::clone(&schema)) - .alias("MIN(b)") - .build()?, + AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, ), Arc::new( AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) @@ -3473,13 +3510,10 @@ mod tests { // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). let aggregates: Vec> = vec![ Arc::new( - AggregateExprBuilder::new( - datafusion_functions_aggregate::min_max::min_udaf(), - vec![col("b", &schema)?], - ) - .schema(Arc::clone(&schema)) - .alias("MIN(b)") - .build()?, + AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, ), Arc::new( AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index a55d70ca6fb27..e9cb33833b0a9 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -29,7 +29,7 @@ use datafusion_common::{Result, ScalarValue, internal_datafusion_err, internal_e use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::expressions::{BinaryExpr, lit}; +use datafusion_physical_expr::expressions::{BinaryExpr, DynamicFilterUpdate, lit}; use futures::stream::BoxStream; use std::borrow::Cow; use std::cmp::Ordering; @@ -188,7 +188,9 @@ impl AggregateStreamInner { // Step 2: Sync the dynamic filter physical expression with reader let predicate = self.build_dynamic_filter_from_accumulator_bounds()?; - filter_state.filter.update(predicate)?; + filter_state + .filter + .update(DynamicFilterUpdate::Global(predicate))?; Ok(()) } diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 91fc1ee4436ee..128db69d54c63 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -92,6 +92,15 @@ use super::partitioned_hash_eval::SeededRandomState; pub(crate) const HASH_JOIN_SEED: SeededRandomState = SeededRandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); +/// Routing mode for partitioned dynamic filters in hash joins. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DynamicFilterRoutingMode { + /// Route probe rows by hash (CASE expression). + CaseHash, + /// Route by partition index (`i -> i`). + PartitionIndex, +} + /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` @@ -340,6 +349,8 @@ pub struct HashJoinExec { random_state: SeededRandomState, /// Partitioning mode to use pub mode: PartitionMode, + /// Optimizer-selected dynamic filter routing mode. + pub dynamic_filter_routing_mode: DynamicFilterRoutingMode, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// The projection indices of the columns in the output schema of join @@ -377,6 +388,10 @@ impl fmt::Debug for HashJoinExec { .field("left_fut", &self.left_fut) .field("random_state", &self.random_state) .field("mode", &self.mode) + .field( + "dynamic_filter_routing_mode", + &self.dynamic_filter_routing_mode, + ) .field("metrics", &self.metrics) .field("projection", &self.projection) .field("column_indices", &self.column_indices) @@ -450,6 +465,7 @@ impl HashJoinExec { left_fut: Default::default(), random_state, mode: partition_mode, + dynamic_filter_routing_mode: DynamicFilterRoutingMode::CaseHash, metrics: ExecutionPlanMetricsSet::new(), projection, column_indices, @@ -503,20 +519,58 @@ impl HashJoinExec { &self.mode } + /// Returns the optimizer-selected dynamic filter routing mode. + pub fn dynamic_filter_routing_mode(&self) -> DynamicFilterRoutingMode { + self.dynamic_filter_routing_mode + } + + /// Returns a new [`HashJoinExec`] with the given dynamic filter routing mode. + pub fn with_dynamic_filter_routing_mode( + mut self, + mode: DynamicFilterRoutingMode, + ) -> Self { + self.dynamic_filter_routing_mode = mode; + self + } + /// Get null_equality pub fn null_equality(&self) -> NullEquality { self.null_equality } /// Get the dynamic filter expression for testing purposes. - /// Returns `None` if no dynamic filter has been set. + /// Returns the dynamic filter expression for this hash join, if set. + pub fn dynamic_filter(&self) -> Option<&Arc> { + self.dynamic_filter.as_ref().map(|df| &df.filter) + } + + fn should_use_partition_index(&self) -> bool { + matches!(self.mode, PartitionMode::Partitioned) + && self.dynamic_filter_routing_mode + == DynamicFilterRoutingMode::PartitionIndex + } + + /// Set the dynamic filter on this hash join. + /// + /// Resets any internal state that depends on any previous dynamic filter. /// - /// This method is intended for testing only and should not be used in production code. - #[doc(hidden)] - pub fn dynamic_filter_for_test(&self) -> Option> { - self.dynamic_filter - .as_ref() - .map(|df| Arc::clone(&df.filter)) + /// Validates that the filter's children reference valid columns in + /// the probe (right) side's schema. + pub fn with_dynamic_filter( + mut self, + filter: Arc, + ) -> Result { + let probe_schema = self.right.schema(); + for child in filter.children() { + child.data_type(&probe_schema)?; + } + self.dynamic_filter = Some(HashJoinExecDynamicFilter { + filter, + // Initialize with an empty accumulator which will be lazily populated + // during execution. + build_accumulator: OnceLock::new(), + }); + Ok(self) } /// Calculate order preservation flags for this hash join. @@ -556,7 +610,7 @@ impl HashJoinExec { }, None => None, }; - Self::try_new( + let new_join = Self::try_new( Arc::clone(&self.left), Arc::clone(&self.right), self.on.clone(), @@ -565,7 +619,9 @@ impl HashJoinExec { projection, self.mode, self.null_equality, - ) + )? + .with_dynamic_filter_routing_mode(self.dynamic_filter_routing_mode); + Ok(new_join) } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -688,7 +744,8 @@ impl HashJoinExec { ), partition_mode, self.null_equality(), - )?; + )? + .with_dynamic_filter_routing_mode(self.dynamic_filter_routing_mode); // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( self.join_type(), @@ -866,6 +923,7 @@ impl ExecutionPlan for HashJoinExec { left_fut: Arc::clone(&self.left_fut), random_state: self.random_state.clone(), mode: self.mode, + dynamic_filter_routing_mode: self.dynamic_filter_routing_mode, metrics: ExecutionPlanMetricsSet::new(), projection: self.projection.clone(), column_indices: self.column_indices.clone(), @@ -896,6 +954,7 @@ impl ExecutionPlan for HashJoinExec { left_fut: Arc::new(OnceAsync::default()), random_state: self.random_state.clone(), mode: self.mode, + dynamic_filter_routing_mode: self.dynamic_filter_routing_mode, metrics: ExecutionPlanMetricsSet::new(), projection: self.projection.clone(), column_indices: self.column_indices.clone(), @@ -1018,6 +1077,11 @@ impl ExecutionPlan for HashJoinExec { // Initialize build_accumulator lazily with runtime partition counts (only if enabled) // Use RepartitionExec's random state (seeds: 0,0,0,0) for partition routing let repartition_random_state = REPARTITION_RANDOM_STATE; + let dynamic_filter_routing = if self.should_use_partition_index() { + DynamicFilterRoutingMode::PartitionIndex + } else { + DynamicFilterRoutingMode::CaseHash + }; let build_accumulator = enable_dynamic_filter_pushdown .then(|| { self.dynamic_filter.as_ref().map(|df| { @@ -1035,6 +1099,7 @@ impl ExecutionPlan for HashJoinExec { filter, on_right, repartition_random_state, + dynamic_filter_routing, )) }))) }) @@ -1134,7 +1199,7 @@ impl ExecutionPlan for HashJoinExec { &schema, self.filter(), )? { - Ok(Some(Arc::new(HashJoinExec::try_new( + let new_join = HashJoinExec::try_new( Arc::new(projected_left_child), Arc::new(projected_right_child), join_on, @@ -1144,7 +1209,9 @@ impl ExecutionPlan for HashJoinExec { None, *self.partition_mode(), self.null_equality, - )?))) + )? + .with_dynamic_filter_routing_mode(self.dynamic_filter_routing_mode); + Ok(Some(Arc::new(new_join))) } else { try_embed_projection(projection, self) } @@ -1232,6 +1299,7 @@ impl ExecutionPlan for HashJoinExec { left_fut: Arc::clone(&self.left_fut), random_state: self.random_state.clone(), mode: self.mode, + dynamic_filter_routing_mode: self.dynamic_filter_routing_mode, metrics: ExecutionPlanMetricsSet::new(), projection: self.projection.clone(), column_indices: self.column_indices.clone(), @@ -1617,7 +1685,6 @@ mod tests { let schema = batch.schema(); TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } - fn join( left: Arc, right: Arc, @@ -4635,11 +4702,6 @@ mod tests { let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); let dynamic_filter_clone = Arc::clone(&dynamic_filter); - // Simulate a consumer by creating a transformed copy (what happens during filter pushdown) - let _consumer = Arc::clone(&dynamic_filter) - .with_new_children(vec![]) - .unwrap(); - // Create HashJoinExec with the dynamic filter let mut join = HashJoinExec::try_new( left, @@ -4688,11 +4750,6 @@ mod tests { let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); let dynamic_filter_clone = Arc::clone(&dynamic_filter); - // Simulate a consumer by creating a transformed copy (what happens during filter pushdown) - let _consumer = Arc::clone(&dynamic_filter) - .with_new_children(vec![]) - .unwrap(); - // Create HashJoinExec with the dynamic filter let mut join = HashJoinExec::try_new( left, diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 8592e1d968535..2f04fca8a3fc9 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -17,7 +17,7 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator -pub use exec::HashJoinExec; +pub use exec::{DynamicFilterRoutingMode, HashJoinExec}; pub use partitioned_hash_eval::{HashExpr, HashTableLookupExpr, SeededRandomState}; mod exec; diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 7d34ce9acbd57..aae2f0ca2035c 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::ExecutionPlan; use crate::ExecutionPlanProperties; use crate::joins::PartitionMode; -use crate::joins::hash_join::exec::HASH_JOIN_SEED; +use crate::joins::hash_join::exec::{DynamicFilterRoutingMode, HASH_JOIN_SEED}; use crate::joins::hash_join::inlist_builder::build_struct_fields; use crate::joins::hash_join::partitioned_hash_eval::{ HashExpr, HashTableLookupExpr, SeededRandomState, @@ -37,7 +37,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_functions::core::r#struct as struct_func; use datafusion_physical_expr::expressions::{ - BinaryExpr, CaseExpr, DynamicFilterPhysicalExpr, InListExpr, lit, + BinaryExpr, CaseExpr, DynamicFilterPhysicalExpr, DynamicFilterUpdate, InListExpr, lit, }; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, ScalarFunctionExpr}; @@ -87,7 +87,7 @@ impl PartitionBounds { /// Supports both single-column and multi-column joins using struct expressions. fn create_membership_predicate( on_right: &[PhysicalExprRef], - pushdown: PushdownStrategy, + pushdown: &PushdownStrategy, random_state: &SeededRandomState, schema: &Schema, ) -> Result>> { @@ -123,7 +123,7 @@ fn create_membership_predicate( // Use in_list_from_array() helper to create InList with static_filter optimization (hash-based lookup) Ok(Some(Arc::new(InListExpr::try_new_from_array( expr, - in_list_array, + Arc::clone(in_list_array), false, )?))) } @@ -137,7 +137,7 @@ fn create_membership_predicate( Ok(Some(Arc::new(HashTableLookupExpr::new( lookup_hash_expr, - hash_map, + Arc::clone(hash_map), "hash_lookup".to_string(), )) as Arc)) } @@ -224,6 +224,8 @@ pub(crate) struct SharedBuildAccumulator { /// Build-side data protected by a single mutex to avoid ordering concerns inner: Mutex, barrier: Barrier, + /// How to route partitioned dynamic filters. + routing: DynamicFilterRoutingMode, /// Dynamic filter for pushdown to probe side dynamic_filter: Arc, /// Right side join expressions needed for creating filter expressions @@ -277,6 +279,41 @@ enum AccumulatedBuildData { } impl SharedBuildAccumulator { + fn combine_bounds_and_membership( + membership_expr: Option>, + bounds_expr: Option>, + ) -> Option> { + match (membership_expr, bounds_expr) { + (Some(membership), Some(bounds)) => { + Some(Arc::new(BinaryExpr::new(bounds, Operator::And, membership)) + as Arc) + } + (Some(membership), None) => Some(membership), + (None, Some(bounds)) => Some(bounds), + (None, None) => None, + } + } + + fn build_partition_filter_expr( + &self, + pushdown: &PushdownStrategy, + bounds: &PartitionBounds, + ) -> Result>> { + let membership_expr = create_membership_predicate( + &self.on_right, + pushdown, + &HASH_JOIN_SEED, + self.probe_schema.as_ref(), + )?; + + let bounds_expr = create_bounds_predicate(&self.on_right, bounds); + + Ok(Self::combine_bounds_and_membership( + membership_expr, + bounds_expr, + )) + } + /// Creates a new SharedBuildAccumulator configured for the given partition mode /// /// This method calculates how many times `collect_build_side` will be called based on the @@ -309,6 +346,7 @@ impl SharedBuildAccumulator { dynamic_filter: Arc, on_right: Vec, repartition_random_state: SeededRandomState, + routing: DynamicFilterRoutingMode, ) -> Self { // Troubleshooting: If partition counts are incorrect, verify this logic matches // the actual execution pattern in collect_build_side() @@ -345,6 +383,7 @@ impl SharedBuildAccumulator { Self { inner: Mutex::new(mode_data), barrier: Barrier::new(expected_calls), + routing, dynamic_filter, on_right, repartition_random_state, @@ -409,20 +448,6 @@ impl SharedBuildAccumulator { // CollectLeft: Simple conjunction of bounds and membership check AccumulatedBuildData::CollectLeft { data } => { if let Some(partition_data) = data { - // Create membership predicate (InList for small build sides, hash lookup otherwise) - let membership_expr = create_membership_predicate( - &self.on_right, - partition_data.pushdown.clone(), - &HASH_JOIN_SEED, - self.probe_schema.as_ref(), - )?; - - // Create bounds check expression (if bounds available) - let bounds_expr = create_bounds_predicate( - &self.on_right, - &partition_data.bounds, - ); - // Combine membership and bounds expressions for multi-layer optimization: // - Bounds (min/max): Enable statistics-based pruning (Parquet row group/file skipping) // - Membership (InList/hash lookup): Enables: @@ -430,37 +455,12 @@ impl SharedBuildAccumulator { // * Bloom filter utilization (if present in Parquet files) // * Better pruning for data types where min/max isn't effective (e.g., UUIDs) // Together, they provide complementary benefits and maximize data skipping. - // Only update the filter if we have something to push down - if let Some(filter_expr) = match (membership_expr, bounds_expr) { - (Some(membership), Some(bounds)) => { - // Both available: combine with AND - Some(Arc::new(BinaryExpr::new( - bounds, - Operator::And, - membership, - )) - as Arc) - } - (Some(membership), None) => { - // Membership available but no bounds - // This is reachable when we have data but bounds aren't available - // (e.g., unsupported data types or no columns with bounds) - Some(membership) - } - (None, Some(bounds)) => { - // Bounds available but no membership. - // This should be unreachable in practice: we can always push down a reference - // to the hash table. - // But it seems safer to handle it defensively. - Some(bounds) - } - (None, None) => { - // No filter available (e.g., empty build side) - // Don't update the filter, but continue to mark complete - None - } - } { - self.dynamic_filter.update(filter_expr)?; + if let Some(filter_expr) = self.build_partition_filter_expr( + &partition_data.pushdown, + &partition_data.bounds, + )? { + self.dynamic_filter + .update(DynamicFilterUpdate::Global(filter_expr))?; } } } @@ -471,119 +471,121 @@ impl SharedBuildAccumulator { partitions.iter().filter_map(|p| p.as_ref()).collect(); if !partition_data.is_empty() { - // Build a CASE expression that combines range checks AND membership checks - // CASE (hash_repartition(join_keys) % num_partitions) - // WHEN 0 THEN (col >= min_0 AND col <= max_0 AND ...) AND membership_check_0 - // WHEN 1 THEN (col >= min_1 AND col <= max_1 AND ...) AND membership_check_1 - // ... - // ELSE false - // END - - let num_partitions = partition_data.len(); - - // Create base expression: hash_repartition(join_keys) % num_partitions - let routing_hash_expr = Arc::new(HashExpr::new( - self.on_right.clone(), - self.repartition_random_state.clone(), - "hash_repartition".to_string(), - )) - as Arc; - - let modulo_expr = Arc::new(BinaryExpr::new( - routing_hash_expr, - Operator::Modulo, - lit(ScalarValue::UInt64(Some(num_partitions as u64))), - )) - as Arc; - - // Create WHEN branches for each partition - let when_then_branches: Vec<( - Arc, - Arc, - )> = partitions - .iter() - .enumerate() - .filter_map(|(partition_id, partition_opt)| { - partition_opt.as_ref().and_then(|partition| { - // Skip empty partitions - they would always return false anyway - match &partition.pushdown { - PushdownStrategy::Empty => None, - _ => Some((partition_id, partition)), - } - }) - }) - .map(|(partition_id, partition)| -> Result<_> { - // WHEN partition_id - let when_expr = - lit(ScalarValue::UInt64(Some(partition_id as u64))); - - // THEN: Combine bounds check AND membership predicate - - // 1. Create membership predicate (InList for small build sides, hash lookup otherwise) - let membership_expr = create_membership_predicate( - &self.on_right, - partition.pushdown.clone(), - &HASH_JOIN_SEED, - self.probe_schema.as_ref(), - )?; + match self.routing { + DynamicFilterRoutingMode::CaseHash => { + // Build a CASE expression that combines range checks AND membership checks + // CASE (hash_repartition(join_keys) % num_partitions) + // WHEN 0 THEN (col >= min_0 AND col <= max_0 AND ...) AND membership_check_0 + // WHEN 1 THEN (col >= min_1 AND col <= max_1 AND ...) AND membership_check_1 + // ... + // ELSE false + // END + + let num_partitions = partition_data.len(); + + // Create base expression: hash_repartition(join_keys) % num_partitions + let routing_hash_expr = Arc::new(HashExpr::new( + self.on_right.clone(), + self.repartition_random_state.clone(), + "hash_repartition".to_string(), + )) + as Arc; - // 2. Create bounds check expression for this partition (if bounds available) - let bounds_expr = create_bounds_predicate( - &self.on_right, - &partition.bounds, - ); - - // 3. Combine membership and bounds expressions - let then_expr = match (membership_expr, bounds_expr) { - (Some(membership), Some(bounds)) => { - // Both available: combine with AND - Arc::new(BinaryExpr::new( - bounds, - Operator::And, - membership, - )) - as Arc - } - (Some(membership), None) => { - // Membership available but no bounds (e.g., unsupported data types) - membership - } - (None, Some(bounds)) => { - // Bounds available but no membership. - // This should be unreachable in practice: we can always push down a reference - // to the hash table. - // But it seems safer to handle it defensively. - bounds - } - (None, None) => { - // No filter for this partition - should not happen due to filter_map above - // but handle defensively by returning a "true" literal - lit(true) - } + let modulo_expr = Arc::new(BinaryExpr::new( + routing_hash_expr, + Operator::Modulo, + lit(ScalarValue::UInt64(Some(num_partitions as u64))), + )) + as Arc; + + // Create WHEN branches for each partition + let when_then_branches: Vec<( + Arc, + Arc, + )> = partitions + .iter() + .enumerate() + .filter_map(|(partition_id, partition_opt)| { + partition_opt.as_ref().and_then(|partition| { + // Skip empty partitions - they would always return false anyway + match &partition.pushdown { + PushdownStrategy::Empty => None, + _ => Some((partition_id, partition)), + } + }) + }) + .filter_map(|(partition_id, partition)| { + match self.build_partition_filter_expr( + &partition.pushdown, + &partition.bounds, + ) { + Ok(Some(filter_expr)) => Some(Ok(( + lit(ScalarValue::UInt64(Some( + partition_id as u64, + ))), + filter_expr, + ))), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + }) + .collect::>>()?; + + // Optimize for single partition: skip CASE expression entirely + let filter_expr = if when_then_branches.is_empty() { + // All partitions are empty: no rows can match + lit(false) + } else if when_then_branches.len() == 1 { + // Single partition: just use the condition directly + // since hash % 1 == 0 always, the WHEN 0 branch will always match + Arc::clone(&when_then_branches[0].1) + } else { + // Multiple partitions: create CASE expression + Arc::new(CaseExpr::try_new( + Some(modulo_expr), + when_then_branches, + Some(lit(false)), // ELSE false + )?) + as Arc }; - Ok((when_expr, then_expr)) - }) - .collect::>>()?; - - // Optimize for single partition: skip CASE expression entirely - let filter_expr = if when_then_branches.is_empty() { - // All partitions are empty: no rows can match - lit(false) - } else if when_then_branches.len() == 1 { - // Single partition: just use the condition directly - // since hash % 1 == 0 always, the WHEN 0 branch will always match - Arc::clone(&when_then_branches[0].1) - } else { - // Multiple partitions: create CASE expression - Arc::new(CaseExpr::try_new( - Some(modulo_expr), - when_then_branches, - Some(lit(false)), // ELSE false - )?) as Arc - }; - - self.dynamic_filter.update(filter_expr)?; + self.dynamic_filter + .update(DynamicFilterUpdate::Global(filter_expr))?; + } + DynamicFilterRoutingMode::PartitionIndex => { + let mut partition_filters: Vec< + Option>, + > = vec![None; partitions.len()]; + + for (partition_id, partition) in + partitions.iter().enumerate().filter_map(|(i, p)| { + p.as_ref().map(|partition| (i, partition)) + }) + { + if matches!( + partition.pushdown, + PushdownStrategy::Empty + ) { + continue; + } + + let filter_expr = self + .build_partition_filter_expr( + &partition.pushdown, + &partition.bounds, + )? + // Defensive fallback: if no partition-local + // predicate can be built, keep this partition + // fail-open. + .unwrap_or_else(|| lit(true)); + partition_filters[partition_id] = Some(filter_expr); + } + + self.dynamic_filter.update( + DynamicFilterUpdate::Partitioned(partition_filters), + )?; + } + } } } } @@ -599,3 +601,243 @@ impl fmt::Debug for SharedBuildAccumulator { write!(f, "SharedBuildAccumulator") } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::TestMemoryExec; + use arrow::array::{ArrayRef, Int32Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_physical_expr::expressions::{ + Column, DynamicFilterRuntimeContext, Literal, + }; + use datafusion_physical_expr_common::physical_expr::bind_runtime_physical_expr; + + fn make_exec(schema: Arc, partitions: usize) -> Arc { + let batch = RecordBatch::new_empty(Arc::clone(&schema)); + let mut partitioned_batches = Vec::with_capacity(partitions); + for _ in 0..partitions { + partitioned_batches.push(vec![batch.clone()]); + } + TestMemoryExec::try_new_exec(&partitioned_batches, schema, None).unwrap() + } + + fn make_bounds(min: i32, max: i32) -> PartitionBounds { + PartitionBounds::new(vec![ColumnBounds::new( + ScalarValue::Int32(Some(min)), + ScalarValue::Int32(Some(max)), + )]) + } + + fn make_in_list(values: &[i32]) -> PushdownStrategy { + let array: ArrayRef = Arc::new(Int32Array::from(values.to_vec())); + PushdownStrategy::InList(array) + } + + fn contains_hash_expr(expr: &Arc) -> bool { + if expr.as_any().downcast_ref::().is_some() { + return true; + } + expr.children() + .iter() + .any(|child| contains_hash_expr(child)) + } + + fn contains_case_expr(expr: &Arc) -> bool { + if expr.as_any().downcast_ref::().is_some() { + return true; + } + expr.children() + .iter() + .any(|child| contains_case_expr(child)) + } + + fn is_literal_true(expr: &Arc) -> bool { + if let Some(literal) = expr.as_any().downcast_ref::() { + matches!(literal.value(), ScalarValue::Boolean(Some(true))) + } else { + false + } + } + + #[tokio::test] + async fn partitioned_dynamic_filter_uses_hash_routing_when_enabled() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let left = make_exec(Arc::clone(&schema), 2); + let right = make_exec(Arc::clone(&schema), 2); + let on_right: Vec = + vec![Arc::new(Column::new_with_schema("b", &schema)?)]; + let dynamic_filter = + Arc::new(DynamicFilterPhysicalExpr::new(on_right.clone(), lit(true))); + let accumulator = SharedBuildAccumulator::new_from_partition_mode( + PartitionMode::Partitioned, + left.as_ref(), + right.as_ref(), + Arc::clone(&dynamic_filter), + on_right, + SeededRandomState::with_seeds(0, 0, 0, 0), + DynamicFilterRoutingMode::CaseHash, + ); + + tokio::try_join!( + accumulator.report_build_data(PartitionBuildData::Partitioned { + partition_id: 0, + pushdown: make_in_list(&[1, 2]), + bounds: make_bounds(1, 2), + }), + accumulator.report_build_data(PartitionBuildData::Partitioned { + partition_id: 1, + pushdown: make_in_list(&[10, 11]), + bounds: make_bounds(10, 11), + }) + )?; + + let expr = dynamic_filter.current()?; + assert!( + contains_hash_expr(&expr), + "expected hash routing expression" + ); + Ok(()) + } + + #[tokio::test] + async fn collect_left_dynamic_filter_never_uses_hash_routing() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let left = make_exec(Arc::clone(&schema), 1); + let right = make_exec(Arc::clone(&schema), 2); + let on_right: Vec = + vec![Arc::new(Column::new_with_schema("b", &schema)?)]; + let dynamic_filter = + Arc::new(DynamicFilterPhysicalExpr::new(on_right.clone(), lit(true))); + let accumulator = SharedBuildAccumulator::new_from_partition_mode( + PartitionMode::CollectLeft, + left.as_ref(), + right.as_ref(), + Arc::clone(&dynamic_filter), + on_right, + SeededRandomState::with_seeds(0, 0, 0, 0), + DynamicFilterRoutingMode::CaseHash, + ); + + tokio::try_join!( + accumulator.report_build_data(PartitionBuildData::CollectLeft { + pushdown: make_in_list(&[1, 2]), + bounds: make_bounds(1, 2), + }), + accumulator.report_build_data(PartitionBuildData::CollectLeft { + pushdown: make_in_list(&[1, 2]), + bounds: make_bounds(1, 2), + }) + )?; + + let expr = dynamic_filter.current()?; + assert!( + !contains_hash_expr(&expr), + "collect-left should not introduce hash routing" + ); + Ok(()) + } + + #[tokio::test] + async fn partitioned_dynamic_filter_or_path_ignores_empty_partitions() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let left = make_exec(Arc::clone(&schema), 2); + let right = make_exec(Arc::clone(&schema), 2); + let on_right: Vec = + vec![Arc::new(Column::new_with_schema("b", &schema)?)]; + let dynamic_filter = + Arc::new(DynamicFilterPhysicalExpr::new(on_right.clone(), lit(true))); + let accumulator = SharedBuildAccumulator::new_from_partition_mode( + PartitionMode::Partitioned, + left.as_ref(), + right.as_ref(), + Arc::clone(&dynamic_filter), + on_right, + SeededRandomState::with_seeds(0, 0, 0, 0), + DynamicFilterRoutingMode::PartitionIndex, + ); + + tokio::try_join!( + accumulator.report_build_data(PartitionBuildData::Partitioned { + partition_id: 0, + pushdown: PushdownStrategy::Empty, + bounds: make_bounds(0, 0), + }), + accumulator.report_build_data(PartitionBuildData::Partitioned { + partition_id: 1, + pushdown: make_in_list(&[10, 11]), + bounds: make_bounds(10, 11), + }) + )?; + + let runtime_ctx = DynamicFilterRuntimeContext::for_partition(1); + let bound = bind_runtime_physical_expr( + Arc::clone(&dynamic_filter) as Arc, + &runtime_ctx, + )?; + let per_partition = bound.snapshot()?.expect("expected snapshot"); + assert!( + !contains_hash_expr(&per_partition), + "partition-index routing should not introduce hash routing" + ); + assert!( + !contains_case_expr(&per_partition), + "partition-index routing should not introduce CASE routing" + ); + assert!( + !is_literal_true(&per_partition), + "expected a concrete per-partition filter" + ); + Ok(()) + } + + #[tokio::test] + async fn partitioned_dynamic_filter_single_non_empty_skips_case_expr() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let left = make_exec(Arc::clone(&schema), 2); + let right = make_exec(Arc::clone(&schema), 2); + let on_right: Vec = + vec![Arc::new(Column::new_with_schema("b", &schema)?)]; + let dynamic_filter = + Arc::new(DynamicFilterPhysicalExpr::new(on_right.clone(), lit(true))); + let accumulator = SharedBuildAccumulator::new_from_partition_mode( + PartitionMode::Partitioned, + left.as_ref(), + right.as_ref(), + Arc::clone(&dynamic_filter), + on_right, + SeededRandomState::with_seeds(0, 0, 0, 0), + DynamicFilterRoutingMode::PartitionIndex, + ); + + tokio::try_join!( + accumulator.report_build_data(PartitionBuildData::Partitioned { + partition_id: 0, + pushdown: make_in_list(&[1, 2]), + bounds: make_bounds(1, 2), + }), + accumulator.report_build_data(PartitionBuildData::Partitioned { + partition_id: 1, + pushdown: PushdownStrategy::Empty, + bounds: make_bounds(0, 0), + }) + )?; + + let runtime_ctx = DynamicFilterRuntimeContext::for_partition(0); + let bound = bind_runtime_physical_expr( + Arc::clone(&dynamic_filter) as Arc, + &runtime_ctx, + )?; + let expr = bound.snapshot()?.expect("expected snapshot"); + assert!( + !contains_hash_expr(&expr), + "single non-empty partition should not need hash routing" + ); + assert!( + !contains_case_expr(&expr), + "single non-empty partition should not need CASE routing" + ); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 3ff61ecf1dacc..5f30972591627 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,7 +20,10 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; -pub use hash_join::{HashExpr, HashJoinExec, HashTableLookupExpr, SeededRandomState}; +pub use hash_join::{ + DynamicFilterRoutingMode, HashExpr, HashJoinExec, HashTableLookupExpr, + SeededRandomState, +}; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 3e8fdf1f3ed7e..8a579dc8bf274 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1071,6 +1071,30 @@ impl SortExec { self.fetch } + /// Returns the dynamic filter expression for this sort (TopK), if set. + pub fn dynamic_filter(&self) -> Option> { + self.filter.as_ref().map(|f| f.read().expr()) + } + + /// Replace the dynamic filter expression for this sort. + /// + /// + /// Resets any internal state which may depend on the previous dynamic filter. + /// + /// Validates that the filter's children reference valid columns in + /// the sort's input schema. + pub fn with_dynamic_filter( + mut self, + filter: Arc, + ) -> Result { + let input_schema = self.input.schema(); + for child in filter.children() { + child.data_type(&input_schema)?; + } + self.filter = Some(Arc::new(RwLock::new(TopKDynamicFilters::new(filter)))); + Ok(self) + } + fn output_partitioning_helper( input: &Arc, preserve_partitioning: bool, @@ -1440,6 +1464,7 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::collect; + use crate::empty::EmptyExec; use crate::execution_plan::Boundedness; use crate::expressions::col; use crate::test; @@ -2715,4 +2740,62 @@ mod tests { Ok(()) } + + #[test] + fn test_with_dynamic_filter() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let sort = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)); + + // SortExec with fetch creates a dynamic filter automatically. + let original_df = sort + .dynamic_filter() + .expect("should have dynamic filter with fetch"); + + // with_dynamic_filter replaces it with a new TopKDynamicFilters. + let new_df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as _], + lit(true), + )); + let sort = sort.with_dynamic_filter(Arc::clone(&new_df))?; + let restored = sort + .dynamic_filter() + .expect("should still have dynamic filter"); + assert_eq!(restored.inner_id(), new_df.inner_id()); + assert_ne!(restored.inner_id(), original_df.inner_id()); + Ok(()) + } + + #[test] + fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let sort = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)); + + // Column index 99 is out of bounds for the input schema. + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("bad", 99)) as _], + lit(true), + )); + assert!(sort.with_dynamic_filter(df).is_err()); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index ebac497f4fbc3..9c66aadd44b3c 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -43,7 +43,10 @@ use datafusion_execution::{ }; use datafusion_physical_expr::{ PhysicalExpr, - expressions::{BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit}, + expressions::{ + BinaryExpr, DynamicFilterPhysicalExpr, DynamicFilterUpdate, is_not_null, is_null, + lit, + }, }; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use parking_lot::RwLock; @@ -412,7 +415,7 @@ impl TopK { if let Some(pred) = predicate && !pred.eq(&lit(true)) { - filter.expr.update(pred)?; + filter.expr.update(DynamicFilterUpdate::Global(pred))?; } Ok(()) diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index b00bd0dcc6bfd..313bc79870f54 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -69,6 +69,7 @@ datafusion-proto-common = { workspace = true } object_store = { workspace = true } pbjson = { workspace = true, optional = true } prost = { workspace = true } +rand = { workspace = true } serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd7dd3a6aff3c..ac57f16c4f103 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -837,6 +837,20 @@ message PhysicalExprNode { // Was date_time_interval_expr reserved 17; + // Unique identifier for this expression to do deduplication during deserialization. + // When serializing, this is set to a unique identifier for each combination of + // expression, process and serialization run. + // When deserializing, if this ID has been seen before, the cached Arc is returned + // instead of creating a new one, enabling reconstruction of referential integrity + // across serde roundtrips. + optional uint64 expr_id = 30; + + // For DynamicFilterPhysicalExpr, this identifies the shared inner state. + // Multiple expressions may have different expr_id values (different outer Arc wrappers) + // but the same dynamic_filter_inner_id (shared inner state). + // Used to reconstruct shared inner state during deserialization. + optional uint64 dynamic_filter_inner_id = 31; + oneof ExprType { // column references PhysicalColumn column = 1; @@ -874,9 +888,19 @@ message PhysicalExprNode { UnknownColumn unknown_column = 20; PhysicalHashExprNode hash_expr = 21; + + PhysicalDynamicFilterNode dynamic_filter = 22; } } +message PhysicalDynamicFilterNode { + repeated PhysicalExprNode children = 1; + repeated PhysicalExprNode remapped_children = 2; + uint64 generation = 3; + PhysicalExprNode inner_expr = 4; + bool is_complete = 5; +} + message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; @@ -1102,6 +1126,11 @@ enum PartitionMode { AUTO = 2; } +enum HashJoinDynamicFilterRoutingMode { + CASE_HASH = 0; + PARTITION_INDEX = 1; +} + message HashJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; @@ -1111,6 +1140,10 @@ message HashJoinExecNode { datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated uint32 projection = 9; + // Optional dynamic filter expression for pushing down to the probe side. + PhysicalExprNode dynamic_filter = 10; + // Selected routing strategy for partitioned dynamic filter expressions. + HashJoinDynamicFilterRoutingMode dynamic_filter_routing_mode = 11; } enum StreamPartitionMode { @@ -1235,6 +1268,8 @@ message AggregateExecNode { repeated MaybeFilter filter_expr = 10; AggLimit limit = 11; bool has_grouping_set = 12; + // Optional dynamic filter expression for pushing down to the child. + PhysicalExprNode dynamic_filter = 13; } message GlobalLimitExecNode { @@ -1256,6 +1291,8 @@ message SortExecNode { // Maximum number of highest/lowest rows to fetch; negative means no limit int64 fetch = 3; bool preserve_partitioning = 4; + // Optional dynamic filter expression for TopK pushdown. + PhysicalExprNode dynamic_filter = 5; } message SortPreservingMergeExecNode { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d95bdd388699e..84b15ea9a8920 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -21,7 +21,8 @@ use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use crate::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + DefaultPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, }; use crate::protobuf; use datafusion_common::{Result, plan_datafusion_err}; @@ -276,16 +277,18 @@ pub fn logical_plan_from_json_with_extension_codec( /// Serialize a PhysicalPlan as bytes pub fn physical_plan_to_bytes(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_to_bytes_with_extension_codec(plan, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, &extension_codec, &proto_converter) } /// Serialize a PhysicalPlan as JSON #[cfg(feature = "json")] pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &extension_codec) - .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; + let proto_converter = DefaultPhysicalProtoConverter {}; + let protobuf = proto_converter + .execution_plan_to_proto(&plan, &extension_codec) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } @@ -295,8 +298,18 @@ pub fn physical_plan_to_bytes_with_extension_codec( plan: Arc, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, extension_codec, &proto_converter) +} + +/// Serialize a PhysicalPlan as bytes, using the provided extension codec +/// and protobuf converter. +pub fn physical_plan_to_bytes_with_proto_converter( + plan: Arc, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + let protobuf = proto_converter.execution_plan_to_proto(&plan, extension_codec)?; let mut buffer = BytesMut::new(); protobuf .encode(&mut buffer) @@ -313,7 +326,8 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - back.try_into_physical_plan(ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + proto_converter.proto_to_execution_plan(ctx, &extension_codec, &back) } /// Deserialize a PhysicalPlan from bytes @@ -322,7 +336,13 @@ pub fn physical_plan_from_bytes( ctx: &TaskContext, ) -> Result> { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + &extension_codec, + &proto_converter, + ) } /// Deserialize a PhysicalPlan from bytes @@ -330,8 +350,24 @@ pub fn physical_plan_from_bytes_with_extension_codec( bytes: &[u8], ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, +) -> Result> { + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + extension_codec, + &proto_converter, + ) +} + +/// Deserialize a PhysicalPlan from bytes +pub fn physical_plan_from_bytes_with_proto_converter( + bytes: &[u8], + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let protobuf = protobuf::PhysicalPlanNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - protobuf.try_into_physical_plan(ctx, extension_codec) + proto_converter.proto_to_execution_plan(ctx, extension_codec, &protobuf) } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e269606d163a3..0ae35e0478e77 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -137,6 +137,9 @@ impl serde::Serialize for AggregateExecNode { if self.has_grouping_set { len += 1; } + if self.dynamic_filter.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -176,6 +179,9 @@ impl serde::Serialize for AggregateExecNode { if self.has_grouping_set { struct_ser.serialize_field("hasGroupingSet", &self.has_grouping_set)?; } + if let Some(v) = self.dynamic_filter.as_ref() { + struct_ser.serialize_field("dynamicFilter", v)?; + } struct_ser.end() } } @@ -206,6 +212,8 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "limit", "has_grouping_set", "hasGroupingSet", + "dynamic_filter", + "dynamicFilter", ]; #[allow(clippy::enum_variant_names)] @@ -222,6 +230,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { FilterExpr, Limit, HasGroupingSet, + DynamicFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -255,6 +264,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), "limit" => Ok(GeneratedField::Limit), "hasGroupingSet" | "has_grouping_set" => Ok(GeneratedField::HasGroupingSet), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -286,6 +296,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut filter_expr__ = None; let mut limit__ = None; let mut has_grouping_set__ = None; + let mut dynamic_filter__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { @@ -360,6 +371,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } has_grouping_set__ = Some(map_.next_value()?); } + GeneratedField::DynamicFilter => { + if dynamic_filter__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + dynamic_filter__ = map_.next_value()?; + } } } Ok(AggregateExecNode { @@ -375,6 +392,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { filter_expr: filter_expr__.unwrap_or_default(), limit: limit__, has_grouping_set: has_grouping_set__.unwrap_or_default(), + dynamic_filter: dynamic_filter__, }) } } @@ -8009,6 +8027,77 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { deserializer.deserialize_struct("datafusion.GroupingSetNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for HashJoinDynamicFilterRoutingMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::CaseHash => "CASE_HASH", + Self::PartitionIndex => "PARTITION_INDEX", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for HashJoinDynamicFilterRoutingMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "CASE_HASH", + "PARTITION_INDEX", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = HashJoinDynamicFilterRoutingMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "CASE_HASH" => Ok(HashJoinDynamicFilterRoutingMode::CaseHash), + "PARTITION_INDEX" => Ok(HashJoinDynamicFilterRoutingMode::PartitionIndex), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for HashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -8041,6 +8130,12 @@ impl serde::Serialize for HashJoinExecNode { if !self.projection.is_empty() { len += 1; } + if self.dynamic_filter.is_some() { + len += 1; + } + if self.dynamic_filter_routing_mode != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; if let Some(v) = self.left.as_ref() { struct_ser.serialize_field("left", v)?; @@ -8072,6 +8167,14 @@ impl serde::Serialize for HashJoinExecNode { if !self.projection.is_empty() { struct_ser.serialize_field("projection", &self.projection)?; } + if let Some(v) = self.dynamic_filter.as_ref() { + struct_ser.serialize_field("dynamicFilter", v)?; + } + if self.dynamic_filter_routing_mode != 0 { + let v = HashJoinDynamicFilterRoutingMode::try_from(self.dynamic_filter_routing_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.dynamic_filter_routing_mode)))?; + struct_ser.serialize_field("dynamicFilterRoutingMode", &v)?; + } struct_ser.end() } } @@ -8093,6 +8196,10 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "nullEquality", "filter", "projection", + "dynamic_filter", + "dynamicFilter", + "dynamic_filter_routing_mode", + "dynamicFilterRoutingMode", ]; #[allow(clippy::enum_variant_names)] @@ -8105,6 +8212,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { NullEquality, Filter, Projection, + DynamicFilter, + DynamicFilterRoutingMode, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8134,6 +8243,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "projection" => Ok(GeneratedField::Projection), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), + "dynamicFilterRoutingMode" | "dynamic_filter_routing_mode" => Ok(GeneratedField::DynamicFilterRoutingMode), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8161,6 +8272,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut null_equality__ = None; let mut filter__ = None; let mut projection__ = None; + let mut dynamic_filter__ = None; + let mut dynamic_filter_routing_mode__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { @@ -8214,6 +8327,18 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { .into_iter().map(|x| x.0).collect()) ; } + GeneratedField::DynamicFilter => { + if dynamic_filter__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + dynamic_filter__ = map_.next_value()?; + } + GeneratedField::DynamicFilterRoutingMode => { + if dynamic_filter_routing_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilterRoutingMode")); + } + dynamic_filter_routing_mode__ = Some(map_.next_value::()? as i32); + } } } Ok(HashJoinExecNode { @@ -8225,6 +8350,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { null_equality: null_equality__.unwrap_or_default(), filter: filter__, projection: projection__.unwrap_or_default(), + dynamic_filter: dynamic_filter__, + dynamic_filter_routing_mode: dynamic_filter_routing_mode__.unwrap_or_default(), }) } } @@ -15866,6 +15993,172 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { deserializer.deserialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalDynamicFilterNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.children.is_empty() { + len += 1; + } + if !self.remapped_children.is_empty() { + len += 1; + } + if self.generation != 0 { + len += 1; + } + if self.inner_expr.is_some() { + len += 1; + } + if self.is_complete { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalDynamicFilterNode", len)?; + if !self.children.is_empty() { + struct_ser.serialize_field("children", &self.children)?; + } + if !self.remapped_children.is_empty() { + struct_ser.serialize_field("remappedChildren", &self.remapped_children)?; + } + if self.generation != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("generation", ToString::to_string(&self.generation).as_str())?; + } + if let Some(v) = self.inner_expr.as_ref() { + struct_ser.serialize_field("innerExpr", v)?; + } + if self.is_complete { + struct_ser.serialize_field("isComplete", &self.is_complete)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalDynamicFilterNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "children", + "remapped_children", + "remappedChildren", + "generation", + "inner_expr", + "innerExpr", + "is_complete", + "isComplete", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Children, + RemappedChildren, + Generation, + InnerExpr, + IsComplete, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "children" => Ok(GeneratedField::Children), + "remappedChildren" | "remapped_children" => Ok(GeneratedField::RemappedChildren), + "generation" => Ok(GeneratedField::Generation), + "innerExpr" | "inner_expr" => Ok(GeneratedField::InnerExpr), + "isComplete" | "is_complete" => Ok(GeneratedField::IsComplete), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalDynamicFilterNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalDynamicFilterNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut children__ = None; + let mut remapped_children__ = None; + let mut generation__ = None; + let mut inner_expr__ = None; + let mut is_complete__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Children => { + if children__.is_some() { + return Err(serde::de::Error::duplicate_field("children")); + } + children__ = Some(map_.next_value()?); + } + GeneratedField::RemappedChildren => { + if remapped_children__.is_some() { + return Err(serde::de::Error::duplicate_field("remappedChildren")); + } + remapped_children__ = Some(map_.next_value()?); + } + GeneratedField::Generation => { + if generation__.is_some() { + return Err(serde::de::Error::duplicate_field("generation")); + } + generation__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::InnerExpr => { + if inner_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("innerExpr")); + } + inner_expr__ = map_.next_value()?; + } + GeneratedField::IsComplete => { + if is_complete__.is_some() { + return Err(serde::de::Error::duplicate_field("isComplete")); + } + is_complete__ = Some(map_.next_value()?); + } + } + } + Ok(PhysicalDynamicFilterNode { + children: children__.unwrap_or_default(), + remapped_children: remapped_children__.unwrap_or_default(), + generation: generation__.unwrap_or_default(), + inner_expr: inner_expr__, + is_complete: is_complete__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalDynamicFilterNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -15874,10 +16167,26 @@ impl serde::Serialize for PhysicalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; + if self.expr_id.is_some() { + len += 1; + } + if self.dynamic_filter_inner_id.is_some() { + len += 1; + } if self.expr_type.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; + if let Some(v) = self.expr_id.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("exprId", ToString::to_string(&v).as_str())?; + } + if let Some(v) = self.dynamic_filter_inner_id.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("dynamicFilterInnerId", ToString::to_string(&v).as_str())?; + } if let Some(v) = self.expr_type.as_ref() { match v { physical_expr_node::ExprType::Column(v) => { @@ -15937,6 +16246,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::HashExpr(v) => { struct_ser.serialize_field("hashExpr", v)?; } + physical_expr_node::ExprType::DynamicFilter(v) => { + struct_ser.serialize_field("dynamicFilter", v)?; + } } } struct_ser.end() @@ -15949,6 +16261,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "expr_id", + "exprId", + "dynamic_filter_inner_id", + "dynamicFilterInnerId", "column", "literal", "binary_expr", @@ -15981,10 +16297,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "unknownColumn", "hash_expr", "hashExpr", + "dynamic_filter", + "dynamicFilter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + ExprId, + DynamicFilterInnerId, Column, Literal, BinaryExpr, @@ -16004,6 +16324,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { Extension, UnknownColumn, HashExpr, + DynamicFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16025,6 +16346,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { E: serde::de::Error, { match value { + "exprId" | "expr_id" => Ok(GeneratedField::ExprId), + "dynamicFilterInnerId" | "dynamic_filter_inner_id" => Ok(GeneratedField::DynamicFilterInnerId), "column" => Ok(GeneratedField::Column), "literal" => Ok(GeneratedField::Literal), "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), @@ -16044,6 +16367,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16063,9 +16387,27 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { where V: serde::de::MapAccess<'de>, { + let mut expr_id__ = None; + let mut dynamic_filter_inner_id__ = None; let mut expr_type__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::ExprId => { + if expr_id__.is_some() { + return Err(serde::de::Error::duplicate_field("exprId")); + } + expr_id__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } + GeneratedField::DynamicFilterInnerId => { + if dynamic_filter_inner_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilterInnerId")); + } + dynamic_filter_inner_id__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::Column => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("column")); @@ -16197,11 +16539,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("hashExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::HashExpr) +; + } + GeneratedField::DynamicFilter => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::DynamicFilter) ; } } } Ok(PhysicalExprNode { + expr_id: expr_id__, + dynamic_filter_inner_id: dynamic_filter_inner_id__, expr_type: expr_type__, }) } @@ -21167,6 +21518,9 @@ impl serde::Serialize for SortExecNode { if self.preserve_partitioning { len += 1; } + if self.dynamic_filter.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.SortExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -21182,6 +21536,9 @@ impl serde::Serialize for SortExecNode { if self.preserve_partitioning { struct_ser.serialize_field("preservePartitioning", &self.preserve_partitioning)?; } + if let Some(v) = self.dynamic_filter.as_ref() { + struct_ser.serialize_field("dynamicFilter", v)?; + } struct_ser.end() } } @@ -21197,6 +21554,8 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { "fetch", "preserve_partitioning", "preservePartitioning", + "dynamic_filter", + "dynamicFilter", ]; #[allow(clippy::enum_variant_names)] @@ -21205,6 +21564,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { Expr, Fetch, PreservePartitioning, + DynamicFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21230,6 +21590,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { "expr" => Ok(GeneratedField::Expr), "fetch" => Ok(GeneratedField::Fetch), "preservePartitioning" | "preserve_partitioning" => Ok(GeneratedField::PreservePartitioning), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21253,6 +21614,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { let mut expr__ = None; let mut fetch__ = None; let mut preserve_partitioning__ = None; + let mut dynamic_filter__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -21281,6 +21643,12 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { } preserve_partitioning__ = Some(map_.next_value()?); } + GeneratedField::DynamicFilter => { + if dynamic_filter__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + dynamic_filter__ = map_.next_value()?; + } } } Ok(SortExecNode { @@ -21288,6 +21656,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { expr: expr__.unwrap_or_default(), fetch: fetch__.unwrap_or_default(), preserve_partitioning: preserve_partitioning__.unwrap_or_default(), + dynamic_filter: dynamic_filter__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cf343e0258d0b..804d3cf4d4578 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1274,9 +1274,23 @@ pub struct PhysicalExtensionNode { /// physical expressions #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExprNode { + /// Unique identifier for this expression to do deduplication during deserialization. + /// When serializing, this is set to a unique identifier for each combination of + /// expression, process and serialization run. + /// When deserializing, if this ID has been seen before, the cached Arc is returned + /// instead of creating a new one, enabling reconstruction of referential integrity + /// across serde roundtrips. + #[prost(uint64, optional, tag = "30")] + pub expr_id: ::core::option::Option, + /// For DynamicFilterPhysicalExpr, this identifies the shared inner state. + /// Multiple expressions may have different expr_id values (different outer Arc wrappers) + /// but the same dynamic_filter_inner_id (shared inner state). + /// Used to reconstruct shared inner state during deserialization. + #[prost(uint64, optional, tag = "31")] + pub dynamic_filter_inner_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22" )] pub expr_type: ::core::option::Option, } @@ -1329,9 +1343,24 @@ pub mod physical_expr_node { UnknownColumn(super::UnknownColumn), #[prost(message, tag = "21")] HashExpr(super::PhysicalHashExprNode), + #[prost(message, tag = "22")] + DynamicFilter(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalDynamicFilterNode { + #[prost(message, repeated, tag = "1")] + pub children: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub remapped_children: ::prost::alloc::vec::Vec, + #[prost(uint64, tag = "3")] + pub generation: u64, + #[prost(message, optional, boxed, tag = "4")] + pub inner_expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bool, tag = "5")] + pub is_complete: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalScalarUdfNode { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, @@ -1688,6 +1717,12 @@ pub struct HashJoinExecNode { pub filter: ::core::option::Option, #[prost(uint32, repeated, tag = "9")] pub projection: ::prost::alloc::vec::Vec, + /// Optional dynamic filter expression for pushing down to the probe side. + #[prost(message, optional, tag = "10")] + pub dynamic_filter: ::core::option::Option, + /// Selected routing strategy for partitioned dynamic filter expressions. + #[prost(enumeration = "HashJoinDynamicFilterRoutingMode", tag = "11")] + pub dynamic_filter_routing_mode: i32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { @@ -1858,6 +1893,9 @@ pub struct AggregateExecNode { pub limit: ::core::option::Option, #[prost(bool, tag = "12")] pub has_grouping_set: bool, + /// Optional dynamic filter expression for pushing down to the child. + #[prost(message, optional, tag = "13")] + pub dynamic_filter: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct GlobalLimitExecNode { @@ -1888,6 +1926,9 @@ pub struct SortExecNode { pub fetch: i64, #[prost(bool, tag = "4")] pub preserve_partitioning: bool, + /// Optional dynamic filter expression for TopK pushdown. + #[prost(message, optional, tag = "5")] + pub dynamic_filter: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortPreservingMergeExecNode { @@ -2304,6 +2345,32 @@ impl PartitionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum HashJoinDynamicFilterRoutingMode { + CaseHash = 0, + PartitionIndex = 1, +} +impl HashJoinDynamicFilterRoutingMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::CaseHash => "CASE_HASH", + Self::PartitionIndex => "PARTITION_INDEX", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "CASE_HASH" => Some(Self::CaseHash), + "PARTITION_INDEX" => Some(Self::PartitionIndex), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum StreamPartitionMode { SinglePartition = 0, PartitionedExec = 1, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 073fdd858cdd3..5c05a91a65bc8 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,14 +21,9 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::{Field, Schema}; use arrow::ipc::reader::StreamReader; use chrono::{TimeZone, Utc}; -use datafusion_expr::dml::InsertOp; -use object_store::ObjectMeta; -use object_store::path::Path; - -use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, Result, internal_datafusion_err, not_impl_err}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; @@ -42,6 +37,7 @@ use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ @@ -52,13 +48,19 @@ use datafusion_physical_plan::joins::{HashExpr, SeededRandomState}; use datafusion_physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion_proto_common::common::proto_error; +use object_store::ObjectMeta; +use object_store::path::Path; -use crate::convert_required; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; use crate::logical_plan::{self}; -use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; - -use super::PhysicalExtensionCodec; +use crate::{convert_required, protobuf}; +use datafusion_physical_expr::expressions::{ + DynamicFilterPhysicalExpr, DynamicFilterSnapshot, +}; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -80,9 +82,15 @@ pub fn parse_physical_sort_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), ctx, input_schema, codec)?; + let expr = proto_converter.proto_to_physical_expr( + expr.as_ref(), + ctx, + input_schema, + codec, + )?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -107,10 +115,13 @@ pub fn parse_physical_sort_exprs( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { proto .iter() - .map(|sort_expr| parse_physical_sort_expr(sort_expr, ctx, input_schema, codec)) + .map(|sort_expr| { + parse_physical_sort_expr(sort_expr, ctx, input_schema, codec, proto_converter) + }) .collect() } @@ -129,12 +140,25 @@ pub fn parse_physical_window_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let window_node_expr = parse_physical_exprs(&proto.args, ctx, input_schema, codec)?; - let partition_by = - parse_physical_exprs(&proto.partition_by, ctx, input_schema, codec)?; - - let order_by = parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, codec)?; + let window_node_expr = + parse_physical_exprs(&proto.args, ctx, input_schema, codec, proto_converter)?; + let partition_by = parse_physical_exprs( + &proto.partition_by, + ctx, + input_schema, + codec, + proto_converter, + )?; + + let order_by = parse_physical_sort_exprs( + &proto.order_by, + ctx, + input_schema, + codec, + proto_converter, + )?; let window_frame = proto .window_frame @@ -188,13 +212,14 @@ pub fn parse_physical_exprs<'a, I>( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result>> where I: IntoIterator, { protos .into_iter() - .map(|p| parse_physical_expr(p, ctx, input_schema, codec)) + .map(|p| proto_converter.proto_to_physical_expr(p, ctx, input_schema, codec)) .collect::>>() } @@ -212,6 +237,32 @@ pub fn parse_physical_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, +) -> Result> { + parse_physical_expr_with_converter( + proto, + ctx, + input_schema, + codec, + &DefaultPhysicalProtoConverter {}, + ) +} + +/// Parses a physical expression from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with physical expression node +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names +/// * `input_schema` - The Arrow schema for the input, used for determining expression data types +/// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. +/// * `proto_converter` - Conversion functions for physical plans and expressions +pub fn parse_physical_expr_with_converter( + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let expr_type = proto .expr_type @@ -232,6 +283,7 @@ pub fn parse_physical_expr( "left", input_schema, codec, + proto_converter, )?, logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_expr( @@ -240,6 +292,7 @@ pub fn parse_physical_expr( "right", input_schema, codec, + proto_converter, )?, )), ExprType::AggregateExpr(_) => { @@ -262,6 +315,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::IsNotNullExpr(e) => { @@ -271,6 +325,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( @@ -279,6 +334,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)), ExprType::Negative(e) => { Arc::new(NegativeExpr::new(parse_required_physical_expr( @@ -287,6 +343,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::InList(e) => in_list( @@ -296,15 +353,23 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, - parse_physical_exprs(&e.list, ctx, input_schema, codec)?, + parse_physical_exprs(&e.list, ctx, input_schema, codec, proto_converter)?, &e.negated, input_schema, )?, ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr( + e.as_ref(), + ctx, + input_schema, + codec, + ) + }) .transpose()?, e.when_then_expr .iter() @@ -316,6 +381,7 @@ pub fn parse_physical_expr( "when_expr", input_schema, codec, + proto_converter, )?, parse_required_physical_expr( e.then_expr.as_ref(), @@ -323,13 +389,21 @@ pub fn parse_physical_expr( "then_expr", input_schema, codec, + proto_converter, )?, )) }) .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr( + e.as_ref(), + ctx, + input_schema, + codec, + ) + }) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -339,6 +413,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, convert_required!(e.arrow_type)?, None, @@ -350,6 +425,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, convert_required!(e.arrow_type)?, )), @@ -362,7 +438,8 @@ pub fn parse_physical_expr( }; let scalar_fun_def = Arc::clone(&udf); - let args = parse_physical_exprs(&e.args, ctx, input_schema, codec)?; + let args = + parse_physical_exprs(&e.args, ctx, input_schema, codec, proto_converter)?; let config_options = Arc::clone(ctx.session_config().options()); @@ -391,6 +468,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, parse_required_physical_expr( like_expr.pattern.as_deref(), @@ -398,11 +476,17 @@ pub fn parse_physical_expr( "pattern", input_schema, codec, + proto_converter, )?, )), ExprType::HashExpr(hash_expr) => { - let on_columns = - parse_physical_exprs(&hash_expr.on_columns, ctx, input_schema, codec)?; + let on_columns = parse_physical_exprs( + &hash_expr.on_columns, + ctx, + input_schema, + codec, + proto_converter, + )?; Arc::new(HashExpr::new( on_columns, SeededRandomState::with_seeds( @@ -414,13 +498,57 @@ pub fn parse_physical_expr( hash_expr.description.clone(), )) } + ExprType::DynamicFilter(dynamic_filter) => { + let children = parse_physical_exprs( + &dynamic_filter.children, + ctx, + input_schema, + codec, + proto_converter, + )?; + + let remapped_children = if !dynamic_filter.remapped_children.is_empty() { + Some(parse_physical_exprs( + &dynamic_filter.remapped_children, + ctx, + input_schema, + codec, + proto_converter, + )?) + } else { + None + }; + + let inner_expr = parse_required_physical_expr( + dynamic_filter.inner_expr.as_deref(), + ctx, + "inner_expr", + input_schema, + codec, + proto_converter, + )?; + + // Recreate filter from snapshot + let snapshot = DynamicFilterSnapshot::new( + children, + remapped_children, + dynamic_filter.generation, + inner_expr, + dynamic_filter.is_complete, + ); + let base_filter: Arc = + Arc::new(DynamicFilterPhysicalExpr::from(snapshot)); + base_filter + } ExprType::Extension(extension) => { let inputs: Vec> = extension .inputs .iter() - .map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec) + }) .collect::>()?; - (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ + codec.try_decode_expr(extension.expr.as_slice(), &inputs)? as _ } }; @@ -433,8 +561,9 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - expr.map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + expr.map(|e| proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec)) .transpose()? .ok_or_else(|| internal_datafusion_err!("Missing required field {field:?}")) } @@ -444,11 +573,17 @@ pub fn parse_protobuf_hash_partitioning( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(hash_part) => { - let expr = - parse_physical_exprs(&hash_part.hash_expr, ctx, input_schema, codec)?; + let expr = parse_physical_exprs( + &hash_part.hash_expr, + ctx, + input_schema, + codec, + proto_converter, + )?; Ok(Some(Partitioning::Hash( expr, @@ -464,6 +599,7 @@ pub fn parse_protobuf_partitioning( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(protobuf::Partitioning { partition_method }) => match partition_method { @@ -478,6 +614,7 @@ pub fn parse_protobuf_partitioning( ctx, input_schema, codec, + proto_converter, ) } Some(protobuf::partitioning::PartitionMethod::Unknown(partition_count)) => { @@ -532,6 +669,7 @@ pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, ctx: &TaskContext, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, file_source: Arc, ) -> Result { let schema: Arc = parse_protobuf_file_scan_schema(proto)?; @@ -557,6 +695,7 @@ pub fn parse_protobuf_file_scan_config( ctx, &schema, codec, + proto_converter, )?; output_ordering.extend(LexOrdering::new(sort_exprs)); } @@ -567,7 +706,7 @@ pub fn parse_protobuf_file_scan_config( .projections .iter() .map(|proto_expr| { - let expr = parse_physical_expr( + let expr = proto_converter.proto_to_physical_expr( proto_expr.expr.as_ref().ok_or_else(|| { internal_datafusion_err!("ProjectionExpr missing expr field") })?, @@ -745,12 +884,13 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { #[cfg(test)] mod tests { - use super::*; use chrono::{TimeZone, Utc}; use datafusion_datasource::PartitionedFile; use object_store::ObjectMeta; use object_store::path::Path; + use super::*; + #[test] fn partitioned_file_path_roundtrip_percent_encoded() { let path_str = "foo/foo%2Fbar/baz%252Fqux"; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 4ff90b61eed9c..58c8bc27e2d7f 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1,3 +1,4 @@ +use std::any::Any; // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -15,33 +16,14 @@ // specific language governing permissions and limitations // under the License. +use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use self::from_proto::parse_protobuf_partitioning; -use self::to_proto::{serialize_partitioning, serialize_physical_expr}; -use crate::common::{byte_to_string, str_to_byte}; -use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, - parse_physical_window_expr, parse_protobuf_file_scan_config, parse_record_batches, - parse_table_schema_from_proto, -}; -use crate::physical_plan::to_proto::{ - serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, - serialize_physical_sort_exprs, serialize_physical_window_expr, - serialize_record_batches, -}; -use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::{ - self, ListUnnest as ProtoListUnnest, SortExprNode, SortMergeJoinExecNode, - proto_error, window_agg_exec_node, -}; -use crate::{convert_required, into_required}; - use arrow::compute::SortOptions; -use arrow::datatypes::{IntervalMonthDayNanoType, SchemaRef}; +use arrow::datatypes::{IntervalMonthDayNanoType, Schema, SchemaRef}; use datafusion_catalog::memory::MemorySourceConfig; use datafusion_common::config::CsvOptions; use datafusion_common::{ @@ -68,12 +50,14 @@ use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, }; -use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; use datafusion_physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; use datafusion_physical_plan::aggregates::AggregateMode; use datafusion_physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion_physical_plan::analyze::AnalyzeExec; +use datafusion_physical_plan::async_func::AsyncFuncExec; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::coop::CooperativeExec; @@ -83,10 +67,9 @@ use datafusion_physical_plan::expressions::PhysicalSortExpr; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, - SymmetricHashJoinExec, + CrossJoinExec, DynamicFilterRoutingMode, HashJoinExec, NestedLoopJoinExec, + PartitionMode, SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::memory::LazyMemoryExec; use datafusion_physical_plan::metrics::MetricType; @@ -99,12 +82,31 @@ use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; use datafusion_physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr}; - -use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; -use datafusion_physical_plan::async_func::AsyncFuncExec; use prost::Message; use prost::bytes::BufMut; +use self::from_proto::parse_protobuf_partitioning; +use self::to_proto::serialize_partitioning; +use crate::common::{byte_to_string, str_to_byte}; +use crate::physical_plan::from_proto::{ + parse_physical_expr_with_converter, parse_physical_sort_expr, + parse_physical_sort_exprs, parse_physical_window_expr, + parse_protobuf_file_scan_config, parse_record_batches, parse_table_schema_from_proto, +}; +use crate::physical_plan::to_proto::{ + serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, + serialize_physical_expr_with_converter, serialize_physical_sort_exprs, + serialize_physical_window_expr, serialize_record_batches, +}; +use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; +use crate::protobuf::physical_expr_node::ExprType; +use crate::protobuf::physical_plan_node::PhysicalPlanType; +use crate::protobuf::{ + self, ListUnnest as ProtoListUnnest, SortExprNode, SortMergeJoinExecNode, + proto_error, window_agg_exec_node, +}; +use crate::{convert_required, into_required}; + pub mod from_proto; pub mod to_proto; @@ -131,8 +133,37 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_into_physical_plan( &self, ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + self.try_into_physical_plan_with_converter( + ctx, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } + + fn try_from_physical_plan( + plan: Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + Self::try_from_physical_plan_with_converter( + plan, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } +} + +impl protobuf::PhysicalPlanNode { + pub fn try_into_physical_plan_with_converter( + &self, + ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { proto_error(format!( @@ -141,125 +172,149 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { })?; match plan { PhysicalPlanType::Explain(explain) => { - self.try_into_explain_physical_plan(explain, ctx, extension_codec) - } - PhysicalPlanType::Projection(projection) => { - self.try_into_projection_physical_plan(projection, ctx, extension_codec) + self.try_into_explain_physical_plan(explain, ctx, codec, proto_converter) } + PhysicalPlanType::Projection(projection) => self + .try_into_projection_physical_plan( + projection, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::Filter(filter) => { - self.try_into_filter_physical_plan(filter, ctx, extension_codec) + self.try_into_filter_physical_plan(filter, ctx, codec, proto_converter) } PhysicalPlanType::CsvScan(scan) => { - self.try_into_csv_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_csv_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::JsonScan(scan) => { - self.try_into_json_scan_physical_plan(scan, ctx, extension_codec) - } - PhysicalPlanType::ParquetScan(scan) => { - self.try_into_parquet_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_json_scan_physical_plan(scan, ctx, codec, proto_converter) } + PhysicalPlanType::ParquetScan(scan) => self + .try_into_parquet_scan_physical_plan(scan, ctx, codec, proto_converter), PhysicalPlanType::AvroScan(scan) => { - self.try_into_avro_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_avro_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::MemoryScan(scan) => { - self.try_into_memory_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_memory_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => self .try_into_coalesce_batches_physical_plan( coalesce_batches, ctx, - extension_codec, + codec, + proto_converter, ), PhysicalPlanType::Merge(merge) => { - self.try_into_merge_physical_plan(merge, ctx, extension_codec) - } - PhysicalPlanType::Repartition(repart) => { - self.try_into_repartition_physical_plan(repart, ctx, extension_codec) - } - PhysicalPlanType::GlobalLimit(limit) => { - self.try_into_global_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::LocalLimit(limit) => { - self.try_into_local_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::Window(window_agg) => { - self.try_into_window_physical_plan(window_agg, ctx, extension_codec) - } - PhysicalPlanType::Aggregate(hash_agg) => { - self.try_into_aggregate_physical_plan(hash_agg, ctx, extension_codec) - } - PhysicalPlanType::HashJoin(hashjoin) => { - self.try_into_hash_join_physical_plan(hashjoin, ctx, extension_codec) + self.try_into_merge_physical_plan(merge, ctx, codec, proto_converter) } + PhysicalPlanType::Repartition(repart) => self + .try_into_repartition_physical_plan(repart, ctx, codec, proto_converter), + PhysicalPlanType::GlobalLimit(limit) => self + .try_into_global_limit_physical_plan(limit, ctx, codec, proto_converter), + PhysicalPlanType::LocalLimit(limit) => self + .try_into_local_limit_physical_plan(limit, ctx, codec, proto_converter), + PhysicalPlanType::Window(window_agg) => self.try_into_window_physical_plan( + window_agg, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Aggregate(hash_agg) => self + .try_into_aggregate_physical_plan(hash_agg, ctx, codec, proto_converter), + PhysicalPlanType::HashJoin(hashjoin) => self + .try_into_hash_join_physical_plan(hashjoin, ctx, codec, proto_converter), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, ctx, - extension_codec, + codec, + proto_converter, ), PhysicalPlanType::Union(union) => { - self.try_into_union_physical_plan(union, ctx, extension_codec) - } - PhysicalPlanType::Interleave(interleave) => { - self.try_into_interleave_physical_plan(interleave, ctx, extension_codec) - } - PhysicalPlanType::CrossJoin(crossjoin) => { - self.try_into_cross_join_physical_plan(crossjoin, ctx, extension_codec) + self.try_into_union_physical_plan(union, ctx, codec, proto_converter) } - PhysicalPlanType::Empty(empty) => { - self.try_into_empty_physical_plan(empty, ctx, extension_codec) - } - PhysicalPlanType::PlaceholderRow(placeholder) => self - .try_into_placeholder_row_physical_plan( - placeholder, + PhysicalPlanType::Interleave(interleave) => self + .try_into_interleave_physical_plan( + interleave, ctx, - extension_codec, + codec, + proto_converter, ), - PhysicalPlanType::Sort(sort) => { - self.try_into_sort_physical_plan(sort, ctx, extension_codec) + PhysicalPlanType::CrossJoin(crossjoin) => self + .try_into_cross_join_physical_plan( + crossjoin, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Empty(empty) => { + self.try_into_empty_physical_plan(empty, ctx, codec, proto_converter) } - PhysicalPlanType::SortPreservingMerge(sort) => self - .try_into_sort_preserving_merge_physical_plan(sort, ctx, extension_codec), - PhysicalPlanType::Extension(extension) => { - self.try_into_extension_physical_plan(extension, ctx, extension_codec) + PhysicalPlanType::PlaceholderRow(placeholder) => { + self.try_into_placeholder_row_physical_plan(placeholder, ctx, codec) } - PhysicalPlanType::NestedLoopJoin(join) => { - self.try_into_nested_loop_join_physical_plan(join, ctx, extension_codec) + PhysicalPlanType::Sort(sort) => { + self.try_into_sort_physical_plan(sort, ctx, codec, proto_converter) } + PhysicalPlanType::SortPreservingMerge(sort) => self + .try_into_sort_preserving_merge_physical_plan( + sort, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Extension(extension) => self + .try_into_extension_physical_plan(extension, ctx, codec, proto_converter), + PhysicalPlanType::NestedLoopJoin(join) => self + .try_into_nested_loop_join_physical_plan( + join, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::Analyze(analyze) => { - self.try_into_analyze_physical_plan(analyze, ctx, extension_codec) + self.try_into_analyze_physical_plan(analyze, ctx, codec, proto_converter) } PhysicalPlanType::JsonSink(sink) => { - self.try_into_json_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_json_sink_physical_plan(sink, ctx, codec, proto_converter) } PhysicalPlanType::CsvSink(sink) => { - self.try_into_csv_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_csv_sink_physical_plan(sink, ctx, codec, proto_converter) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] - PhysicalPlanType::ParquetSink(sink) => { - self.try_into_parquet_sink_physical_plan(sink, ctx, extension_codec) - } + PhysicalPlanType::ParquetSink(sink) => self + .try_into_parquet_sink_physical_plan(sink, ctx, codec, proto_converter), PhysicalPlanType::Unnest(unnest) => { - self.try_into_unnest_physical_plan(unnest, ctx, extension_codec) - } - PhysicalPlanType::Cooperative(cooperative) => { - self.try_into_cooperative_physical_plan(cooperative, ctx, extension_codec) + self.try_into_unnest_physical_plan(unnest, ctx, codec, proto_converter) } + PhysicalPlanType::Cooperative(cooperative) => self + .try_into_cooperative_physical_plan( + cooperative, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::GenerateSeries(generate_series) => { self.try_into_generate_series_physical_plan(generate_series) } PhysicalPlanType::SortMergeJoin(sort_join) => { - self.try_into_sort_join(sort_join, ctx, extension_codec) - } - PhysicalPlanType::AsyncFunc(async_func) => { - self.try_into_async_func_physical_plan(async_func, ctx, extension_codec) + self.try_into_sort_join(sort_join, ctx, codec, proto_converter) } + PhysicalPlanType::AsyncFunc(async_func) => self + .try_into_async_func_physical_plan( + async_func, + ctx, + codec, + proto_converter, + ), } } - fn try_from_physical_plan( + pub fn try_from_physical_plan_with_converter( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result where Self: Sized, @@ -268,107 +323,112 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let plan = plan.as_any(); if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_explain_exec( - exec, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_explain_exec(exec, codec); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_projection_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_analyze_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_filter_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_global_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_local_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_merge_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cross_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_aggregate_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(empty) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_empty_exec( - empty, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_empty_exec(empty, codec); } if let Some(empty) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_placeholder_row_exec( - empty, - extension_codec, + empty, codec, ); } if let Some(coalesce_batches) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_batches_exec( coalesce_batches, - extension_codec, + codec, + proto_converter, ); } if let Some(data_source_exec) = plan.downcast_ref::() && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_source_exec( data_source_exec, - extension_codec, + codec, + proto_converter, )? { return Ok(node); @@ -377,67 +437,80 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_partitions_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_repartition_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_sort_exec(exec, extension_codec); + return protobuf::PhysicalPlanNode::try_from_sort_exec( + exec, + codec, + proto_converter, + ); } if let Some(union) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_union_exec( union, - extension_codec, + codec, + proto_converter, ); } if let Some(interleave) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_interleave_exec( interleave, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_preserving_merge_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_nested_loop_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_bounded_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_sink_exec( exec, - extension_codec, + codec, + proto_converter, )? { return Ok(node); @@ -446,14 +519,16 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_unnest_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cooperative_exec( exec, - extension_codec, + codec, + proto_converter, ); } @@ -467,21 +542,23 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_async_func_exec( exec, - extension_codec, + codec, + proto_converter, ); } let mut buf: Vec = vec![]; - match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { + match codec.try_encode(Arc::clone(&plan_clone), &mut buf, proto_converter) { Ok(_) => { let inputs: Vec = plan_clone .children() .into_iter() .cloned() .map(|i| { - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( i, - extension_codec, + codec, + proto_converter, ) }) .collect::>()?; @@ -505,7 +582,8 @@ impl protobuf::PhysicalPlanNode { explain: &protobuf::ExplainExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { Ok(Arc::new(ExplainExec::new( Arc::new(explain.schema.as_ref().unwrap().try_into()?), @@ -523,21 +601,22 @@ impl protobuf::PhysicalPlanNode { projection: &protobuf::ProjectionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&projection.input, ctx, extension_codec)?; + into_physical_plan(&projection.input, ctx, codec, proto_converter)?; let exprs = projection .expr .iter() .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?, name.to_string(), )) @@ -555,16 +634,22 @@ impl protobuf::PhysicalPlanNode { filter: &protobuf::FilterExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&filter.input, ctx, extension_codec)?; + into_physical_plan(&filter.input, ctx, codec, proto_converter)?; let predicate = filter .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr( + expr, + ctx, + input.schema().as_ref(), + codec, + ) }) .transpose()? .ok_or_else(|| { @@ -603,7 +688,8 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::CsvScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let escape = if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape(escape)) = @@ -644,7 +730,8 @@ impl protobuf::PhysicalPlanNode { let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + codec, + proto_converter, source, )?) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) @@ -657,14 +744,16 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::JsonScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let base_conf = scan.base_conf.as_ref().unwrap(); let table_schema = parse_table_schema_from_proto(base_conf)?; let scan_conf = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + codec, + proto_converter, Arc::new(JsonSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(scan_conf)) @@ -675,7 +764,8 @@ impl protobuf::PhysicalPlanNode { &self, scan: &protobuf::ParquetScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { @@ -692,7 +782,7 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|&i| schema.field(i as usize).clone()) .collect(); - Arc::new(arrow::datatypes::Schema::new(projected_fields)) + Arc::new(Schema::new(projected_fields)) } else { schema }; @@ -701,11 +791,11 @@ impl protobuf::PhysicalPlanNode { .predicate .as_ref() .map(|expr| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, ctx, predicate_schema.as_ref(), - extension_codec, + codec, ) }) .transpose()?; @@ -727,7 +817,8 @@ impl protobuf::PhysicalPlanNode { let base_config = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + codec, + proto_converter, Arc::new(source), )?; Ok(DataSourceExec::from_data_source(base_config)) @@ -743,7 +834,8 @@ impl protobuf::PhysicalPlanNode { &self, scan: &protobuf::AvroScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "avro")] { @@ -752,7 +844,8 @@ impl protobuf::PhysicalPlanNode { let conf = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + codec, + proto_converter, Arc::new(AvroSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(conf)) @@ -767,7 +860,8 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::MemoryScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let partitions = scan .partitions @@ -797,7 +891,8 @@ impl protobuf::PhysicalPlanNode { &ordering.physical_sort_expr_nodes, ctx, &schema, - extension_codec, + codec, + proto_converter, )?; sort_information.extend(LexOrdering::new(sort_exprs)); } @@ -816,10 +911,11 @@ impl protobuf::PhysicalPlanNode { coalesce_batches: &protobuf::CoalesceBatchesExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&coalesce_batches.input, ctx, extension_codec)?; + into_physical_plan(&coalesce_batches.input, ctx, codec, proto_converter)?; Ok(Arc::new( CoalesceBatchesExec::new(input, coalesce_batches.target_batch_size as usize) .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), @@ -831,10 +927,11 @@ impl protobuf::PhysicalPlanNode { merge: &protobuf::CoalescePartitionsExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&merge.input, ctx, extension_codec)?; + into_physical_plan(&merge.input, ctx, codec, proto_converter)?; Ok(Arc::new( CoalescePartitionsExec::new(input) .with_fetch(merge.fetch.map(|f| f as usize)), @@ -846,15 +943,17 @@ impl protobuf::PhysicalPlanNode { repart: &protobuf::RepartitionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&repart.input, ctx, extension_codec)?; + into_physical_plan(&repart.input, ctx, codec, proto_converter)?; let partitioning = parse_protobuf_partitioning( repart.partitioning.as_ref(), ctx, input.schema().as_ref(), - extension_codec, + codec, + proto_converter, )?; Ok(Arc::new(RepartitionExec::try_new( input, @@ -867,10 +966,11 @@ impl protobuf::PhysicalPlanNode { limit: &protobuf::GlobalLimitExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, codec, proto_converter)?; let fetch = if limit.fetch >= 0 { Some(limit.fetch as usize) } else { @@ -888,10 +988,11 @@ impl protobuf::PhysicalPlanNode { limit: &protobuf::LocalLimitExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, codec, proto_converter)?; Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) } @@ -900,10 +1001,11 @@ impl protobuf::PhysicalPlanNode { window_agg: &protobuf::WindowAggExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&window_agg.input, ctx, extension_codec)?; + into_physical_plan(&window_agg.input, ctx, codec, proto_converter)?; let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg @@ -914,7 +1016,8 @@ impl protobuf::PhysicalPlanNode { window_expr, ctx, input_schema.as_ref(), - extension_codec, + codec, + proto_converter, ) }) .collect::, _>>()?; @@ -923,7 +1026,12 @@ impl protobuf::PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr( + expr, + ctx, + input.schema().as_ref(), + codec, + ) }) .collect::>>>()?; @@ -958,10 +1066,11 @@ impl protobuf::PhysicalPlanNode { hash_agg: &protobuf::AggregateExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&hash_agg.input, ctx, extension_codec)?; + into_physical_plan(&hash_agg.input, ctx, codec, proto_converter)?; let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err(|_| { proto_error(format!( "Received a AggregateNode message with unknown AggregateMode {}", @@ -985,7 +1094,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -995,7 +1105,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -1024,7 +1135,12 @@ impl protobuf::PhysicalPlanNode { expr.expr .as_ref() .map(|e| { - parse_physical_expr(e, ctx, &physical_schema, extension_codec) + proto_converter.proto_to_physical_expr( + e, + ctx, + &physical_schema, + codec, + ) }) .transpose() }) @@ -1045,11 +1161,11 @@ impl protobuf::PhysicalPlanNode { .expr .iter() .map(|e| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( e, ctx, &physical_schema, - extension_codec, + codec, ) }) .collect::>>()?; @@ -1061,7 +1177,8 @@ impl protobuf::PhysicalPlanNode { e, ctx, &physical_schema, - extension_codec, + codec, + proto_converter, ) }) .collect::>()?; @@ -1071,11 +1188,11 @@ impl protobuf::PhysicalPlanNode { .map(|func| match func { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = match &agg_node.fun_definition { - Some(buf) => extension_codec - .try_decode_udaf(udaf_name, buf)?, + Some(buf) => { + codec.try_decode_udaf(udaf_name, buf)? + } None => ctx.udaf(udaf_name).or_else(|_| { - extension_codec - .try_decode_udaf(udaf_name, &[]) + codec.try_decode_udaf(udaf_name, &[]) })?, }; @@ -1107,6 +1224,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|lit_value| lit_value.limit as usize); + let physical_schema_ref = Arc::clone(&physical_schema); let agg = AggregateExec::try_new( agg_mode, PhysicalGroupBy::new(group_expr, null_expr, groups, has_grouping_set), @@ -1118,6 +1236,24 @@ impl protobuf::PhysicalPlanNode { let agg = agg.with_limit(limit); + let agg = if let Some(dynamic_filter_proto) = &hash_agg.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + ctx, + physical_schema_ref.as_ref(), + codec, + )?; + if let Ok(df) = (dynamic_filter_expr as Arc) + .downcast::() + { + agg.with_dynamic_filter(df)? + } else { + agg + } + } else { + agg + }; + Ok(Arc::new(agg)) } @@ -1126,29 +1262,30 @@ impl protobuf::PhysicalPlanNode { hashjoin: &protobuf::HashJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&hashjoin.left, ctx, extension_codec)?; + into_physical_plan(&hashjoin.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&hashjoin.right, ctx, extension_codec)?; + into_physical_plan(&hashjoin.right, ctx, codec, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1177,12 +1314,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1216,6 +1353,24 @@ impl protobuf::PhysicalPlanNode { protobuf::PartitionMode::Partitioned => PartitionMode::Partitioned, protobuf::PartitionMode::Auto => PartitionMode::Auto, }; + let dynamic_filter_routing_mode = + protobuf::HashJoinDynamicFilterRoutingMode::try_from( + hashjoin.dynamic_filter_routing_mode, + ) + .map_err(|_| { + proto_error(format!( + "Received a HashJoinNode message with unknown HashJoinDynamicFilterRoutingMode {}", + hashjoin.dynamic_filter_routing_mode + )) + })?; + let dynamic_filter_routing_mode = match dynamic_filter_routing_mode { + protobuf::HashJoinDynamicFilterRoutingMode::CaseHash => { + DynamicFilterRoutingMode::CaseHash + } + protobuf::HashJoinDynamicFilterRoutingMode::PartitionIndex => { + DynamicFilterRoutingMode::PartitionIndex + } + }; let projection = if !hashjoin.projection.is_empty() { Some( hashjoin @@ -1227,7 +1382,7 @@ impl protobuf::PhysicalPlanNode { } else { None }; - Ok(Arc::new(HashJoinExec::try_new( + let mut hash_join = HashJoinExec::try_new( left, right, on, @@ -1236,7 +1391,25 @@ impl protobuf::PhysicalPlanNode { projection, partition_mode, null_equality.into(), - )?)) + )?; + + if let Some(dynamic_filter_proto) = &hashjoin.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + ctx, + right_schema.as_ref(), + codec, + )?; + if let Ok(df) = (dynamic_filter_expr as Arc) + .downcast::() + { + hash_join = hash_join.with_dynamic_filter(df)?; + } + } + hash_join = + hash_join.with_dynamic_filter_routing_mode(dynamic_filter_routing_mode); + + Ok(Arc::new(hash_join)) } fn try_into_symmetric_hash_join_physical_plan( @@ -1244,27 +1417,28 @@ impl protobuf::PhysicalPlanNode { sym_join: &protobuf::SymmetricHashJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sym_join.left, ctx, extension_codec)?; - let right = into_physical_plan(&sym_join.right, ctx, extension_codec)?; + let left = into_physical_plan(&sym_join.left, ctx, codec, proto_converter)?; + let right = into_physical_plan(&sym_join.right, ctx, codec, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on = sym_join .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1293,12 +1467,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1324,7 +1498,8 @@ impl protobuf::PhysicalPlanNode { &sym_join.left_sort_exprs, ctx, &left_schema, - extension_codec, + codec, + proto_converter, )?; let left_sort_exprs = LexOrdering::new(left_sort_exprs); @@ -1332,7 +1507,8 @@ impl protobuf::PhysicalPlanNode { &sym_join.right_sort_exprs, ctx, &right_schema, - extension_codec, + codec, + proto_converter, )?; let right_sort_exprs = LexOrdering::new(right_sort_exprs); @@ -1372,11 +1548,12 @@ impl protobuf::PhysicalPlanNode { union: &protobuf::UnionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &union.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); } UnionExec::try_new(inputs) } @@ -1386,11 +1563,12 @@ impl protobuf::PhysicalPlanNode { interleave: &protobuf::InterleaveExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &interleave.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); } Ok(Arc::new(InterleaveExec::try_new(inputs)?)) } @@ -1400,12 +1578,13 @@ impl protobuf::PhysicalPlanNode { crossjoin: &protobuf::CrossJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&crossjoin.left, ctx, extension_codec)?; + into_physical_plan(&crossjoin.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&crossjoin.right, ctx, extension_codec)?; + into_physical_plan(&crossjoin.right, ctx, codec, proto_converter)?; Ok(Arc::new(CrossJoinExec::new(left, right))) } @@ -1414,7 +1593,8 @@ impl protobuf::PhysicalPlanNode { empty: &protobuf::EmptyExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let schema = Arc::new(convert_required!(empty.schema)?); Ok(Arc::new(EmptyExec::new(schema))) @@ -1425,7 +1605,7 @@ impl protobuf::PhysicalPlanNode { placeholder: &protobuf::PlaceholderRowExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result> { let schema = Arc::new(convert_required!(placeholder.schema)?); Ok(Arc::new(PlaceholderRowExec::new(schema))) @@ -1436,9 +1616,10 @@ impl protobuf::PhysicalPlanNode { sort: &protobuf::SortExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; let exprs = sort .expr .iter() @@ -1459,7 +1640,7 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec)?, + expr: proto_converter.proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -1480,6 +1661,24 @@ impl protobuf::PhysicalPlanNode { .with_fetch(fetch) .with_preserve_partitioning(sort.preserve_partitioning); + let new_sort = if let Some(dynamic_filter_proto) = &sort.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + ctx, + new_sort.input().schema().as_ref(), + codec, + )?; + if let Ok(df) = (dynamic_filter_expr as Arc) + .downcast::() + { + new_sort.with_dynamic_filter(df)? + } else { + new_sort + } + } else { + new_sort + }; + Ok(Arc::new(new_sort)) } @@ -1488,9 +1687,10 @@ impl protobuf::PhysicalPlanNode { sort: &protobuf::SortPreservingMergeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; let exprs = sort .expr .iter() @@ -1511,11 +1711,11 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr( + expr: proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?, options: SortOptions { descending: !sort_expr.asc, @@ -1541,16 +1741,17 @@ impl protobuf::PhysicalPlanNode { extension: &protobuf::PhysicalExtensionNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let inputs: Vec> = extension .inputs .iter() - .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .map(|i| proto_converter.proto_to_execution_plan(ctx, codec, i)) .collect::>()?; let extension_node = - extension_codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; + codec.try_decode(extension.node.as_slice(), &inputs, ctx, proto_converter)?; Ok(extension_node) } @@ -1560,12 +1761,13 @@ impl protobuf::PhysicalPlanNode { join: &protobuf::NestedLoopJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&join.left, ctx, extension_codec)?; + into_physical_plan(&join.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&join.right, ctx, extension_codec)?; + into_physical_plan(&join.right, ctx, codec, proto_converter)?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( "Received a NestedLoopJoinExecNode message with unknown JoinType {}", @@ -1582,12 +1784,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1634,10 +1836,11 @@ impl protobuf::PhysicalPlanNode { analyze: &protobuf::AnalyzeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&analyze.input, ctx, extension_codec)?; + into_physical_plan(&analyze.input, ctx, codec, proto_converter)?; Ok(Arc::new(AnalyzeExec::new( analyze.verbose, analyze.show_statistics, @@ -1652,9 +1855,10 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::JsonSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: JsonSink = sink .sink @@ -1670,7 +1874,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1690,9 +1895,10 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::CsvSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: CsvSink = sink .sink @@ -1708,7 +1914,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1729,11 +1936,12 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::ParquetSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: ParquetSink = sink .sink @@ -1749,7 +1957,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1772,9 +1981,10 @@ impl protobuf::PhysicalPlanNode { unnest: &protobuf::UnnestExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&unnest.input, ctx, extension_codec)?; + let input = into_physical_plan(&unnest.input, ctx, codec, proto_converter)?; Ok(Arc::new(UnnestExec::new( input, @@ -1803,11 +2013,12 @@ impl protobuf::PhysicalPlanNode { sort_join: &SortMergeJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sort_join.left, ctx, extension_codec)?; + let left = into_physical_plan(&sort_join.left, ctx, codec, proto_converter)?; let left_schema = left.schema(); - let right = into_physical_plan(&sort_join.right, ctx, extension_codec)?; + let right = into_physical_plan(&sort_join.right, ctx, codec, proto_converter)?; let right_schema = right.schema(); let filter = sort_join @@ -1820,13 +2031,13 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f .column_indices @@ -1883,17 +2094,17 @@ impl protobuf::PhysicalPlanNode { .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1980,9 +2191,10 @@ impl protobuf::PhysicalPlanNode { field_stream: &protobuf::CooperativeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&field_stream.input, ctx, extension_codec)?; + let input = into_physical_plan(&field_stream.input, ctx, codec, proto_converter)?; Ok(Arc::new(CooperativeExec::new(input))) } @@ -1990,10 +2202,11 @@ impl protobuf::PhysicalPlanNode { &self, async_func: &protobuf::AsyncFuncExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&async_func.input, ctx, extension_codec)?; + into_physical_plan(&async_func.input, ctx, codec, proto_converter)?; if async_func.async_exprs.len() != async_func.async_expr_names.len() { return internal_err!( @@ -2006,11 +2219,11 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(async_func.async_expr_names.iter()) .map(|(expr, name)| { - let physical_expr = parse_physical_expr( + let physical_expr = proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?; Ok(Arc::new(AsyncFuncExpr::try_new( @@ -2026,7 +2239,7 @@ impl protobuf::PhysicalPlanNode { fn try_from_explain_exec( exec: &ExplainExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Explain( @@ -2045,16 +2258,20 @@ impl protobuf::PhysicalPlanNode { fn try_from_projection_exec( exec: &ProjectionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() - .map(|proj_expr| serialize_physical_expr(&proj_expr.expr, extension_codec)) + .map(|proj_expr| { + proto_converter.physical_expr_to_proto(&proj_expr.expr, codec) + }) .collect::>>()?; let expr_name = exec .expr() @@ -2074,11 +2291,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_analyze_exec( exec: &AnalyzeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( @@ -2094,20 +2313,22 @@ impl protobuf::PhysicalPlanNode { fn try_from_filter_exec( exec: &FilterExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(serialize_physical_expr( - exec.predicate(), - extension_codec, - )?), + expr: Some( + proto_converter + .physical_expr_to_proto(exec.predicate(), codec)?, + ), default_filter_selectivity: exec.default_selectivity() as u32, projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() @@ -2119,11 +2340,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_global_limit_exec( limit: &GlobalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -2142,11 +2365,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_local_limit_exec( limit: &LocalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( @@ -2160,22 +2385,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_hash_join_exec( exec: &HashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on: Vec = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2189,7 +2417,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2216,6 +2444,23 @@ impl protobuf::PhysicalPlanNode { PartitionMode::Auto => protobuf::PartitionMode::Auto, }; + let dynamic_filter = exec + .dynamic_filter() + .map(|df| { + let df_expr: Arc = + Arc::clone(df) as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?; + let dynamic_filter_routing_mode = match exec.dynamic_filter_routing_mode() { + DynamicFilterRoutingMode::CaseHash => { + protobuf::HashJoinDynamicFilterRoutingMode::CaseHash + } + DynamicFilterRoutingMode::PartitionIndex => { + protobuf::HashJoinDynamicFilterRoutingMode::PartitionIndex + } + }; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { @@ -2229,6 +2474,8 @@ impl protobuf::PhysicalPlanNode { projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), + dynamic_filter, + dynamic_filter_routing_mode: dynamic_filter_routing_mode.into(), }, ))), }) @@ -2236,22 +2483,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2265,7 +2515,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2302,10 +2552,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2322,10 +2572,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2354,22 +2604,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_merge_join_exec( exec: &SortMergeJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2383,7 +2636,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2423,7 +2676,7 @@ impl protobuf::PhysicalPlanNode { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::SortMergeJoin(Box::new( - protobuf::SortMergeJoinExecNode { + SortMergeJoinExecNode { left: Some(Box::new(left)), right: Some(Box::new(right)), on, @@ -2438,15 +2691,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_cross_join_exec( exec: &CrossJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( @@ -2460,7 +2716,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_aggregate_exec( exec: &AggregateExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let groups: Vec = exec .group_expr() @@ -2480,13 +2737,15 @@ impl protobuf::PhysicalPlanNode { let filter = exec .filter_expr() .iter() - .map(|expr| serialize_maybe_filter(expr.to_owned(), extension_codec)) + .map(|expr| serialize_maybe_filter(expr.to_owned(), codec, proto_converter)) .collect::>>()?; let agg = exec .aggr_expr() .iter() - .map(|expr| serialize_physical_aggr_expr(expr.to_owned(), extension_codec)) + .map(|expr| { + serialize_physical_aggr_expr(expr.to_owned(), codec, proto_converter) + }) .collect::>>()?; let agg_names = exec @@ -2505,23 +2764,24 @@ impl protobuf::PhysicalPlanNode { } }; let input_schema = exec.input_schema(); - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let null_expr = exec .group_expr() .null_expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; let limit = exec.limit().map(|value| protobuf::AggLimit { @@ -2543,6 +2803,14 @@ impl protobuf::PhysicalPlanNode { groups, limit, has_grouping_set: exec.group_expr().has_grouping_set(), + dynamic_filter: exec + .dynamic_filter() + .map(|df| { + let df_expr: Arc = + Arc::clone(df) as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?, }, ))), }) @@ -2550,7 +2818,7 @@ impl protobuf::PhysicalPlanNode { fn try_from_empty_exec( empty: &EmptyExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { @@ -2562,7 +2830,7 @@ impl protobuf::PhysicalPlanNode { fn try_from_placeholder_row_exec( empty: &PlaceholderRowExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { @@ -2576,11 +2844,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_coalesce_batches_exec( coalesce_batches: &CoalesceBatchesExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( coalesce_batches.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( @@ -2595,7 +2865,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_source_exec( data_source_exec: &DataSourceExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let data_source = data_source_exec.data_source(); if let Some(maybe_csv) = data_source.as_any().downcast_ref::() { @@ -2606,7 +2877,8 @@ impl protobuf::PhysicalPlanNode { protobuf::CsvScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_csv, - extension_codec, + codec, + proto_converter, )?), has_header: csv_config.has_header(), delimiter: byte_to_string( @@ -2647,7 +2919,8 @@ impl protobuf::PhysicalPlanNode { protobuf::JsonScanExecNode { base_conf: Some(serialize_file_scan_config( scan_conf, - extension_codec, + codec, + proto_converter, )?), }, )), @@ -2661,14 +2934,15 @@ impl protobuf::PhysicalPlanNode { { let predicate = conf .filter() - .map(|pred| serialize_physical_expr(&pred, extension_codec)) + .map(|pred| proto_converter.physical_expr_to_proto(&pred, codec)) .transpose()?; return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_parquet, - extension_codec, + codec, + proto_converter, )?), predicate, parquet_options: Some(conf.table_parquet_options().try_into()?), @@ -2686,7 +2960,8 @@ impl protobuf::PhysicalPlanNode { protobuf::AvroScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_avro, - extension_codec, + codec, + proto_converter, )?), }, )), @@ -2719,7 +2994,8 @@ impl protobuf::PhysicalPlanNode { .map(|ordering| { let sort_exprs = serialize_physical_sort_exprs( ordering.to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok::<_, DataFusionError>(protobuf::PhysicalSortExprNodeCollection { physical_sort_expr_nodes: sort_exprs, @@ -2746,11 +3022,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_coalesce_partitions_exec( exec: &CoalescePartitionsExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( @@ -2764,15 +3042,17 @@ impl protobuf::PhysicalPlanNode { fn try_from_repartition_exec( exec: &RepartitionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let pb_partitioning = - serialize_partitioning(exec.partitioning(), extension_codec)?; + serialize_partitioning(exec.partitioning(), codec, proto_converter)?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( @@ -2786,29 +3066,36 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_exec( exec: &SortExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; + let input = proto_converter.execution_plan_to_proto(exec.input(), codec)?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; + let dynamic_filter = exec + .dynamic_filter() + .map(|df| { + let df_expr: Arc = df as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Sort(Box::new( protobuf::SortExecNode { @@ -2819,6 +3106,7 @@ impl protobuf::PhysicalPlanNode { _ => -1, }, preserve_partitioning: exec.preserve_partitioning(), + dynamic_filter, }, ))), }) @@ -2826,14 +3114,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_union_exec( union: &UnionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let mut inputs: Vec = vec![]; for input in union.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Union(protobuf::UnionExecNode { @@ -2844,14 +3136,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_interleave_exec( interleave: &InterleaveExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let mut inputs: Vec = vec![]; for input in interleave.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Interleave( @@ -2862,25 +3158,28 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_preserving_merge_exec( exec: &SortPreservingMergeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -2898,15 +3197,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_nested_loop_join_exec( exec: &NestedLoopJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); @@ -2915,7 +3217,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2953,23 +3255,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_window_agg_exec( exec: &WindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; Ok(protobuf::PhysicalPlanNode { @@ -2986,23 +3290,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_bounded_window_agg_exec( exec: &BoundedWindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -3035,12 +3341,14 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_sink_exec( exec: &DataSinkExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let sort_order = match exec.sort_order() { Some(requirements) => { @@ -3049,10 +3357,10 @@ impl protobuf::PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; @@ -3112,11 +3420,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_unnest_exec( exec: &UnnestExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3145,11 +3455,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_cooperative_exec( exec: &CooperativeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3278,18 +3590,21 @@ impl protobuf::PhysicalPlanNode { fn try_from_async_func_exec( exec: &AsyncFuncExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( Arc::clone(exec.input()), - extension_codec, + codec, + proto_converter, )?; let mut async_exprs = vec![]; let mut async_expr_names = vec![]; for async_expr in exec.async_exprs() { - async_exprs.push(serialize_physical_expr(&async_expr.func, extension_codec)?); + async_exprs + .push(proto_converter.physical_expr_to_proto(&async_expr.func, codec)?); async_expr_names.push(async_expr.name.clone()) } @@ -3319,12 +3634,12 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { &self, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result>; fn try_from_physical_plan( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result where Self: Sized; @@ -3336,9 +3651,15 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { buf: &[u8], inputs: &[Arc], ctx: &TaskContext, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result>; - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()>; + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()>; fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!("PhysicalExtensionCodec is not provided for scalar function {name}") @@ -3392,6 +3713,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { _buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { not_impl_err!("PhysicalExtensionCodec is not provided") } @@ -3400,11 +3722,44 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { &self, _node: Arc, _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result<()> { not_impl_err!("PhysicalExtensionCodec is not provided") } } +/// Controls the conversion of physical plans and expressions to and from their +/// Protobuf variants. Using this trait, users can perform optimizations on the +/// conversion process or collect performance metrics. +pub trait PhysicalProtoConverterExtension { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result>; + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result>; + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; +} + /// DataEncoderTuple captures the position of the encoder /// in the codec list that was used to encode the data and actual encoded data #[derive(Clone, PartialEq, prost::Message)] @@ -3418,6 +3773,310 @@ struct DataEncoderTuple { pub blob: Vec, } +pub struct DefaultPhysicalProtoConverter; +impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + // Default implementation calls the free function + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +/// Internal serializer that adds expr_id to expressions. +/// Created fresh for each serialization operation. +struct DeduplicatingSerializer { + /// Random salt combined with pointer addresses and process ID to create globally unique expr_ids. + session_id: u64, +} + +impl DeduplicatingSerializer { + fn new() -> Self { + Self { + session_id: rand::random(), + } + } + + fn hash(&self, ptr: u64) -> u64 { + // Hash session_id, pointer address, and process ID together to create expr_id. + // - session_id: random per serializer, prevents collisions when merging serializations + // - ptr: unique address per Arc within a process + // - pid: prevents collisions if serializer is shared across processes + let mut hasher = DefaultHasher::new(); + self.session_id.hash(&mut hasher); + ptr.hash(&mut hasher); + std::process::id().hash(&mut hasher); + hasher.finish() + } +} + +impl PhysicalProtoConverterExtension for DeduplicatingSerializer { + fn proto_to_execution_plan( + &self, + _ctx: &TaskContext, + _codec: &dyn PhysicalExtensionCodec, + _proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + internal_err!("DeduplicatingSerializer cannot deserialize execution plans") + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + _proto: &protobuf::PhysicalExprNode, + _ctx: &TaskContext, + _input_schema: &Schema, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + internal_err!("DeduplicatingSerializer cannot deserialize physical expressions") + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let mut proto = serialize_physical_expr_with_converter(expr, codec, self)?; + // Special case for dynamic filters. Two expressions may live in separate Arcs but + // point to the same inner dynamic filter state. This inner state must be deduplicated. + if let Some(dynamic_filter) = + expr.as_any().downcast_ref::() + { + proto.dynamic_filter_inner_id = Some(self.hash(dynamic_filter.inner_id())) + } + proto.expr_id = Some(self.hash(Arc::as_ptr(expr) as *const () as u64)); + + Ok(proto) + } +} + +/// Internal deserializer that caches expressions by expr_id. +/// Created fresh for each deserialization operation. +#[derive(Default)] +struct DeduplicatingDeserializer { + /// Cache mapping expr_id to deserialized expressions. + cache: RefCell>>, + /// Cache mapping dynamic_filter_inner_id to the first deserialized DynamicFilterPhysicalExpr. + /// This ensures that multiple dynamic filters with the same dynamic_filter_inner_id + /// can share the same inner state after deserialization. + dynamic_filter_cache: RefCell>>, +} + +impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + _plan: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + internal_err!("DeduplicatingDeserializer cannot serialize execution plans") + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + // The entire expr is cached, so re-use it. + if let Some(expr_id) = proto.expr_id + && let Some(cached) = self.cache.borrow().get(&expr_id) + { + return Ok(Arc::clone(cached)); + } + + // Cache miss, we must deserialize the expr. + let mut expr = + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self)?; + + // Check if we need to share inner state with a cached dynamic filter + if let Some(dynamic_filter_id) = proto.dynamic_filter_inner_id { + if let Some(cached_filter) = + self.dynamic_filter_cache.borrow().get(&dynamic_filter_id) + { + // Get the base filter's structure + let Some(cached_df) = cached_filter + .as_any() + .downcast_ref::() + else { + return internal_err!( + "dynamic filter cache returned an expression that is not a DynamicFilterPhysicalExpr" + ); + }; + + // Get the base filter's structure + let dynamic_filter_expr = (expr as Arc).downcast::() + .map_err(|_| internal_datafusion_err!("dynamic_filter_id present in proto, but the expression was not a DynamicFilterPhysicalExpr"))?; + expr = Arc::new(dynamic_filter_expr.new_from_source(cached_df)?) + as Arc; + } else { + // Cache it + self.dynamic_filter_cache + .borrow_mut() + .insert(dynamic_filter_id, Arc::clone(&expr)); + } + }; + + // Cache it if the cache key is available. + if let Some(expr_id) = proto.expr_id { + self.cache.borrow_mut().insert(expr_id, Arc::clone(&expr)); + }; + + Ok(expr) + } + + fn physical_expr_to_proto( + &self, + _expr: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result { + internal_err!("DeduplicatingDeserializer cannot serialize physical expressions") + } +} + +/// A proto converter that adds expression deduplication during serialization +/// and deserialization. +/// +/// During serialization, each expression's Arc pointer address is XORed with a +/// random session_id to create a salted `expr_id`. This prevents cross-process +/// collisions when serialized plans are merged. +/// +/// During deserialization, expressions with the same `expr_id` share the same +/// Arc, reducing memory usage for plans with duplicate expressions (e.g., large +/// IN lists) and supporting correctly linking [`DynamicFilterPhysicalExpr`] instances. +/// +/// This converter is stateless - it creates internal serializers/deserializers +/// on demand for each operation. +/// +/// [`DynamicFilterPhysicalExpr`]: https://docs.rs/datafusion-physical-expr/latest/datafusion_physical_expr/expressions/struct.DynamicFilterPhysicalExpr.html +#[derive(Debug, Default, Clone, Copy)] +pub struct DeduplicatingProtoConverter {} + +impl PhysicalProtoConverterExtension for DeduplicatingProtoConverter { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + let deserializer = DeduplicatingDeserializer::default(); + let plan = + proto.try_into_physical_plan_with_converter(ctx, codec, &deserializer)?; + Ok(plan) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + let serializer = DeduplicatingSerializer::new(); + let proto = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + &serializer, + )?; + Ok(proto) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + let deserializer = DeduplicatingDeserializer::default(); + deserializer.proto_to_physical_expr(proto, ctx, input_schema, codec) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let serializer = DeduplicatingSerializer::new(); + serializer.physical_expr_to_proto(expr, codec) + } +} + /// A PhysicalExtensionCodec that tries one of multiple inner codecs /// until one works #[derive(Debug)] @@ -3492,12 +4151,22 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { buf: &[u8], inputs: &[Arc], ctx: &TaskContext, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, ctx)) + self.decode_protobuf(buf, |codec, data| { + codec.try_decode(data, inputs, ctx, proto_converter) + }) } - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - self.encode_protobuf(buf, |codec, data| codec.try_encode(Arc::clone(&node), data)) + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { + self.encode_protobuf(buf, |codec, data| { + codec.try_encode(Arc::clone(&node), data, proto_converter) + }) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { @@ -3520,10 +4189,11 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if let Some(field) = node { - field.try_into_physical_plan(ctx, extension_codec) + proto_converter.proto_to_execution_plan(ctx, codec, field) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 9558effb8a2a6..fb7c5f7b736b6 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,8 +24,7 @@ use datafusion_common::{ DataFusionError, Result, internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::file_sink_config::FileSink; -use datafusion_datasource::file_sink_config::FileSinkConfig; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::{FileRange, PartitionedFile}; use datafusion_datasource_csv::file_format::CsvSink; use datafusion_datasource_json::file_format::JsonSink; @@ -36,36 +35,45 @@ use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -use datafusion_physical_plan::expressions::LikeExpr; use datafusion_physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + BinaryExpr, CaseExpr, CastExpr, Column, DynamicFilterPhysicalExpr, + DynamicFilterSnapshot, InListExpr, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, + NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, }; use datafusion_physical_plan::joins::{HashExpr, HashTableLookupExpr}; use datafusion_physical_plan::udaf::AggregateFunctionExpr; use datafusion_physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; use crate::protobuf::{ self, PhysicalSortExprNode, PhysicalSortExprNodeCollection, physical_aggregate_expr_node, physical_window_expr_node, }; -use super::PhysicalExtensionCodec; - #[expect(clippy::needless_pass_by_value)] pub fn serialize_physical_aggr_expr( aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; - let order_bys = - serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; + let expressions = + serialize_physical_exprs(&aggr_expr.expressions(), codec, proto_converter)?; + let order_bys = serialize_physical_sort_exprs( + aggr_expr.order_bys().iter().cloned(), + codec, + proto_converter, + )?; let name = aggr_expr.fun().name().to_string(); let mut buf = Vec::new(); codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -100,6 +108,7 @@ fn serialize_physical_window_aggr_expr( pub fn serialize_physical_window_expr( window_expr: &Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let expr = window_expr.as_any(); let args = window_expr.expressions().to_vec(); @@ -155,9 +164,14 @@ pub fn serialize_physical_window_expr( return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - let args = serialize_physical_exprs(&args, codec)?; - let partition_by = serialize_physical_exprs(window_expr.partition_by(), codec)?; - let order_by = serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?; + let args = serialize_physical_exprs(&args, codec, proto_converter)?; + let partition_by = + serialize_physical_exprs(window_expr.partition_by(), codec, proto_converter)?; + let order_by = serialize_physical_sort_exprs( + window_expr.order_by().to_vec(), + codec, + proto_converter, + )?; let window_frame: protobuf::WindowFrame = window_frame .as_ref() .try_into() @@ -179,22 +193,24 @@ pub fn serialize_physical_window_expr( pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator, { sort_exprs .into_iter() - .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec)) + .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec, proto_converter)) .collect() } pub fn serialize_physical_sort_expr( sort_expr: PhysicalSortExpr, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let PhysicalSortExpr { expr, options } = sort_expr; - let expr = serialize_physical_expr(&expr, codec)?; + let expr = proto_converter.physical_expr_to_proto(&expr, codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !options.descending, @@ -205,13 +221,14 @@ pub fn serialize_physical_sort_expr( pub fn serialize_physical_exprs<'a, I>( values: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator>, { values .into_iter() - .map(|value| serialize_physical_expr(value, codec)) + .map(|value| proto_converter.physical_expr_to_proto(value, codec)) .collect() } @@ -223,6 +240,65 @@ pub fn serialize_physical_expr( value: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { + serialize_physical_expr_with_converter( + value, + codec, + &DefaultPhysicalProtoConverter {}, + ) +} + +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]). +/// A [`PhysicalProtoConverterExtension`] can be provided to handle the +/// conversion process (see [`PhysicalProtoConverterExtension::physical_expr_to_proto`]). +pub fn serialize_physical_expr_with_converter( + value: &Arc, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + // Check for DynamicFilterPhysicalExpr before snapshotting. + // We need to handle it before snapshot_physical_expr because snapshot() + // replaces the DynamicFilterPhysicalExpr with its inner expression. + if let Some(df) = value.as_any().downcast_ref::() { + // Capture all state atomically + let snapshot = DynamicFilterSnapshot::from(df); + + let children = snapshot + .children() + .iter() + .map(|child| proto_converter.physical_expr_to_proto(child, codec)) + .collect::>>()?; + + let remapped_children = if let Some(remapped) = snapshot.remapped_children() { + remapped + .iter() + .map(|child| proto_converter.physical_expr_to_proto(child, codec)) + .collect::>>()? + } else { + vec![] + }; + + let inner_expr = Box::new( + proto_converter.physical_expr_to_proto(snapshot.inner_expr(), codec)?, + ); + + return Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::DynamicFilter( + Box::new(protobuf::PhysicalDynamicFilterNode { + children, + remapped_children, + generation: snapshot.generation(), + inner_expr: Some(inner_expr), + is_complete: snapshot.is_complete(), + }), + )), + }); + } + // Snapshot the expr in case it has dynamic predicate state so // it can be serialized let value = snapshot_physical_expr(Arc::clone(value))?; @@ -248,12 +324,16 @@ pub fn serialize_physical_expr( )), }; return Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal(value)), }); } if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Column( protobuf::PhysicalColumn { name: expr.name().to_string(), @@ -263,6 +343,8 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::UnknownColumn( protobuf::UnknownColumn { name: expr.name().to_string(), @@ -271,18 +353,26 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(serialize_physical_expr(expr.left(), codec)?)), - r: Some(Box::new(serialize_physical_expr(expr.right(), codec)?)), + l: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.left(), codec)?, + )), + r: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.right(), codec)?, + )), op: format!("{:?}", expr.op()), }); Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( binary_expr, )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some( protobuf::physical_expr_node::ExprType::Case( Box::new( @@ -290,14 +380,21 @@ pub fn serialize_physical_expr( expr: expr .expr() .map(|exp| { - serialize_physical_expr(exp, codec).map(Box::new) + proto_converter + .physical_expr_to_proto(exp, codec) + .map(Box::new) }) .transpose()?, when_then_expr: expr .when_then_expr() .iter() .map(|(when_expr, then_expr)| { - serialize_when_then_expr(when_expr, then_expr, codec) + serialize_when_then_expr( + when_expr, + then_expr, + codec, + proto_converter, + ) }) .collect::, @@ -305,7 +402,11 @@ pub fn serialize_physical_expr( >>()?, else_expr: expr .else_expr() - .map(|a| serialize_physical_expr(a, codec).map(Box::new)) + .map(|a| { + proto_converter + .physical_expr_to_proto(a, codec) + .map(Box::new) + }) .transpose()?, }, ), @@ -314,66 +415,96 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( protobuf::PhysicalNot { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( protobuf::PhysicalInListNode { - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - list: serialize_physical_exprs(expr.list(), codec)?, + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.expr(), codec)?, + )), + list: serialize_physical_exprs(expr.list(), codec, proto_converter)?, negated: expr.negated(), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( protobuf::PhysicalNegativeNode { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }, ))), }) } else if let Some(lit) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( lit.value().try_into()?, )), }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(cast.expr(), codec)?, + )), arrow_type: Some(cast.cast_type().try_into()?), }, ))), }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( protobuf::PhysicalTryCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(cast.expr(), codec)?, + )), arrow_type: Some(cast.cast_type().try_into()?), }, ))), @@ -382,10 +513,12 @@ pub fn serialize_physical_expr( let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), - args: serialize_physical_exprs(expr.args(), codec)?, + args: serialize_physical_exprs(expr.args(), codec, proto_converter)?, fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), nullable: expr.nullable(), @@ -398,24 +531,33 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( protobuf::PhysicalLikeExprNode { negated: expr.negated(), case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - pattern: Some(Box::new(serialize_physical_expr( - expr.pattern(), - codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.expr(), codec)?, + )), + pattern: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.pattern(), codec)?, + )), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { let (s0, s1, s2, s3) = expr.seeds(); Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( protobuf::PhysicalHashExprNode { - on_columns: serialize_physical_exprs(expr.on_columns(), codec)?, + on_columns: serialize_physical_exprs( + expr.on_columns(), + codec, + proto_converter, + )?, seed0: s0, seed1: s1, seed2: s2, @@ -431,9 +573,11 @@ pub fn serialize_physical_expr( let inputs: Vec = value .children() .into_iter() - .map(|e| serialize_physical_expr(e, codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>()?; Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, )), @@ -449,6 +593,7 @@ pub fn serialize_physical_expr( pub fn serialize_partitioning( partitioning: &Partitioning, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let serialized_partitioning = match partitioning { Partitioning::RoundRobinBatch(partition_count) => protobuf::Partitioning { @@ -457,7 +602,8 @@ pub fn serialize_partitioning( )), }, Partitioning::Hash(exprs, partition_count) => { - let serialized_exprs = serialize_physical_exprs(exprs, codec)?; + let serialized_exprs = + serialize_physical_exprs(exprs, codec, proto_converter)?; protobuf::Partitioning { partition_method: Some(protobuf::partitioning::PartitionMethod::Hash( protobuf::PhysicalHashRepartition { @@ -480,10 +626,11 @@ fn serialize_when_then_expr( when_expr: &Arc, then_expr: &Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(serialize_physical_expr(when_expr, codec)?), - then_expr: Some(serialize_physical_expr(then_expr, codec)?), + when_expr: Some(proto_converter.physical_expr_to_proto(when_expr, codec)?), + then_expr: Some(proto_converter.physical_expr_to_proto(then_expr, codec)?), }) } @@ -539,6 +686,7 @@ impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { pub fn serialize_file_scan_config( conf: &FileScanConfig, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let file_groups = conf .file_groups @@ -548,7 +696,8 @@ pub fn serialize_file_scan_config( let mut output_orderings = vec![]; for order in &conf.output_ordering { - let ordering = serialize_physical_sort_exprs(order.to_vec(), codec)?; + let ordering = + serialize_physical_sort_exprs(order.to_vec(), codec, proto_converter)?; output_orderings.push(ordering) } @@ -563,8 +712,7 @@ pub fn serialize_file_scan_config( fields.extend(conf.table_partition_cols().iter().cloned()); let schema = Arc::new( - arrow::datatypes::Schema::new(fields.clone()) - .with_metadata(conf.file_schema().metadata.clone()), + Schema::new(fields.clone()).with_metadata(conf.file_schema().metadata.clone()), ); let projection_exprs = conf @@ -579,7 +727,10 @@ pub fn serialize_file_scan_config( .map(|expr| { Ok(protobuf::ProjectionExpr { alias: expr.alias.to_string(), - expr: Some(serialize_physical_expr(&expr.expr, codec)?), + expr: Some( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + ), }) }) .collect::>>()?, @@ -614,11 +765,12 @@ pub fn serialize_file_scan_config( pub fn serialize_maybe_filter( expr: Option>, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(serialize_physical_expr(&expr, codec)?), + expr: Some(proto_converter.physical_expr_to_proto(&expr, codec)?), }), } } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index aa5458849330f..73eb5c1a2e511 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -18,28 +18,12 @@ use std::any::Any; use std::collections::HashMap; use std::fmt::{Display, Formatter}; - -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::vec; -use crate::cases::{ - CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, - MyRegexUdfNode, -}; - use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{Fields, TimeUnit}; -use datafusion::physical_expr::aggregate::AggregateExprBuilder; -use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::metrics::MetricType; -use datafusion_datasource::TableSchema; -use datafusion_expr::dml::InsertOp; -use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; -use datafusion_functions_aggregate::array_agg::array_agg_udaf; -use datafusion_functions_aggregate::min_max::max_udaf; -use prost::Message; - use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -63,15 +47,19 @@ use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::functions_window::nth_value::nth_value_udwf; use datafusion::functions_window::row_number::row_number_udwf; use datafusion::logical_expr::{JoinType, Operator, Volatility, create_udf}; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion::physical_expr::{ LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_optimizer::filter_pushdown::FilterPushdown; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion::physical_plan::analyze::AnalyzeExec; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ @@ -79,10 +67,11 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::{ - HashJoinExec, HashTableLookupExpr, NestedLoopJoinExec, PartitionMode, - SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, + DynamicFilterRoutingMode, HashJoinExec, HashTableLookupExpr, NestedLoopJoinExec, + PartitionMode, SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::metrics::MetricType; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; @@ -94,7 +83,8 @@ use datafusion::physical_plan::windows::{ create_udwf_window_expr, }; use datafusion::physical_plan::{ - ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, Statistics, displayable, + DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, Partitioning, + PhysicalExpr, SendableRecordBatchStream, Statistics, displayable, }; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion::scalar::ScalarValue; @@ -104,23 +94,50 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - DataFusionError, NullEquality, Result, UnnestOptions, internal_datafusion_err, - internal_err, not_impl_err, + DataFusionError, NullEquality, Result, UnnestOptions, exec_datafusion_err, + internal_datafusion_err, internal_err, not_impl_err, }; +use datafusion_datasource::TableSchema; +use datafusion_datasource::file::FileSource; use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, }; +use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::min_max::max_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_physical_plan::joins::join_hash_map::JoinHashMapU32; +use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; use datafusion_proto::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + AsExecutionPlan, DeduplicatingProtoConverter, DefaultPhysicalExtensionCodec, + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf; +use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; +use prost::Message; + +use crate::cases::{ + CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, + MyRegexUdfNode, +}; + +use datafusion_physical_expr::expressions::{ + DynamicFilterPhysicalExpr, DynamicFilterSnapshot, DynamicFilterUpdate, }; -use datafusion_proto::protobuf::{self, PhysicalPlanNode}; +use datafusion_physical_expr::utils::reassign_expr_columns; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -128,7 +145,8 @@ use datafusion_proto::protobuf::{self, PhysicalPlanNode}; fn roundtrip_test(exec_plan: Arc) -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, &ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; Ok(()) } @@ -142,13 +160,19 @@ fn roundtrip_test_and_return( exec_plan: Arc, ctx: &SessionContext, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) - .expect("to proto"); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), codec) - .expect("from proto"); + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan), + codec, + proto_converter, + )?; + let result_exec_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + codec, + proto_converter, + )?; pretty_assertions::assert_eq!( format!("{exec_plan:?}"), @@ -168,7 +192,8 @@ fn roundtrip_test_with_context( ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -176,9 +201,10 @@ fn roundtrip_test_with_context( /// query results are identical. async fn roundtrip_test_sql_with_context(sql: &str, ctx: &SessionContext) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; - roundtrip_test_and_return(initial_plan, ctx, &codec)?; + roundtrip_test_and_return(initial_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -985,7 +1011,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } impl Display for CustomPredicateExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CustomPredicateExpr") } } @@ -1031,6 +1057,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { _buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { unreachable!() } @@ -1039,6 +1066,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { &self, _node: Arc, _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result<()> { unreachable!() } @@ -1078,7 +1106,12 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { let exec_plan = DataSourceExec::from_data_source(scan_config); let ctx = SessionContext::new(); - roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; + roundtrip_test_and_return( + exec_plan, + &ctx, + &CustomPhysicalExtensionCodec {}, + &DefaultPhysicalProtoConverter {}, + )?; Ok(()) } @@ -1139,6 +1172,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { _buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { not_impl_err!("No extension codec provided") } @@ -1147,6 +1181,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { &self, _node: Arc, _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result<()> { not_impl_err!("No extension codec provided") } @@ -1284,7 +1319,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1331,7 +1367,8 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1402,7 +1439,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1526,12 +1564,14 @@ fn roundtrip_csv_sink() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let roundtrip_plan = roundtrip_test_and_return( Arc::new(DataSinkExec::new(input, data_sink, Some(sort_order))), &ctx, &codec, - ) - .unwrap(); + &proto_converter, + )?; let roundtrip_plan = roundtrip_plan .as_any() @@ -1972,6 +2012,7 @@ async fn test_serialize_deserialize_tpch_queries() -> Result<()> { // serialize the physical plan let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; @@ -2093,6 +2134,7 @@ async fn test_tpch_part_in_list_query_with_real_parquet_data() -> Result<()> { // Serialize the physical plan - bug may happen here already but not necessarily manifests let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; // This will fail with the bug, but should succeed when fixed @@ -2353,8 +2395,9 @@ fn roundtrip_hash_table_lookup_expr_to_lit() -> Result<()> { // Serialize let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(filter.clone(), &codec) + + let proto: PhysicalPlanNode = + PhysicalPlanNode::try_from_physical_plan(filter.clone(), &codec) .expect("serialization should succeed"); // Deserialize @@ -2404,3 +2447,1337 @@ fn roundtrip_hash_expr() -> Result<()> { ); roundtrip_test(filter) } + +#[test] +fn custom_proto_converter_intercepts() -> Result<()> { + #[derive(Default)] + struct CustomConverterInterceptor { + num_proto_plans: RwLock, + num_physical_plans: RwLock, + num_proto_exprs: RwLock, + num_physical_exprs: RwLock, + } + + impl PhysicalProtoConverterExtension for CustomConverterInterceptor { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + { + let mut counter = self + .num_proto_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + { + let mut counter = self + .num_physical_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + { + let mut counter = self + .num_proto_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + { + let mut counter = self + .num_physical_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + serialize_physical_expr_with_converter(expr, codec, self) + } + } + + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let sort_exprs = [ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }, + ] + .into(); + + let exec_plan = Arc::new(SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema)))); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = CustomConverterInterceptor::default(); + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; + + assert_eq!(*proto_converter.num_proto_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_proto_plans.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_plans.read().unwrap(), 2); + + Ok(()) +} + +/// Test that expression deduplication works during deserialization. +/// When the same expression Arc is serialized multiple times, it should be +/// deduplicated on deserialization (sharing the same Arc). +#[test] +fn test_expression_deduplication() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a shared expression that will be used multiple times + let shared_col: Arc = Arc::new(Column::new("a", 0)); + + // Create an InList expression that uses the same column Arc multiple times + // This simulates a real-world scenario where expressions are shared + let in_list_expr = in_list( + Arc::clone(&shared_col), + vec![lit(1i64), lit(2i64), lit(3i64)], + &false, + &schema, + )?; + + // Create a binary expression that uses the shared column and the in_list result + let binary_expr: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&shared_col), + Operator::Eq, + lit(42i64), + )); + + // Create a plan that has both expressions (they share the `shared_col` Arc) + let input = Arc::new(EmptyExec::new(schema.clone())); + let filter = Arc::new(FilterExec::try_new(in_list_expr, input)?); + let projection_exprs = vec![ProjectionExpr { + expr: binary_expr, + alias: "result".to_string(), + }]; + let exec_plan = Arc::new(ProjectionExec::try_new(projection_exprs, filter)?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Perform roundtrip + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Create a new converter for deserialization (fresh cache) + let deser_converter = DeduplicatingProtoConverter {}; + let result_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Verify the plan structure is correct + pretty_assertions::assert_eq!(format!("{exec_plan:?}"), format!("{result_plan:?}")); + + Ok(()) +} + +/// Test that expression deduplication correctly shares Arcs for identical expressions. +/// This test verifies the core deduplication behavior. +#[test] +fn test_expression_deduplication_arc_sharing() -> Result<()> { + use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, + }; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a column expression + let col_expr: Arc = Arc::new(Column::new("a", 0)); + + // Create a projection that uses the SAME Arc twice + // After roundtrip, both should point to the same Arc + let projection_exprs = vec![ + ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&col_expr), // Same Arc! + alias: "a2".to_string(), + }, + ]; + + let input = Arc::new(EmptyExec::new(schema)); + let exec_plan = Arc::new(ProjectionExec::try_new(projection_exprs, input)?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Deserialize with a fresh converter + let deser_converter = DeduplicatingProtoConverter {}; + let result_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Get the projection from the result + let projection = result_plan + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + + let exprs: Vec<_> = projection.expr().iter().collect(); + assert_eq!(exprs.len(), 2); + + // The key test: both expressions should point to the same Arc after deduplication + // This is because they were the same Arc before serialization + assert!( + Arc::ptr_eq(&exprs[0].expr, &exprs[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + Ok(()) +} + +/// Test backward compatibility: protos without expr_id should still deserialize correctly. +#[test] +fn test_backward_compatibility_no_expr_id() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Manually create a proto without expr_id set + let proto = PhysicalExprNode { + expr_id: None, // Simulating old proto without this field + dynamic_filter_inner_id: None, + expr_type: Some( + datafusion_proto::protobuf::physical_expr_node::ExprType::Column( + datafusion_proto::protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + ), + ), + }; + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + + // Should deserialize without error + let result = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Verify the result is correct + let col = result + .as_any() + .downcast_ref::() + .expect("Expected Column"); + assert_eq!(col.name(), "a"); + assert_eq!(col.index(), 0); + + Ok(()) +} + +/// Test that deduplication works within a single plan deserialization and that +/// separate deserializations produce independent expressions (no cross-operation sharing). +#[test] +fn test_deduplication_within_plan_deserialization() -> Result<()> { + use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, + }; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a plan with expressions that will be deduplicated + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let projection_exprs = vec![ + ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&col_expr), // Same Arc - will be deduplicated + alias: "a2".to_string(), + }, + ]; + let exec_plan = Arc::new(ProjectionExec::try_new( + projection_exprs, + Arc::new(EmptyExec::new(schema)), + )?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // First deserialization + let plan1 = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &proto_converter, + )?; + + // Check that the plan was deserialized correctly with deduplication + let projection1 = plan1 + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + let exprs1: Vec<_> = projection1.expr().iter().collect(); + assert_eq!(exprs1.len(), 2); + assert!( + Arc::ptr_eq(&exprs1[0].expr, &exprs1[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + // Second deserialization + let plan2 = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &proto_converter, + )?; + + // Check that the second plan was also deserialized correctly + let projection2 = plan2 + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + let exprs2: Vec<_> = projection2.expr().iter().collect(); + assert_eq!(exprs2.len(), 2); + assert!( + Arc::ptr_eq(&exprs2[0].expr, &exprs2[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + // Check that there was no deduplication across deserializations + assert!( + !Arc::ptr_eq(&exprs1[0].expr, &exprs2[0].expr), + "Expected expressions from different deserializations to be different Arcs" + ); + assert!( + !Arc::ptr_eq(&exprs1[1].expr, &exprs2[1].expr), + "Expected expressions from different deserializations to be different Arcs" + ); + + Ok(()) +} + +/// Test that deduplication works within direct expression deserialization and that +/// separate deserializations produce independent expressions (no cross-operation sharing). +#[test] +fn test_deduplication_within_expr_deserialization() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a binary expression where both sides are the same Arc + // This allows us to test deduplication within a single deserialization + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let binary_expr: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&col_expr), + Operator::Plus, + Arc::clone(&col_expr), // Same Arc - will be deduplicated + )); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize the expression + let proto = proto_converter.physical_expr_to_proto(&binary_expr, &codec)?; + + // First expression deserialization + let expr1 = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Check that deduplication worked within the deserialization + let binary1 = expr1 + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + assert!( + Arc::ptr_eq(binary1.left(), binary1.right()), + "Expected both sides to share the same Arc after deduplication" + ); + + // Second expression deserialization + let expr2 = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Check that the second expression was also deserialized correctly + let binary2 = expr2 + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + assert!( + Arc::ptr_eq(binary2.left(), binary2.right()), + "Expected both sides to share the same Arc after deduplication" + ); + + // Check that there was no deduplication across deserializations + assert!( + !Arc::ptr_eq(binary1.left(), binary2.left()), + "Expected expressions from different deserializations to be different Arcs" + ); + assert!( + !Arc::ptr_eq(binary1.right(), binary2.right()), + "Expected expressions from different deserializations to be different Arcs" + ); + + Ok(()) +} + +#[test] +fn test_dynamic_filters_different_filter_same_inner_state() { + let filter_expr_1 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + // Column "a" is now at index 1, which creates a new filter. + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int64, false), + Field::new("a", DataType::Int64, false), + ])); + let filter_expr_2 = + reassign_expr_columns(Arc::clone(&filter_expr_1), &schema).unwrap(); + + // Meta-assertion: ensure this test is testing the case where the inner state is the same but + // the exprs are different + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + assert!(!outer_equal); + assert!(inner_equal); + test_deduplication_of_dynamic_filter_expression(filter_expr_1, filter_expr_2, schema) + .unwrap(); +} + +#[test] +fn test_dynamic_filters_same_filter() { + let filter_expr_1 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let filter_expr_2 = Arc::clone(&filter_expr_1); + + // Ensure this test is testing the case where the inner state is the same and the exprs are the same + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + assert!(outer_equal); + assert!(inner_equal); + test_deduplication_of_dynamic_filter_expression(filter_expr_1, filter_expr_2, schema) + .unwrap(); +} + +#[test] +fn test_dynamic_filters_different_filter() { + let filter_expr_1 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + let filter_expr_2 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Ensure this test is testing the case where the inner state is the different and the outer exprs are different + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + assert!(!outer_equal); + assert!(!inner_equal); + test_deduplication_of_dynamic_filter_expression(filter_expr_1, filter_expr_2, schema) + .unwrap(); +} + +#[test] +fn test_dynamic_filter_roundtrip() -> Result<()> { + // Create a dynamic filter with base children + let filter_expr = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + // Add remapped columns by reassigning to a schema where "a" is at index 1 + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int64, false), + Field::new("a", DataType::Int64, false), + ])); + let filter_expr = reassign_expr_columns(filter_expr, &schema).unwrap(); + + // Update its internal state + let df = filter_expr + .as_any() + .downcast_ref::() + .unwrap(); + df.update(DynamicFilterUpdate::Global(lit(42)))?; + df.update(DynamicFilterUpdate::Global(lit(100)))?; + df.mark_complete(); + + // Serialize + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let proto = converter.physical_expr_to_proto(&filter_expr, &codec)?; + + // Deserialize + let ctx = SessionContext::new(); + let deserialized_filter = converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + let deserialized_df = deserialized_filter + .as_any() + .downcast_ref::() + .expect("Should be DynamicFilterPhysicalExpr"); + + assert_eq!( + DynamicFilterSnapshot::from(df).to_string(), + DynamicFilterSnapshot::from(deserialized_df).to_string(), + "Snapshots should be equal" + ); + + Ok(()) +} + +/// Returns (outer_equal, inner_equal) +/// +/// outer_equal is true if the two arcs point to the same data. +/// inner_equal is true if the two dynamic filters have the same inner arc. +fn dynamic_filter_outer_inner_equal( + filter_expr_1: &Arc, + filter_expr_2: &Arc, +) -> (bool, bool) { + ( + std::ptr::addr_eq(Arc::as_ptr(filter_expr_1), Arc::as_ptr(filter_expr_2)), + filter_expr_1 + .as_any() + .downcast_ref::() + .unwrap() + .inner_id() + == filter_expr_2 + .as_any() + .downcast_ref::() + .unwrap() + .inner_id(), + ) +} + +fn test_deduplication_of_dynamic_filter_expression( + filter_expr_1: Arc, + filter_expr_2: Arc, + schema: Arc, +) -> Result<()> { + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + + // Create execution plan: FilterExec(filter2) -> FilterExec(filter1) -> EmptyExec + let empty_exec = Arc::new(EmptyExec::new(schema)) as Arc; + let filter_exec1 = + Arc::new(FilterExec::try_new(Arc::clone(&filter_expr_1), empty_exec)?) + as Arc; + let filter_exec2 = Arc::new(FilterExec::try_new( + Arc::clone(&filter_expr_2), + filter_exec1, + )?) as Arc; + + // Serialize the plan + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let proto = converter.execution_plan_to_proto(&filter_exec2, &codec)?; + + let outer_filter = match &proto.physical_plan_type { + Some(PhysicalPlanType::Filter(outer_filter)) => outer_filter, + _ => panic!("Expected PhysicalPlanType::Filter"), + }; + + let inner_filter = match &outer_filter.input { + Some(inner_input) => match &inner_input.physical_plan_type { + Some(PhysicalPlanType::Filter(inner_filter)) => inner_filter, + _ => panic!("Expected PhysicalPlanType::Filter"), + }, + _ => panic!("Expected inner input"), + }; + + let filter1_proto = inner_filter + .expr + .as_ref() + .expect("Should have filter expression"); + + let filter2_proto = outer_filter + .expr + .as_ref() + .expect("Should have filter expression"); + + // Both should have dynamic_filter_inner_id set + let filter1_dynamic_id = filter1_proto + .dynamic_filter_inner_id + .expect("Filter1 should have dynamic_filter_inner_id"); + let filter2_dynamic_id = filter2_proto + .dynamic_filter_inner_id + .expect("Filter2 should have dynamic_filter_inner_id"); + + assert_eq!( + inner_equal, + filter1_dynamic_id == filter2_dynamic_id, + "Dynamic filters sharing the same inner state should have the same dynamic_filter_inner_id" + ); + + let filter1_expr_id = filter1_proto.expr_id.expect("Should have expr_id"); + let filter2_expr_id = filter2_proto.expr_id.expect("Should have expr_id"); + assert_eq!( + outer_equal, + filter1_expr_id == filter2_expr_id, + "Different filters have different expr ids" + ); + + // Test deserialization - verify that filters with same dynamic_filter_inner_id share state + let ctx = SessionContext::new(); + let deserialized_plan = + converter.proto_to_execution_plan(ctx.task_ctx().as_ref(), &codec, &proto)?; + + // Extract the two filter expressions from the deserialized plan + let outer_filter = deserialized_plan + .as_any() + .downcast_ref::() + .expect("Should be FilterExec"); + let filter2_deserialized = outer_filter.predicate(); + + let inner_filter = outer_filter.children()[0] + .as_any() + .downcast_ref::() + .expect("Inner should be FilterExec"); + let filter1_deserialized = inner_filter.predicate(); + + // The Arcs should be different (different outer wrappers) + assert_eq!( + outer_equal, + Arc::ptr_eq(filter1_deserialized, filter2_deserialized), + "Deserialized filters should be different Arcs" + ); + + // Check if they're DynamicFilterPhysicalExpr (they might be snapshotted to Literal) + let (df1, df2) = match ( + filter1_deserialized + .as_any() + .downcast_ref::(), + filter2_deserialized + .as_any() + .downcast_ref::(), + ) { + (Some(df1), Some(df2)) => (df1, df2), + _ => panic!("Should be DynamicFilterPhysicalExpr"), + }; + + // But they should have the same inner_id (shared inner state) + assert_eq!( + inner_equal, + df1.inner_id() == df2.inner_id(), + "Deserialized filters should share inner state" + ); + + // Ensure the children and remapped children are equal after the roundtrip + let filter_1_before_roundtrip = filter_expr_1 + .as_any() + .downcast_ref::() + .unwrap(); + let filter_2_before_roundtrip = filter_expr_2 + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + DynamicFilterSnapshot::from(filter_1_before_roundtrip).to_string(), + DynamicFilterSnapshot::from(df1).to_string() + ); + assert_eq!( + DynamicFilterSnapshot::from(filter_2_before_roundtrip).to_string(), + DynamicFilterSnapshot::from(df2).to_string() + ); + + Ok(()) +} + +/// Test that session_id rotates between top-level serialization operations. +/// This verifies that each top-level serialization gets a fresh session_id, +/// which prevents cross-process collisions when serialized plans are merged. +#[test] +fn test_session_id_rotation_between_serializations() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let _schema = Arc::new(Schema::new(vec![field_a])); + + // Create a simple expression + let col_expr: Arc = Arc::new(Column::new("a", 0)); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // First serialization + let proto1 = proto_converter.physical_expr_to_proto(&col_expr, &codec)?; + let expr_id1 = proto1.expr_id.expect("Expected expr_id to be set"); + + // Second serialization with the same converter + // The session_id should have rotated, so the expr_id should be different + // even though we're serializing the same expression (same pointer address) + let proto2 = proto_converter.physical_expr_to_proto(&col_expr, &codec)?; + let expr_id2 = proto2.expr_id.expect("Expected expr_id to be set"); + + // The expr_ids should be different because session_id rotated + assert_ne!( + expr_id1, expr_id2, + "Expected different expr_ids due to session_id rotation between serializations" + ); + + // Also test that serializing the same expression multiple times within + // the same top-level operation would give the same expr_id (not testable + // here directly since each physical_expr_to_proto is a top-level operation, + // but the deduplication tests verify this indirectly) + + Ok(()) +} + +/// Test that session_id rotation works correctly with execution plans. +/// This verifies the end-to-end behavior with plan serialization. +#[test] +fn test_session_id_rotation_with_execution_plans() -> Result<()> { + use datafusion_proto::bytes::physical_plan_to_bytes_with_proto_converter; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a simple plan + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let projection_exprs = vec![ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }]; + let exec_plan = Arc::new(ProjectionExec::try_new( + projection_exprs.clone(), + Arc::new(EmptyExec::new(Arc::clone(&schema))), + )?); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // First serialization + let bytes1 = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Second serialization with the same converter + let bytes2 = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // The serialized bytes should be different due to different session_ids + // (specifically, the expr_id values embedded in the protobuf will differ) + assert_ne!( + bytes1.as_ref(), + bytes2.as_ref(), + "Expected different serialized bytes due to session_id rotation" + ); + + // But both should deserialize correctly + let ctx = SessionContext::new(); + let deser_converter = DeduplicatingProtoConverter {}; + + let plan1 = datafusion_proto::bytes::physical_plan_from_bytes_with_proto_converter( + bytes1.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + let plan2 = datafusion_proto::bytes::physical_plan_from_bytes_with_proto_converter( + bytes2.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Verify both plans have the expected structure + assert_eq!(plan1.schema(), plan2.schema()); + + Ok(()) +} + +/// Create a DataSourceExec backed by a ParquetSource that accepts filter pushdown, +/// along with a ConfigOptions that enables all dynamic filter pushdown options. +fn datasource_for_dynamic_filter_pushdown( + schema: &Arc, +) -> (Arc, ConfigOptions) { + let mut parquet_options = TableParquetOptions::new(); + parquet_options.global.pushdown_filters = true; + let source = Arc::new( + ParquetSource::new(Arc::clone(schema)) + .with_table_parquet_options(parquet_options), + ); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(PartitionedFile::new("/path/to/file.parquet", 1024)) + .build(); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_join_dynamic_filter_pushdown = true; + config.optimizer.enable_aggregate_dynamic_filter_pushdown = true; + config.optimizer.enable_topk_dynamic_filter_pushdown = true; + + (DataSourceExec::from_data_source(scan_config), config) +} + +/// Test that plan containing a HashJoinExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_hash_join_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, false)])); + + let left_child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let (right_child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let on: Vec<(Arc, Arc)> = vec![( + Arc::new(Column::new("col", 0)), + Arc::new(Column::new("col", 0)), + )]; + + let hash_join = Arc::new( + HashJoinExec::try_new( + left_child, + right_child, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )? + .with_dynamic_filter_routing_mode(DynamicFilterRoutingMode::PartitionIndex), + ) as Arc; + + // Run the optimizer rule for filter pushdown. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(hash_join, &config)?; + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let deserialized = roundtrip_test_and_return(plan, &ctx, &codec, &converter)?; + + // Extract the deserialized HashJoinExec and its dynamic filter. + let deserialized_join = deserialized + .as_any() + .downcast_ref::() + .expect("Should be HashJoinExec"); + let deserialized_hash_join_df = deserialized_join + .dynamic_filter() + .expect("HashJoinExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter from the probe side's ParquetSource. + let deserialized_data_source = deserialized_join + .right() + .as_any() + .downcast_ref::() + .expect("Right child should be DataSourceExec"); + let (_, deserialized_parquet_source) = deserialized_data_source + .downcast_to_file_source::() + .expect("Should be ParquetSource"); + let deserialized_predicate = deserialized_parquet_source + .filter() + .expect("ParquetSource should have a predicate after roundtrip"); + let deserialized_predicate_df = deserialized_predicate + .as_any() + .downcast_ref::() + .expect("ParquetSource predicate should contain a DynamicFilterPhysicalExpr"); + + // After roundtrip, the HashJoinExec's dynamic filter and the ParquetSource's + // predicate should share the same inner state. + assert_eq!( + deserialized_hash_join_df.inner_id(), + deserialized_predicate_df.inner_id(), + "HashJoinExec's dynamic filter should share inner state with the probe side's predicate" + ); + assert_eq!( + deserialized_join.dynamic_filter_routing_mode(), + DynamicFilterRoutingMode::PartitionIndex, + "HashJoinExec should preserve dynamic filter routing mode after roundtrip", + ); + + Ok(()) +} + +/// Test that plan containing a AggregateExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_aggregate_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let col_a: Arc = Arc::new(Column::new("a", 0)); + + let (child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let agg = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![ + AggregateExprBuilder::new( + datafusion::functions_aggregate::min_max::min_udaf(), + vec![Arc::clone(&col_a)], + ) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .map(Arc::new)?, + ], + vec![None], + child, + Arc::clone(&schema), + )?) as Arc; + + // Run the optimizer rule for filter pushdown. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(agg, &config)?; + + // Roundtrip with deduplication. + // + // Note: We don't use `roundtrip_test_and_return` here because there's a + // pre-existing issue with PhysicalGroupBy serialization where empty groups + // `[[]]` become `[]` after roundtrip. This behavior is unrelated to this test. + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&plan), + &codec, + &converter, + )?; + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Extract the deserialized AggregateExec and its dynamic filter. + let deserialized_agg = deserialized + .as_any() + .downcast_ref::() + .expect("Should be AggregateExec"); + let deserialized_agg_df = deserialized_agg + .dynamic_filter() + .expect("AggregateExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter from the child DataSourceExec. + let deserialized_data_source = deserialized_agg + .input() + .as_any() + .downcast_ref::() + .expect("Child should be DataSourceExec"); + let (_, deserialized_parquet_source) = deserialized_data_source + .downcast_to_file_source::() + .expect("Should be ParquetSource"); + let deserialized_predicate = deserialized_parquet_source + .filter() + .expect("ParquetSource should have a predicate after roundtrip"); + let deserialized_predicate_df = deserialized_predicate + .as_any() + .downcast_ref::() + .expect("ParquetSource predicate should contain a DynamicFilterPhysicalExpr"); + + // The AggregateExec's dynamic filter and the child's predicate should + // share the same inner state after roundtrip. + assert_eq!( + deserialized_agg_df.inner_id(), + deserialized_predicate_df.inner_id(), + "AggregateExec's dynamic filter should share inner state with child's predicate" + ); + + Ok(()) +} + +/// Test that plan containing a SortExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_sort_topk_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let col_a: Arc = Arc::new(Column::new("a", 0)); + + let (child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let sort = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)), + ) as Arc; + + // Verify the optimizer kept the dynamic filter on the AggregateExec. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(sort, &config)?; + + // Roundtrip with deduplication. + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let deserialized = roundtrip_test_and_return(plan, &ctx, &codec, &converter)?; + + // Extract the deserialized SortExec and its dynamic filter. + let deserialized_sort = deserialized + .as_any() + .downcast_ref::() + .expect("Should be SortExec"); + let deserialized_sort_df = deserialized_sort + .dynamic_filter() + .expect("SortExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter from the child DataSourceExec. + let deserialized_data_source = deserialized_sort + .input() + .as_any() + .downcast_ref::() + .expect("Child should be DataSourceExec"); + let (_, deserialized_parquet_source) = deserialized_data_source + .downcast_to_file_source::() + .expect("Should be ParquetSource"); + let deserialized_predicate = deserialized_parquet_source + .filter() + .expect("ParquetSource should have a predicate after roundtrip"); + let deserialized_predicate_df = deserialized_predicate + .as_any() + .downcast_ref::() + .expect("ParquetSource predicate should contain a DynamicFilterPhysicalExpr"); + + // The SortExec's dynamic filter and the child's predicate should + // share the same inner state after roundtrip. + assert_eq!( + deserialized_sort_df.inner_id(), + deserialized_predicate_df.inner_id(), + "SortExec's dynamic filter should share inner state with child's predicate" + ); + + Ok(()) +} + +/// A custom ExecutionPlan that stores `Vec>` fields, +/// simulating a custom scan node (like NumpangFileSource) that has dynamic +/// filters pushed down during optimization. +struct CustomExecWithExprs { + exprs: Vec>, + child: Arc, + properties: datafusion::physical_plan::PlanProperties, +} + +impl std::fmt::Debug for CustomExecWithExprs { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CustomExecWithExprs") + .field("exprs", &self.exprs) + .field("child", &self.child) + .finish() + } +} + +impl CustomExecWithExprs { + fn new(exprs: Vec>, child: Arc) -> Self { + use datafusion_physical_expr::equivalence::EquivalenceProperties; + use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; + let properties = datafusion::physical_plan::PlanProperties::new( + EquivalenceProperties::new(child.schema()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + Self { + exprs, + child, + properties, + } + } +} + +impl DisplayAs for CustomExecWithExprs { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CustomExecWithExprs") + } +} + +impl ExecutionPlan for CustomExecWithExprs { + fn name(&self) -> &str { + "CustomExecWithExprs" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unreachable!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!() + } +} + +/// A PhysicalExtensionCodec that uses proto_converter to serialize/deserialize +/// the PhysicalExpr fields stored in CustomExecWithExprs. +#[derive(Debug)] +struct CustomExecWithExprsCodec { + /// The schema used for expression deserialization (shared between encode/decode). + schema: Arc, +} + +impl PhysicalExtensionCodec for CustomExecWithExprsCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + ctx: &TaskContext, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let num_exprs = u32::from_le_bytes(buf[0..4].try_into().unwrap()) as usize; + let mut offset = 4; + + let mut exprs = Vec::with_capacity(num_exprs); + for _ in 0..num_exprs { + let expr_len = + u32::from_le_bytes(buf[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + let expr_proto = PhysicalExprNode::decode(&buf[offset..offset + expr_len]) + .map_err(|e| { + internal_datafusion_err!("Failed to decode expression: {e}") + })?; + let expr = proto_converter.proto_to_physical_expr( + &expr_proto, + ctx, + &self.schema, + self, + )?; + exprs.push(expr); + offset += expr_len; + } + + Ok(Arc::new(CustomExecWithExprs::new(exprs, inputs[0].clone()))) + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { + let custom = node + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("Expected CustomExecWithExprs"))?; + + buf.extend_from_slice(&(custom.exprs.len() as u32).to_le_bytes()); + for expr in &custom.exprs { + let expr_proto = proto_converter.physical_expr_to_proto(expr, self)?; + let expr_bytes = expr_proto.encode_to_vec(); + buf.extend_from_slice(&(expr_bytes.len() as u32).to_le_bytes()); + buf.extend_from_slice(&expr_bytes); + } + + Ok(()) + } +} + +/// Tests that a custom ExecutionPlan node storing PhysicalExpr fields can +/// use the proto_converter parameter to serialize/deserialize those expressions +/// with deduplication support. When the same DynamicFilterPhysicalExpr is shared +/// between the custom node and another part of the plan (e.g., a FilterExec), +/// the inner_id should be preserved after roundtrip (proving dedup works). +#[test] +fn test_custom_node_with_dynamic_filter_dedup_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Create a dynamic filter expression + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )); + let dynamic_filter_expr = dynamic_filter.clone() as Arc; + + // Create the plan: + // FilterExec(dynamic_filter) + // -> CustomExecWithExprs(exprs: [dynamic_filter]) + // -> EmptyExec + // + // The same dynamic_filter Arc is stored in both the FilterExec and the custom node. + let empty = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let custom_exec = Arc::new(CustomExecWithExprs::new( + vec![Arc::clone(&dynamic_filter_expr)], + empty, + )); + let filter_exec = Arc::new(FilterExec::try_new( + Arc::clone(&dynamic_filter_expr), + custom_exec, + )?) as Arc; + + // Roundtrip with DeduplicatingProtoConverter + let codec = CustomExecWithExprsCodec { + schema: Arc::clone(&schema), + }; + let converter = DeduplicatingProtoConverter {}; + + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&filter_exec), + &codec, + &converter, + )?; + + let ctx = SessionContext::new(); + let deser_converter = DeduplicatingProtoConverter {}; + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Extract the deserialized FilterExec's dynamic filter + let deser_filter = deserialized + .as_any() + .downcast_ref::() + .expect("Top-level should be FilterExec"); + let deser_filter_df = deser_filter + .predicate() + .as_any() + .downcast_ref::() + .expect("FilterExec predicate should be DynamicFilterPhysicalExpr"); + + // Extract the deserialized custom node's dynamic filter + let deser_custom = deser_filter + .input() + .as_any() + .downcast_ref::() + .expect("FilterExec child should be CustomExecWithExprs"); + assert_eq!(deser_custom.exprs.len(), 1, "Should have one expression"); + let deser_custom_df = deser_custom.exprs[0] + .as_any() + .downcast_ref::() + .expect("Custom node expr should be DynamicFilterPhysicalExpr"); + + // After roundtrip with deduplication, both references should share the same + // inner state (same inner_id), proving that proto_converter was used correctly + // within the custom codec and deduplication was preserved across the plan boundary. + assert_eq!( + deser_filter_df.inner_id(), + deser_custom_df.inner_id(), + "FilterExec's dynamic filter and CustomExecWithExprs's dynamic filter \ + should share the same inner state after roundtrip with deduplication" + ); + + Ok(()) +} diff --git a/datafusion/sqllogictest/test_files/cast.slt b/datafusion/sqllogictest/test_files/cast.slt index 916895b8be1eb..3466354e54d71 100644 --- a/datafusion/sqllogictest/test_files/cast.slt +++ b/datafusion/sqllogictest/test_files/cast.slt @@ -89,39 +89,3 @@ select * from t0 where v0<1e100; statement ok drop table t0; - - -# ensure that automatically casting with "datafusion.optimizer.expand_views_at_output" does not -# change the column name - -statement ok -create table t(a int, b varchar); - -statement ok -set datafusion.optimizer.expand_views_at_output = true; - -query TT -explain select * from t; ----- -logical_plan -01)Projection: t.a, CAST(t.b AS LargeUtf8) AS b -02)--TableScan: t projection=[a, b] -physical_plan -01)ProjectionExec: expr=[a@0 as a, CAST(b@1 AS LargeUtf8) as b] -02)--DataSourceExec: partitions=1, partition_sizes=[0] - -query TT -explain select b from t; ----- -logical_plan -01)Projection: CAST(t.b AS LargeUtf8) AS b -02)--TableScan: t projection=[b] -physical_plan -01)ProjectionExec: expr=[CAST(b@0 AS LargeUtf8) as b] -02)--DataSourceExec: partitions=1, partition_sizes=[0] - -statement ok -set datafusion.optimizer.expand_views_at_output = false; - -statement ok -drop table t; diff --git a/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt b/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt index 34c5fd97b51f3..210945bc2efa2 100644 --- a/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt +++ b/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt @@ -101,6 +101,29 @@ STORED AS PARQUET; ---- 4 +# Create hive-partitioned dimension table (3 partitions matching fact_table) +# For testing Partitioned joins with matching partition counts +query I +COPY (SELECT 'dev' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet' +STORED AS PARQUET; +---- +1 + +query I +COPY (SELECT 'prod' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet' +STORED AS PARQUET; +---- +1 + +query I +COPY (SELECT 'staging' as env, 'trace' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet' +STORED AS PARQUET; +---- +1 + # Create high-cardinality fact table (5 partitions > 3 target_partitions) # For testing partition merging with consistent hashing query I @@ -173,6 +196,13 @@ CREATE EXTERNAL TABLE dimension_table (d_dkey STRING, env STRING, service STRING STORED AS PARQUET LOCATION 'test_files/scratch/preserve_file_partitioning/dimension/'; +# Hive-partitioned dimension table (3 partitions matching fact_table for Partitioned join tests) +statement ok +CREATE EXTERNAL TABLE dimension_table_partitioned (env STRING, service STRING) +STORED AS PARQUET +PARTITIONED BY (d_dkey STRING) +LOCATION 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/'; + # 'High'-cardinality fact table (5 partitions > 3 target_partitions) statement ok CREATE EXTERNAL TABLE high_cardinality_table (timestamp TIMESTAMP, value DOUBLE) @@ -579,6 +609,224 @@ C 1 300 D 1 400 E 1 500 +########## +# TEST 11: Partitioned Join with Matching Partition Counts - Without Optimization +# fact_table (3 partitions) joins dimension_table_partitioned (3 partitions) +# Shows RepartitionExec added when preserve_file_partitions is disabled +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 0; + +# Force Partitioned join mode (not CollectLeft) +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold = 0; + +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold_rows = 0; + +query TT +EXPLAIN SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +logical_plan +01)Aggregate: groupBy=[[f.f_dkey, d.env]], aggr=[[sum(f.value)]] +02)--Projection: f.value, f.f_dkey, d.env +03)----Inner Join: f.f_dkey = d.d_dkey +04)------SubqueryAlias: f +05)--------TableScan: fact_table projection=[value, f_dkey] +06)------SubqueryAlias: d +07)--------TableScan: dimension_table_partitioned projection=[env, d_dkey] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env], aggr=[sum(f.value)] +02)--RepartitionExec: partitioning=Hash([f_dkey@0, env@1], 3), input_partitions=3 +03)----AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env], aggr=[sum(f.value)] +04)------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +06)----------RepartitionExec: partitioning=Hash([d_dkey@1], 3), input_partitions=3 +07)------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +08)----------RepartitionExec: partitioning=Hash([f_dkey@1], 3), input_partitions=3 +09)------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TTR rowsort +SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +A dev 772.4 +B prod 614.4 +C staging 2017.6 + +########## +# TEST 12: Partitioned Join with Matching Partition Counts - With Optimization +# Both tables have 3 partitions matching target_partitions=3 +# No RepartitionExec needed for join - partitions already satisfy the requirement +# Dynamic filter pushdown uses partition-indexed routing because preserve_file_partitions +# reports Hash partitioning for Hive-style file groups that are not hash-routed. +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 1; + +query TT +EXPLAIN SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +logical_plan +01)Aggregate: groupBy=[[f.f_dkey, d.env]], aggr=[[sum(f.value)]] +02)--Projection: f.value, f.f_dkey, d.env +03)----Inner Join: f.f_dkey = d.d_dkey +04)------SubqueryAlias: f +05)--------TableScan: fact_table projection=[value, f_dkey] +06)------SubqueryAlias: d +07)--------TableScan: dimension_table_partitioned projection=[env, d_dkey] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env], aggr=[sum(f.value)] +02)--RepartitionExec: partitioning=Hash([f_dkey@0, env@1], 3), input_partitions=3 +03)----AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env], aggr=[sum(f.value)] +04)------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +06)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +07)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TTR rowsort +SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +A dev 772.4 +B prod 614.4 +C staging 2017.6 + +########## +# TEST 13: Nested Joins with Different Partitioning Expressions - Without Optimization +# Join structure: +# Inner join: fact_table (partitioned on f_dkey) JOIN dimension_partitioned (partitioned on d_dkey) ON f_dkey = d_dkey +# Repartition: result repartitioned on env (different column than f_dkey) +# Outer join: repartitioned_result JOIN dimension_table ON env = env +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 0; + +# Force Partitioned join mode +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold = 0; + +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold_rows = 0; + +query TT +EXPLAIN SELECT f.f_dkey, d1.env, d2.service, sum(f.value) as total_value +FROM fact_table f +INNER JOIN dimension_table_partitioned d1 ON f.f_dkey = d1.d_dkey +INNER JOIN dimension_table d2 ON d1.env = d2.env +GROUP BY f.f_dkey, d1.env, d2.service; +---- +logical_plan +01)Projection: f.f_dkey, d1.env, d2.service, sum(f.value) AS total_value +02)--Aggregate: groupBy=[[f.f_dkey, d1.env, d2.service]], aggr=[[sum(f.value)]] +03)----Projection: f.value, f.f_dkey, d1.env, d2.service +04)------Inner Join: d1.env = d2.env +05)--------Projection: f.value, f.f_dkey, d1.env +06)----------Inner Join: f.f_dkey = d1.d_dkey +07)------------SubqueryAlias: f +08)--------------TableScan: fact_table projection=[value, f_dkey] +09)------------SubqueryAlias: d1 +10)--------------TableScan: dimension_table_partitioned projection=[env, d_dkey] +11)--------SubqueryAlias: d2 +12)----------TableScan: dimension_table projection=[env, service] +physical_plan +01)ProjectionExec: expr=[f_dkey@0 as f_dkey, env@1 as env, service@2 as service, sum(f.value)@3 as total_value] +02)--AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env, service@2 as service], aggr=[sum(f.value)] +03)----RepartitionExec: partitioning=Hash([f_dkey@0, env@1, service@2], 3), input_partitions=3 +04)------AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env, service@3 as service], aggr=[sum(f.value)] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(env@2, env@0)], projection=[value@0, f_dkey@1, env@2, service@4] +06)----------RepartitionExec: partitioning=Hash([env@2], 3), input_partitions=3 +07)------------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +09)----------------RepartitionExec: partitioning=Hash([d_dkey@1], 3), input_partitions=3 +10)------------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +11)----------------RepartitionExec: partitioning=Hash([f_dkey@1], 3), input_partitions=3 +12)------------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet, predicate=DynamicFilter [ empty ] +13)----------RepartitionExec: partitioning=Hash([env@0], 3), input_partitions=1 +14)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension/data.parquet]]}, projection=[env, service], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify results without optimization +query TTTR rowsort +SELECT f.f_dkey, d1.env, d2.service, sum(f.value) as total_value +FROM fact_table f +INNER JOIN dimension_table_partitioned d1 ON f.f_dkey = d1.d_dkey +INNER JOIN dimension_table d2 ON d1.env = d2.env +GROUP BY f.f_dkey, d1.env, d2.service; +---- +A dev log 772.4 +B prod log 1228.8 +B prod trace 614.4 + +########## +# TEST 14: Nested Joins with Different Partitioning Expressions - With Optimization +# Same query as TEST 13, but with preserve_file_partitions enabled +# Key difference: No RepartitionExec before the first join (both sides already aligned on f_dkey/d_dkey) +# RepartitionExec only appears after the first join to change partitioning key from f_dkey to env +# Shows that remap_children correctly propagates dynamic filters through the repartition +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 1; + +query TT +EXPLAIN SELECT f.f_dkey, d1.env, d2.service, sum(f.value) as total_value +FROM fact_table f +INNER JOIN dimension_table_partitioned d1 ON f.f_dkey = d1.d_dkey +INNER JOIN dimension_table d2 ON d1.env = d2.env +GROUP BY f.f_dkey, d1.env, d2.service; +---- +logical_plan +01)Projection: f.f_dkey, d1.env, d2.service, sum(f.value) AS total_value +02)--Aggregate: groupBy=[[f.f_dkey, d1.env, d2.service]], aggr=[[sum(f.value)]] +03)----Projection: f.value, f.f_dkey, d1.env, d2.service +04)------Inner Join: d1.env = d2.env +05)--------Projection: f.value, f.f_dkey, d1.env +06)----------Inner Join: f.f_dkey = d1.d_dkey +07)------------SubqueryAlias: f +08)--------------TableScan: fact_table projection=[value, f_dkey] +09)------------SubqueryAlias: d1 +10)--------------TableScan: dimension_table_partitioned projection=[env, d_dkey] +11)--------SubqueryAlias: d2 +12)----------TableScan: dimension_table projection=[env, service] +physical_plan +01)ProjectionExec: expr=[f_dkey@0 as f_dkey, env@1 as env, service@2 as service, sum(f.value)@3 as total_value] +02)--AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env, service@2 as service], aggr=[sum(f.value)] +03)----RepartitionExec: partitioning=Hash([f_dkey@0, env@1, service@2], 3), input_partitions=3 +04)------AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env, service@3 as service], aggr=[sum(f.value)] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(env@2, env@0)], projection=[value@0, f_dkey@1, env@2, service@4] +06)----------RepartitionExec: partitioning=Hash([env@2], 3), input_partitions=3 +07)------------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +09)----------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +10)----------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet, predicate=DynamicFilter [ empty ] +11)----------RepartitionExec: partitioning=Hash([env@0], 3), input_partitions=1 +12)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension/data.parquet]]}, projection=[env, service], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify results with optimization match results without optimization +query TTTR rowsort +SELECT f.f_dkey, d1.env, d2.service, sum(f.value) as total_value +FROM fact_table f +INNER JOIN dimension_table_partitioned d1 ON f.f_dkey = d1.d_dkey +INNER JOIN dimension_table d2 ON d1.env = d2.env +GROUP BY f.f_dkey, d1.env, d2.service; +---- +A dev log 772.4 +B prod log 1228.8 +B prod trace 614.4 + ########## # CLEANUP ########## @@ -592,5 +840,8 @@ DROP TABLE fact_table_ordered; statement ok DROP TABLE dimension_table; +statement ok +DROP TABLE dimension_table_partitioned; + statement ok DROP TABLE high_cardinality_table; diff --git a/dev/changelog/52.1.0.md b/dev/changelog/52.1.0.md deleted file mode 100644 index 97a1435c41a44..0000000000000 --- a/dev/changelog/52.1.0.md +++ /dev/null @@ -1,46 +0,0 @@ - - -# Apache DataFusion 52.1.0 Changelog - -This release consists of 3 commits from 3 contributors. See credits at the end of this changelog for more information. - -See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. - -**Documentation updates:** - -- [branch-52] Fix Internal error: Assertion failed: !self.finished: LimitedBatchCoalescer (#19785) [#19836](https://github.com/apache/datafusion/pull/19836) (alamb) - -**Other:** - -- [branch-52] fix: expose `ListFilesEntry` [#19818](https://github.com/apache/datafusion/pull/19818) (lonless9) -- [branch 52] Fix grouping set subset satisfaction [#19855](https://github.com/apache/datafusion/pull/19855) (gabotechs) -- Add BatchAdapter to simplify using PhysicalExprAdapter / Projector [#19877](https://github.com/apache/datafusion/pull/19877) (alamb) - -## Credits - -Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. - -``` - 1 Andrew Lamb - 1 Gabriel - 1 XL Liang -``` - -Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 61246f00dfe74..d58066c7fca35 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -19,6 +19,234 @@ # Upgrade Guides +## DataFusion `53.0.0` + +**Note:** DataFusion `53.0.0` has not been released yet. The information provided in this section pertains to features and changes that have already been merged to the main branch and are awaiting release in this version. + +### `SimplifyInfo` trait removed, `SimplifyContext` now uses builder-style API + +The `SimplifyInfo` trait has been removed and replaced with the concrete `SimplifyContext` struct. This simplifies the expression simplification API and removes the need for trait objects. + +**Who is affected:** + +- Users who implemented custom `SimplifyInfo` implementations +- Users who implemented `ScalarUDFImpl::simplify()` for custom scalar functions +- Users who directly use `SimplifyContext` or `ExprSimplifier` + +**Breaking changes:** + +1. The `SimplifyInfo` trait has been removed entirely +2. `SimplifyContext` no longer takes `&ExecutionProps` - it now uses a builder-style API with direct fields +3. `ScalarUDFImpl::simplify()` now takes `&SimplifyContext` instead of `&dyn SimplifyInfo` +4. Time-dependent function simplification (e.g., `now()`) is now optional - if `query_execution_start_time` is `None`, these functions won't be simplified + +**Migration guide:** + +If you implemented a custom `SimplifyInfo`: + +**Before:** + +```rust,ignore +impl SimplifyInfo for MySimplifyInfo { + fn is_boolean_type(&self, expr: &Expr) -> Result { ... } + fn nullable(&self, expr: &Expr) -> Result { ... } + fn execution_props(&self) -> &ExecutionProps { ... } + fn get_data_type(&self, expr: &Expr) -> Result { ... } +} +``` + +**After:** + +Use `SimplifyContext` directly with the builder-style API: + +```rust,ignore +let context = SimplifyContext::default() + .with_schema(schema) + .with_config_options(config_options) + .with_query_execution_start_time(Some(Utc::now())); // or use .with_current_time() +``` + +If you implemented `ScalarUDFImpl::simplify()`: + +**Before:** + +```rust,ignore +fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, +) -> Result { + let now_ts = info.execution_props().query_execution_start_time; + // ... +} +``` + +**After:** + +```rust,ignore +fn simplify( + &self, + args: Vec, + info: &SimplifyContext, +) -> Result { + // query_execution_start_time is now Option> + // Return Original if time is not set (simplification skipped) + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; + // ... +} +``` + +If you created `SimplifyContext` from `ExecutionProps`: + +**Before:** + +```rust,ignore +let props = ExecutionProps::new(); +let context = SimplifyContext::new(&props).with_schema(schema); +``` + +**After:** + +```rust,ignore +let context = SimplifyContext::default() + .with_schema(schema) + .with_config_options(config_options) + .with_current_time(); // Sets query_execution_start_time to Utc::now() +``` + +See [`SimplifyContext` documentation](https://docs.rs/datafusion-expr/latest/datafusion_expr/simplify/struct.SimplifyContext.html) for more details. + +### `FilterExec` builder methods deprecated + +The following methods on `FilterExec` have been deprecated in favor of using `FilterExecBuilder`: + +- `with_projection()` +- `with_batch_size()` + +**Who is affected:** + +- Users who create `FilterExec` instances and use these methods to configure them + +**Migration guide:** + +Use `FilterExecBuilder` instead of chaining method calls on `FilterExec`: + +**Before:** + +```rust,ignore +let filter = FilterExec::try_new(predicate, input)? + .with_projection(Some(vec![0, 2]))? + .with_batch_size(8192)?; +``` + +**After:** + +```rust,ignore +let filter = FilterExecBuilder::new(predicate, input) + .with_projection(Some(vec![0, 2])) + .with_batch_size(8192) + .build()?; +``` + +The builder pattern is more efficient as it computes properties once during `build()` rather than recomputing them for each method call. + +Note: `with_default_selectivity()` is not deprecated as it simply updates a field value and does not require the overhead of the builder pattern. + +### Protobuf conversion trait added + +A new trait, `PhysicalProtoConverterExtension`, has been added to the `datafusion-proto` +crate. This is used for controlling the process of conversion of physical plans and +expressions to and from their protobuf equivalents. The methods for conversion now +require an additional parameter. + +The primary APIs for interacting with this crate have not been modified, so most users +should not need to make any changes. If you do require this trait, you can use the +`DefaultPhysicalProtoConverter` implementation. + +For example, to convert a sort expression protobuf node you can make the following +updates: + +**Before:** + +```rust,ignore +let sort_expr = parse_physical_sort_expr( + sort_proto, + ctx, + input_schema, + codec, +); +``` + +**After:** + +```rust,ignore +let converter = DefaultPhysicalProtoConverter {}; +let sort_expr = parse_physical_sort_expr( + sort_proto, + ctx, + input_schema, + codec, + &converter +); +``` + +Similarly to convert from a physical sort expression into a protobuf node: + +**Before:** + +```rust,ignore +let sort_proto = serialize_physical_sort_expr( + sort_expr, + codec, +); +``` + +**After:** + +```rust,ignore +let converter = DefaultPhysicalProtoConverter {}; +let sort_proto = serialize_physical_sort_expr( + sort_expr, + codec, + &converter, +); +``` + +### `generate_series` and `range` table functions changed + +The `generate_series` and `range` table functions now return an empty set when the interval is invalid, instead of an error. +This behavior is consistent with systems like PostgreSQL. + +Before: + +```sql +> select * from generate_series(0, -1); +Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series + +> select * from range(0, -1); +Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +``` + +Now: + +```sql +> select * from generate_series(0, -1); ++-------+ +| value | ++-------+ ++-------+ +0 row(s) fetched. + +> select * from range(0, -1); ++-------+ +| value | ++-------+ ++-------+ +0 row(s) fetched. +``` + ## DataFusion `52.0.0` **Note:** DataFusion `52.0.0` has not been released yet. The information provided in this section pertains to features and changes that have already been merged to the main branch and are awaiting release in this version. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 76acd42ac901d..c9a1b6c684c9b 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -99,7 +99,7 @@ The following configuration settings are available: | datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | | datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 52.1.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 52.0.0 | (writing) Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | | datafusion.execution.parquet.statistics_truncate_length | 64 | (writing) Sets statistics truncate length. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page |