diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c6644e008645a..fe2e1a3b0408a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1660,17 +1660,19 @@ impl TreeNodeRewriter for Simplifier<'_> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, - }) => Transformed::yes(simplify_regex_expr(left, op, right)?), + }) => simplify_regex_expr(left, op, right)?, // Rules for Like Expr::Like(like) => { // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291 let escape_char = like.escape_char.unwrap_or('\\'); - match as_string_scalar(&like.pattern) { - Some((data_type, pattern_str)) => { + + match StringScalar::try_from_expr(&like.pattern) { + Some(string_scalar) => { + let pattern_str = string_scalar.as_str(); match pattern_str { None => return Ok(Transformed::yes(lit_bool_null())), - Some(pattern_str) if pattern_str == "%" => { + Some("%") => { // exp LIKE '%' is // - when exp is not NULL, it's true // - when exp is NULL, it's NULL @@ -1702,10 +1704,9 @@ impl TreeNodeRewriter for Simplifier<'_> { .replace_all(pattern_str, "%") .to_string(); Transformed::yes(Expr::Like(Like { - pattern: Box::new(to_string_scalar( - &data_type, - Some(simplified_pattern), - )), + pattern: Box::new( + string_scalar.to_expr(&simplified_pattern), + ), ..like })) } @@ -2126,21 +2127,54 @@ fn is_literal_or_literal_cast(expr: &Expr) -> bool { } } -fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { - match expr { - Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), - _ => None, - } +/// Helper for working with string scalar values (Utf8, LargeUtf8, Utf8View) +pub(crate) enum StringScalar<'a> { + Utf8(&'a ScalarValue), + LargeUtf8(&'a ScalarValue), + Utf8View(&'a ScalarValue), } -fn to_string_scalar(data_type: &DataType, value: Option) -> Expr { - match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None), - _ => unreachable!(), +impl<'a> StringScalar<'a> { + /// Create a `StringScalar` view from an `Expr` if it is a supported string literal. + /// Returns `None` if the expression is not a string literal. + pub(crate) fn try_from_expr(expr: &'a Expr) -> Option { + match expr { + Expr::Literal(scalar, _) => Self::try_from_scalar(scalar), + _ => None, + } + } + + /// Create a `StringScalar` view from a `ScalarValue` if it is a supported string type. + /// Returns `None` if the scalar value is not a supported string type. + fn try_from_scalar(scalar: &'a ScalarValue) -> Option { + match scalar { + ScalarValue::Utf8(_) => Some(Self::Utf8(scalar)), + ScalarValue::LargeUtf8(_) => Some(Self::LargeUtf8(scalar)), + ScalarValue::Utf8View(_) => Some(Self::Utf8View(scalar)), + _ => None, + } + } + + /// Returns the underlying string slice. + pub(crate) fn as_str(&self) -> Option<&'a str> { + match self { + Self::Utf8(scalar) | Self::LargeUtf8(scalar) | Self::Utf8View(scalar) => { + scalar.try_as_str().flatten() + } + } + } + + /// Build a new `Expr` of the same string type with the given value. + pub(crate) fn to_expr(&self, val: &str) -> Expr { + match self { + Self::Utf8(_) => Expr::Literal(ScalarValue::Utf8(Some(val.to_owned())), None), + Self::LargeUtf8(_) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(val.to_owned())), None) + } + Self::Utf8View(_) => { + Expr::Literal(ScalarValue::Utf8View(Some(val.to_owned())), None) + } + } } } diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index d388aaf74cdac..6c2492d05404d 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{BinaryExpr, Expr, Like, Operator, lit}; use regex_syntax::hir::{Capture, Hir, HirKind, Literal, Look}; +use crate::simplify_expressions::expr_simplifier::StringScalar; + /// Maximum number of regex alternations (`foo|bar|...`) that will be expanded into multiple `LIKE` expressions. const MAX_REGEX_ALTERNATIONS_EXPANSION: usize = 4; @@ -43,52 +46,70 @@ pub fn simplify_regex_expr( left: Box, op: Operator, right: Box, -) -> Result { - let mode = OperatorMode::new(&op); +) -> Result> { + // Check if the right operand is a supported string literal + let Some(string_scalar) = StringScalar::try_from_expr(right.as_ref()) else { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + }; + let pattern = string_scalar.as_str(); + let Some(pattern) = pattern else { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + }; - if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { - // Handle the special case for ".*" pattern - if pattern == ANY_CHAR_REGEX_PATTERN { - let new_expr = if mode.not { - // not empty - let empty_lit = Box::new(lit("")); - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right: empty_lit, - }) - } else { - // not null - left.is_not_null() - }; - return Ok(new_expr); - } + let mode = OperatorMode::new(&op); + // Handle the special case for ".*" pattern + if pattern == ANY_CHAR_REGEX_PATTERN { + let new_expr = if mode.not { + // not empty + let empty_lit = Box::new(string_scalar.to_expr("")); + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right: empty_lit, + }) + } else { + // not null + left.is_not_null() + }; + return Ok(Transformed::yes(new_expr)); + } - match regex_syntax::Parser::new().parse(pattern) { - Ok(hir) => { - let kind = hir.kind(); - if let HirKind::Alternation(alts) = kind { - if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION - && let Some(expr) = lower_alt(&mode, &left, alts) - { - return Ok(expr); - } - } else if let Some(expr) = lower_simple(&mode, &left, &hir) { - return Ok(expr); + match regex_syntax::Parser::new().parse(pattern) { + Ok(hir) => { + let kind = hir.kind(); + if let HirKind::Alternation(alts) = kind { + if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION + && let Some(expr) = lower_alt(&mode, &left, alts, &string_scalar) + { + return Ok(Transformed::yes(expr)); } - } - Err(e) => { - // error out early since the execution may fail anyways - return Err(DataFusionError::Context( - "Invalid regex".to_owned(), - Box::new(DataFusionError::External(Box::new(e))), - )); + } else if let Some(expr) = lower_simple(&mode, &left, &hir, &string_scalar) { + return Ok(Transformed::yes(expr)); } } + Err(e) => { + // error out early since the execution may fail anyways + return Err(DataFusionError::Context( + "Invalid regex".to_owned(), + Box::new(DataFusionError::External(Box::new(e))), + )); + } } // Leave untouched if optimization didn't work - Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })) + Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))) } #[derive(Debug)] @@ -117,11 +138,11 @@ impl OperatorMode { } /// Creates an [`LIKE`](Expr::Like) from the given `LIKE` pattern. - fn expr(&self, expr: Box, pattern: String) -> Expr { + fn expr(&self, expr: Box, pattern: Box) -> Expr { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)), + pattern, escape_char: None, case_insensitive: self.i, }; @@ -311,14 +332,24 @@ fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { } /// Tries to lower (transform) a simple regex pattern to a LIKE expression. -fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { +fn lower_simple( + mode: &OperatorMode, + left: &Expr, + hir: &Hir, + string_scalar: &StringScalar, +) -> Option { match hir.kind() { HirKind::Empty => { - return Some(mode.expr(Box::new(left.clone()), "%".to_owned())); + return Some( + mode.expr(Box::new(left.clone()), Box::new(string_scalar.to_expr("%"))), + ); } HirKind::Literal(l) => { let s = like_str_from_literal(l)?; - return Some(mode.expr(Box::new(left.clone()), format!("%{s}%"))); + return Some(mode.expr( + Box::new(left.clone()), + Box::new(string_scalar.to_expr(&format!("%{s}%"))), + )); } HirKind::Concat(inner) if is_anchored_literal(inner) => { return anchored_literal_to_expr(inner).map(|right| { @@ -333,7 +364,10 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { if let Some(pattern) = partial_anchored_literal_to_like(inner) .or_else(|| collect_concat_to_like_string(inner)) { - return Some(mode.expr(Box::new(left.clone()), pattern)); + return Some(mode.expr( + Box::new(left.clone()), + Box::new(string_scalar.to_expr(&pattern)), + )); } } _ => {} @@ -344,11 +378,16 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { /// Calls [`lower_simple`] for each alternative and combine the results with `or` or `and` /// based on [`OperatorMode`]. Any fail attempt to lower an alternative will makes this /// function to return `None`. -fn lower_alt(mode: &OperatorMode, left: &Expr, alts: &[Hir]) -> Option { +fn lower_alt( + mode: &OperatorMode, + left: &Expr, + alts: &[Hir], + string_scalar: &StringScalar, +) -> Option { let mut accu: Option = None; for part in alts { - if let Some(expr) = lower_simple(mode, left, part) { + if let Some(expr) = lower_simple(mode, left, part, string_scalar) { accu = match accu { Some(accu) => { if mode.not { diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index 99fc9900ef619..f8c219e052f80 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -34,20 +34,20 @@ query TT explain select b from t where b ~ '.*' ---- logical_plan -01)Filter: t.b ~ Utf8View(".*") +01)Filter: t.b IS NOT NULL 02)--TableScan: t projection=[b] physical_plan -01)FilterExec: b@0 ~ .* +01)FilterExec: b@0 IS NOT NULL 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TT explain select b from t where b !~ '.*' ---- logical_plan -01)Filter: t.b !~ Utf8View(".*") +01)Filter: t.b = Utf8View("") 02)--TableScan: t projection=[b] physical_plan -01)FilterExec: b@0 !~ .* +01)FilterExec: b@0 = 02)--DataSourceExec: partitions=1, partition_sizes=[1] query T diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 13b0aba653efb..4dcc2f663a830 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -1100,7 +1100,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: test.column1_utf8view ~ Utf8View("an") AS c1 +01)Projection: test.column1_utf8view LIKE Utf8View("%an%") AS c1 02)--TableScan: test projection=[column1_utf8view] # `~*` operator (regex match case-insensitive)