Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 85 additions & 56 deletions datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr};
use crate::logical_plan::producer::{
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
};
use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err};
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr};
use datafusion::logical_expr::{
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
};
use substrait::proto::expression::{RexType, ScalarFunction};
use substrait::proto::function_argument::ArgType;
use substrait::proto::{Expression, FunctionArgument};
use substrait::proto::{Expression, FunctionArgument, Type};

pub fn from_scalar_function(
producer: &mut impl SubstraitProducer,
Expand All @@ -36,13 +40,20 @@ pub fn from_scalar_function(

let arguments = custom_argument_handler(fun.name(), arguments);

let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?;
let output_type = to_substrait_type(
producer,
output_field.data_type(),
output_field.is_nullable(),
)?;

let function_anchor = producer.register_function(fun.name().to_string());
#[expect(deprecated)]
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
output_type: Some(output_type),
options: vec![],
args: vec![],
})),
Expand Down Expand Up @@ -86,7 +97,13 @@ pub fn from_unary_expr(
Expr::Negative(arg) => ("negate", arg),
expr => not_impl_err!("Unsupported expression: {expr:?}")?,
};
to_substrait_unary_scalar_fn(producer, fn_name, arg, schema)
let (_, output_field) = expr.to_field(schema)?;
let output_type = to_substrait_type(
producer,
output_field.data_type(),
output_field.is_nullable(),
)?;
to_substrait_unary_scalar_fn(producer, fn_name, arg, schema, &output_type)
}

pub fn from_binary_expr(
Expand All @@ -97,7 +114,19 @@ pub fn from_binary_expr(
let BinaryExpr { left, op, right } = expr;
let l = producer.handle_expr(left, schema)?;
let r = producer.handle_expr(right, schema)?;
Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?;
let output_type = to_substrait_type(
producer,
output_field.data_type(),
output_field.is_nullable(),
)?;
Ok(make_binary_op_scalar_func(
producer,
&l,
&r,
*op,
&output_type,
))
}

pub fn from_like(
Expand Down Expand Up @@ -192,6 +221,7 @@ fn to_substrait_unary_scalar_fn(
fn_name: &str,
arg: &Expr,
schema: &DFSchemaRef,
output_type: &Type,
) -> datafusion::common::Result<Expression> {
let function_anchor = producer.register_function(fn_name.to_string());
let substrait_expr = producer.handle_expr(arg, schema)?;
Expand All @@ -202,7 +232,7 @@ fn to_substrait_unary_scalar_fn(
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_expr)),
}],
output_type: None,
output_type: Some(output_type.clone()),
options: vec![],
..Default::default()
})),
Expand All @@ -215,6 +245,7 @@ pub fn make_binary_op_scalar_func(
lhs: &Expression,
rhs: &Expression,
op: Operator,
output_type: &Type,
) -> Expression {
let function_anchor = producer.register_function(operator_to_name(op).to_string());
#[expect(deprecated)]
Expand All @@ -229,7 +260,7 @@ pub fn make_binary_op_scalar_func(
arg_type: Some(ArgType::Value(rhs.clone())),
},
],
output_type: None,
output_type: Some(output_type.clone()),
args: vec![],
options: vec![],
})),
Expand All @@ -247,57 +278,21 @@ pub fn from_between(
low,
high,
} = between;
if *negated {
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;

let l_expr = make_binary_op_scalar_func(
producer,
&substrait_expr,
&substrait_low,
Operator::Lt,
);
let r_expr = make_binary_op_scalar_func(
producer,
&substrait_high,
&substrait_expr,
Operator::Lt,
);

Ok(make_binary_op_scalar_func(
producer,
&l_expr,
&r_expr,
Operator::Or,
))
let expr = if *negated {
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
Expr::or(
Expr::lt(*expr.clone(), *low.clone()),
Expr::lt(*high.clone(), *expr.clone()),
)
} else {
// `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high)
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;

let l_expr = make_binary_op_scalar_func(
producer,
&substrait_low,
&substrait_expr,
Operator::LtEq,
);
let r_expr = make_binary_op_scalar_func(
producer,
&substrait_expr,
&substrait_high,
Operator::LtEq,
);

Ok(make_binary_op_scalar_func(
producer,
&l_expr,
&r_expr,
Operator::And,
))
}
Expr::and(
Expr::lt_eq(*low.clone(), *expr.clone()),
Expr::lt_eq(*expr.clone(), *high.clone()),
)
};
producer.handle_expr(&expr, schema)
}

pub fn operator_to_name(op: Operator) -> &'static str {
Expand Down Expand Up @@ -346,3 +341,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
Operator::BitwiseShiftLeft => "bitwise_shift_left",
}
}

