diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index bd8a9d9a99b53..d6c4d8e5fac6f 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr}; +use crate::logical_plan::producer::{ + SubstraitProducer, to_substrait_literal_expr, to_substrait_type, +}; use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err}; -use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr}; +use datafusion::logical_expr::{ + Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr, +}; use substrait::proto::expression::{RexType, ScalarFunction}; use substrait::proto::function_argument::ArgType; -use substrait::proto::{Expression, FunctionArgument}; +use substrait::proto::{Expression, FunctionArgument, Type}; pub fn from_scalar_function( producer: &mut impl SubstraitProducer, @@ -36,13 +40,20 @@ pub fn from_scalar_function( let arguments = custom_argument_handler(fun.name(), arguments); + let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?; + let output_type = to_substrait_type( + producer, + output_field.data_type(), + output_field.is_nullable(), + )?; + let function_anchor = producer.register_function(fun.name().to_string()); #[expect(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, arguments, - output_type: None, + output_type: Some(output_type), options: vec![], args: vec![], })), @@ -86,7 +97,13 @@ pub fn from_unary_expr( Expr::Negative(arg) => ("negate", arg), expr => not_impl_err!("Unsupported expression: {expr:?}")?, }; - to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) + let (_, output_field) = expr.to_field(schema)?; + let output_type = to_substrait_type( + producer, + output_field.data_type(), + output_field.is_nullable(), + )?; + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema, &output_type) } pub fn from_binary_expr( @@ -97,7 +114,19 @@ pub fn from_binary_expr( let BinaryExpr { left, op, right } = expr; let l = producer.handle_expr(left, schema)?; let r = producer.handle_expr(right, schema)?; - Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) + let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?; + let output_type = to_substrait_type( + producer, + output_field.data_type(), + output_field.is_nullable(), + )?; + Ok(make_binary_op_scalar_func( + producer, + &l, + &r, + *op, + &output_type, + )) } pub fn from_like( @@ -192,6 +221,7 @@ fn to_substrait_unary_scalar_fn( fn_name: &str, arg: &Expr, schema: &DFSchemaRef, + output_type: &Type, ) -> datafusion::common::Result { let function_anchor = producer.register_function(fn_name.to_string()); let substrait_expr = producer.handle_expr(arg, schema)?; @@ -202,7 +232,7 @@ fn to_substrait_unary_scalar_fn( arguments: vec![FunctionArgument { arg_type: Some(ArgType::Value(substrait_expr)), }], - output_type: None, + output_type: Some(output_type.clone()), options: vec![], ..Default::default() })), @@ -215,6 +245,7 @@ pub fn make_binary_op_scalar_func( lhs: &Expression, rhs: &Expression, op: Operator, + output_type: &Type, ) -> Expression { let function_anchor = producer.register_function(operator_to_name(op).to_string()); #[expect(deprecated)] @@ -229,7 +260,7 @@ pub fn make_binary_op_scalar_func( arg_type: Some(ArgType::Value(rhs.clone())), }, ], - output_type: None, + output_type: Some(output_type.clone()), args: vec![], options: vec![], })), @@ -247,57 +278,21 @@ pub fn from_between( low, high, } = between; - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_low, - Operator::Lt, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_high, - &substrait_expr, - Operator::Lt, - ); - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::Or, - )) + let expr = if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + Expr::or( + Expr::lt(*expr.clone(), *low.clone()), + Expr::lt(*high.clone(), *expr.clone()), + ) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_low, - &substrait_expr, - Operator::LtEq, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_high, - Operator::LtEq, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::And, - )) - } + Expr::and( + Expr::lt_eq(*low.clone(), *expr.clone()), + Expr::lt_eq(*expr.clone(), *high.clone()), + ) + }; + producer.handle_expr(&expr, schema) } pub fn operator_to_name(op: Operator) -> &'static str { @@ -346,3 +341,37 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::BitwiseShiftLeft => "bitwise_shift_left", } } + +#[cfg(test)] +mod tests { + use crate::logical_plan::producer::{ + DefaultSubstraitProducer, SubstraitProducer, to_substrait_type, + }; + use datafusion::arrow::datatypes::DataType; + use datafusion::common::{DFSchema, DFSchemaRef}; + use datafusion::execution::SessionStateBuilder; + use datafusion::prelude::lit; + use substrait::proto::Expression; + use substrait::proto::expression::{RexType, ScalarFunction}; + + #[tokio::test] + async fn binary_expr_output_type() -> datafusion::common::Result<()> { + let state = SessionStateBuilder::default().build(); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let mut producer = DefaultSubstraitProducer::new(&state); + + let expr = lit(1i64) + lit(2i64); + let substrait_expr = producer.handle_expr(&expr, &empty_schema)?; + if let Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })), + } = substrait_expr + { + let expected_type = + to_substrait_type(&mut producer, &DataType::Int64, false)?; + assert_eq!(output_type, Some(expected_type)); + Ok(()) + } else { + panic!("Substrait ScalarFunction expected") + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs index cbf5593ffc86c..97e48a79b6ed1 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/join.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -15,59 +15,38 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{SubstraitProducer, make_binary_op_scalar_func}; -use datafusion::common::{ - DFSchemaRef, JoinConstraint, JoinType, NullEquality, not_impl_err, -}; +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::{JoinConstraint, JoinType, NullEquality, not_impl_err}; +use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{Expr, Join, Operator}; +use datafusion::prelude::binary_expr; use std::sync::Arc; use substrait::proto::rel::RelType; -use substrait::proto::{Expression, JoinRel, Rel, join_rel}; +use substrait::proto::{JoinRel, Rel, join_rel}; pub fn from_join( producer: &mut impl SubstraitProducer, join: &Join, ) -> datafusion::common::Result> { - let left = producer.handle_plan(join.left.as_ref())?; - let right = producer.handle_plan(join.right.as_ref())?; - let join_type = to_substrait_jointype(join.join_type); - // we only support basic joins so return an error for anything not yet supported + // only ON constraints are supported right now match join.join_constraint { JoinConstraint::On => {} JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), } - let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); - - // convert filter if present - let join_filter = match &join.filter { - Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), - None => None, - }; - // map the left and right columns to binary expressions in the form `l = r` - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = match join.null_equality { - NullEquality::NullEqualsNothing => Operator::Eq, - NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, - }; - let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; + let left = producer.handle_plan(join.left.as_ref())?; + let right = producer.handle_plan(join.right.as_ref())?; + let join_type = to_substrait_jointype(join.join_type); - // create conjunction between `join_on` and `join_filter` to embed all join conditions, - // whether equal or non-equal in a single expression - let join_expr = match &join_on { - Some(on_expr) => match &join_filter { - Some(filter) => Some(Box::new(make_binary_op_scalar_func( - producer, - on_expr, - filter, - Operator::And, - ))), - None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist - }, - None => match &join_filter { - Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist - None => None, - }, + let join_expr = + to_substrait_join_expr(join.on.clone(), join.null_equality, join.filter.clone()); + let join_expression = match join_expr { + Some(expr) => { + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); + let expression = producer.handle_expr(&expr, &in_join_schema)?; + Some(Box::new(expression)) + } + None => None, }; Ok(Box::new(Rel { @@ -76,7 +55,7 @@ pub fn from_join( left: Some(left), right: Some(right), r#type: join_type as i32, - expression: join_expr, + expression: join_expression, post_join_filter: None, advanced_extension: None, }))), @@ -84,25 +63,20 @@ pub fn from_join( } fn to_substrait_join_expr( - producer: &mut impl SubstraitProducer, - join_conditions: &Vec<(Expr, Expr)>, - eq_op: Operator, - join_schema: &DFSchemaRef, -) -> datafusion::common::Result> { - // Only support AND conjunction for each binary expression in join conditions - let mut exprs: Vec = vec![]; - for (left, right) in join_conditions { - let l = producer.handle_expr(left, join_schema)?; - let r = producer.handle_expr(right, join_schema)?; - // AND with existing expression - exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); - } - - let join_expr: Option = - exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(producer, &acc, &e, Operator::And) - }); - Ok(join_expr) + join_on: Vec<(Expr, Expr)>, + null_equality: NullEquality, + join_filter: Option, +) -> Option { + // Combine join on and filter conditions into a single Boolean expression (#7611) + let eq_op = match null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, + }; + let all_conditions = join_on + .into_iter() + .map(|(left, right)| binary_expr(left, eq_op, right)) + .chain(join_filter); + conjunction(all_conditions) } fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {