diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 828b286407b3..8683141e57ff 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -69,8 +69,8 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{ - DFSchema, ScalarValue, exec_err, internal_datafusion_err, internal_err, not_impl_err, - plan_err, + DFSchema, DFSchemaRef, ScalarValue, exec_err, internal_datafusion_err, internal_err, + not_impl_err, plan_err, }; use datafusion_common::{ TableReference, assert_eq_or_internal_err, assert_or_internal_err, @@ -157,6 +157,26 @@ pub trait ExtensionPlanner { physical_inputs: &[Arc], session_state: &SessionState, ) -> Result>>; + + /// Create a physical plan for a [`LogicalPlan::TableScan`]. + /// + /// This is useful for planning valid [`TableSource`]s that are not [`TableProvider`]s. + /// + /// Returns: + /// * `Ok(Some(plan))` if the planner knows how to plan the `scan` + /// * `Ok(None)` if the planner does not know how to plan the `scan` and wants to delegate the planning to another [`ExtensionPlanner`] + /// * `Err` if the planner knows how to plan the `scan` but errors while doing so + /// + /// [`TableSource`]: datafusion_expr::TableSource + /// [`TableProvider`]: datafusion_catalog::TableProvider + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + _scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } } /// Default single node physical query planner that converts a @@ -278,7 +298,8 @@ struct LogicalNode<'a> { impl DefaultPhysicalPlanner { /// Create a physical planner that uses `extension_planners` to - /// plan user-defined logical nodes [`LogicalPlan::Extension`]. + /// plan user-defined logical nodes [`LogicalPlan::Extension`] + /// or user-defined table sources in [`LogicalPlan::TableScan`]. /// The planner uses the first [`ExtensionPlanner`] to return a non-`None` /// plan. pub fn with_extension_planners( @@ -287,6 +308,24 @@ impl DefaultPhysicalPlanner { Self { extension_planners } } + fn ensure_schema_matches( + &self, + logical_schema: &DFSchemaRef, + physical_plan: &Arc, + context: &str, + ) -> Result<()> { + if !logical_schema.matches_arrow_schema(&physical_plan.schema()) { + return plan_err!( + "{} created an ExecutionPlan with mismatched schema. \ + LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", + context, + logical_schema, + physical_plan.schema() + ); + } + Ok(()) + } + /// Create a physical plan from a logical plan async fn create_initial_plan( &self, @@ -455,25 +494,53 @@ impl DefaultPhysicalPlanner { ) -> Result> { let exec_node: Arc = match node { // Leaves (no children) - LogicalPlan::TableScan(TableScan { - source, - projection, - filters, - fetch, - .. - }) => { - let source = source_as_provider(source)?; - // Remove all qualifiers from the scan as the provider - // doesn't know (nor should care) how the relation was - // referred to in the query - let filters = unnormalize_cols(filters.iter().cloned()); - let filters_vec = filters.into_iter().collect::>(); - let opts = ScanArgs::default() - .with_projection(projection.as_deref()) - .with_filters(Some(&filters_vec)) - .with_limit(*fetch); - let res = source.scan_with_args(session_state, opts).await?; - Arc::clone(res.plan()) + LogicalPlan::TableScan(scan) => { + let TableScan { + source, + projection, + filters, + fetch, + projected_schema, + .. + } = scan; + + if let Ok(source) = source_as_provider(source) { + // Remove all qualifiers from the scan as the provider + // doesn't know (nor should care) how the relation was + // referred to in the query + let filters = unnormalize_cols(filters.iter().cloned()); + let filters_vec = filters.into_iter().collect::>(); + let opts = ScanArgs::default() + .with_projection(projection.as_deref()) + .with_filters(Some(&filters_vec)) + .with_limit(*fetch); + let res = source.scan_with_args(session_state, opts).await?; + Arc::clone(res.plan()) + } else { + let mut maybe_plan = None; + for planner in &self.extension_planners { + if maybe_plan.is_some() { + break; + } + + maybe_plan = + planner.plan_table_scan(self, scan, session_state).await?; + } + + let plan = match maybe_plan { + Some(plan) => plan, + None => { + return plan_err!( + "No installed planner was able to plan TableScan for custom TableSource: {:?}", + scan.table_name + ); + } + }; + let context = + format!("Extension planner for table scan {}", scan.table_name); + self.ensure_schema_matches(projected_schema, &plan, &context)?; + plan + } } LogicalPlan::Values(Values { values, schema }) => { let exprs = values @@ -1616,20 +1683,9 @@ impl DefaultPhysicalPlanner { ), }?; - // Ensure the ExecutionPlan's schema matches the - // declared logical schema to catch and warn about - // logic errors when creating user defined plans. - if !node.schema().matches_arrow_schema(&plan.schema()) { - return plan_err!( - "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ - LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", - node, - node.schema(), - plan.schema() - ); - } else { - plan - } + let context = format!("Extension planner for {node:?}"); + self.ensure_schema_matches(node.schema(), &plan, &context)?; + plan } // Other @@ -2889,7 +2945,9 @@ mod tests { use datafusion_execution::TaskContext; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::builder::subquery_alias; - use datafusion_expr::{LogicalPlanBuilder, UserDefinedLogicalNodeCore, col, lit}; + use datafusion_expr::{ + LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore, col, lit, + }; use datafusion_functions_aggregate::count::count_all; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; @@ -4413,4 +4471,76 @@ digraph { assert_contains!(&err_str, "field nullability at index"); assert_contains!(&err_str, "field metadata at index"); } + + #[derive(Debug)] + struct MockTableSource { + schema: SchemaRef, + } + + impl TableSource for MockTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } + + struct MockTableScanExtensionPlanner; + + #[async_trait] + impl ExtensionPlanner for MockTableScanExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + _node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + _physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } + + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + if scan.source.as_any().is::() { + Ok(Some(Arc::new(EmptyExec::new(Arc::clone( + scan.projected_schema.inner(), + ))))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_table_scan_extension_planner() { + let session_state = make_session_state(); + let planner = Arc::new(MockTableScanExtensionPlanner); + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![planner]); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let table_source = Arc::new(MockTableSource { + schema: Arc::clone(&schema), + }); + let logical_plan = LogicalPlanBuilder::scan("test", table_source, None) + .unwrap() + .build() + .unwrap(); + + let plan = physical_planner + .create_physical_plan(&logical_plan, &session_state) + .await + .unwrap(); + + assert_eq!(plan.schema(), schema); + assert!(plan.as_any().is::()); + } }