#[cfg(test)]
mod tests {
use crate::logical_plan::producer::{
DefaultSubstraitProducer, SubstraitProducer, to_substrait_type,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{DFSchema, DFSchemaRef};
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::lit;
use substrait::proto::Expression;
use substrait::proto::expression::{RexType, ScalarFunction};

#[tokio::test]
async fn binary_expr_output_type() -> datafusion::common::Result<()> {
let state = SessionStateBuilder::default().build();
let empty_schema = DFSchemaRef::new(DFSchema::empty());
let mut producer = DefaultSubstraitProducer::new(&state);

let expr = lit(1i64) + lit(2i64);
let substrait_expr = producer.handle_expr(&expr, &empty_schema)?;
if let Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })),
} = substrait_expr
{
let expected_type =
to_substrait_type(&mut producer, &DataType::Int64, false)?;
assert_eq!(output_type, Some(expected_type));
Ok(())
} else {
panic!("Substrait ScalarFunction expected")
}
}
}
92 changes: 33 additions & 59 deletions datafusion/substrait/src/logical_plan/producer/rel/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,38 @@
// specific language governing permissions and limitations
// under the License.

use crate::logical_plan::producer::{SubstraitProducer, make_binary_op_scalar_func};
use datafusion::common::{
DFSchemaRef, JoinConstraint, JoinType, NullEquality, not_impl_err,
};
use crate::logical_plan::producer::SubstraitProducer;
use datafusion::common::{JoinConstraint, JoinType, NullEquality, not_impl_err};
use datafusion::logical_expr::utils::conjunction;
use datafusion::logical_expr::{Expr, Join, Operator};
use datafusion::prelude::binary_expr;
use std::sync::Arc;
use substrait::proto::rel::RelType;
use substrait::proto::{Expression, JoinRel, Rel, join_rel};
use substrait::proto::{JoinRel, Rel, join_rel};

pub fn from_join(
producer: &mut impl SubstraitProducer,
join: &Join,
) -> datafusion::common::Result<Box<Rel>> {
let left = producer.handle_plan(join.left.as_ref())?;
let right = producer.handle_plan(join.right.as_ref())?;
let join_type = to_substrait_jointype(join.join_type);
// we only support basic joins so return an error for anything not yet supported
// only ON constraints are supported right now
match join.join_constraint {
JoinConstraint::On => {}
JoinConstraint::Using => return not_impl_err!("join constraint: `using`"),
}
let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?);

// convert filter if present
let join_filter = match &join.filter {
Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?),
None => None,
};

// map the left and right columns to binary expressions in the form `l = r`
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
let eq_op = match join.null_equality {
NullEquality::NullEqualsNothing => Operator::Eq,
NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
};
let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?;
let left = producer.handle_plan(join.left.as_ref())?;
let right = producer.handle_plan(join.right.as_ref())?;
let join_type = to_substrait_jointype(join.join_type);

// create conjunction between `join_on` and `join_filter` to embed all join conditions,
// whether equal or non-equal in a single expression
let join_expr = match &join_on {
Some(on_expr) => match &join_filter {
Some(filter) => Some(Box::new(make_binary_op_scalar_func(
producer,
on_expr,
filter,
Operator::And,
))),
None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist
},
None => match &join_filter {
Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist
None => None,
},
let join_expr =
to_substrait_join_expr(join.on.clone(), join.null_equality, join.filter.clone());
let join_expression = match join_expr {
Some(expr) => {
let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?);
let expression = producer.handle_expr(&expr, &in_join_schema)?;
Some(Box::new(expression))
}
None => None,
};

Ok(Box::new(Rel {
Expand All @@ -76,33 +55,28 @@ pub fn from_join(
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: join_expr,
expression: join_expression,
post_join_filter: None,
advanced_extension: None,
}))),
}))
}

fn to_substrait_join_expr(
producer: &mut impl SubstraitProducer,
join_conditions: &Vec<(Expr, Expr)>,
eq_op: Operator,
join_schema: &DFSchemaRef,
) -> datafusion::common::Result<Option<Expression>> {
// Only support AND conjunction for each binary expression in join conditions
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
let l = producer.handle_expr(left, join_schema)?;
let r = producer.handle_expr(right, join_schema)?;
// AND with existing expression
exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op));
}

let join_expr: Option<Expression> =
exprs.into_iter().reduce(|acc: Expression, e: Expression| {
make_binary_op_scalar_func(producer, &acc, &e, Operator::And)
});
Ok(join_expr)
join_on: Vec<(Expr, Expr)>,
null_equality: NullEquality,
join_filter: Option<Expr>,
) -> Option<Expr> {
// Combine join on and filter conditions into a single Boolean expression (#7611)
let eq_op = match null_equality {
NullEquality::NullEqualsNothing => Operator::Eq,
NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
};
let all_conditions = join_on
.into_iter()
.map(|(left, right)| binary_expr(left, eq_op, right))
.chain(join_filter);
conjunction(all_conditions)
}

fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
Expand Down