diff --git a/Cargo.lock b/Cargo.lock index 5092a860e3c13..b992b12a62750 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2367,6 +2367,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-expr-common", + "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-functions-window-common", diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 76d3f73f68767..5e2026f05ac2c 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -49,6 +49,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } +datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c6644e008645a..9adb1d9921f37 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1660,17 +1660,19 @@ impl TreeNodeRewriter for Simplifier<'_> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, - }) => Transformed::yes(simplify_regex_expr(left, op, right)?), + }) => simplify_regex_expr(left, op, right)?, // Rules for Like Expr::Like(like) => { // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291 let escape_char = like.escape_char.unwrap_or('\\'); - match as_string_scalar(&like.pattern) { - Some((data_type, pattern_str)) => { + + match StringScalar::try_from_expr(&like.pattern) { + Some(string_scalar) => { + let pattern_str = string_scalar.as_str(); match pattern_str { None => return Ok(Transformed::yes(lit_bool_null())), - Some(pattern_str) if pattern_str == "%" => { + Some("%") => { // exp LIKE '%' is // - when exp is not NULL, it's true // - when exp is NULL, it's NULL @@ -1702,10 +1704,9 @@ impl TreeNodeRewriter for Simplifier<'_> { .replace_all(pattern_str, "%") .to_string(); Transformed::yes(Expr::Like(Like { - pattern: Box::new(to_string_scalar( - &data_type, - Some(simplified_pattern), - )), + pattern: Box::new( + string_scalar.to_expr(&simplified_pattern), + ), ..like })) } @@ -2126,21 +2127,54 @@ fn is_literal_or_literal_cast(expr: &Expr) -> bool { } } -fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { - match expr { - Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), - _ => None, - } +/// Helper for working with string scalar values (Utf8, LargeUtf8, Utf8View) +pub(crate) enum StringScalar<'a> { + Utf8(&'a ScalarValue), + LargeUtf8(&'a ScalarValue), + Utf8View(&'a ScalarValue), } -fn to_string_scalar(data_type: &DataType, value: Option) -> Expr { - match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None), - _ => unreachable!(), +impl<'a> StringScalar<'a> { + /// Create a `StringScalar` view from an `Expr` if it is a supported string literal. + /// Returns `None` if the expression is not a string literal. + pub(crate) fn try_from_expr(expr: &'a Expr) -> Option { + match expr { + Expr::Literal(scalar, _) => Self::try_from_scalar(scalar), + _ => None, + } + } + + /// Create a `StringScalar` view from a `ScalarValue` if it is a supported string type. + /// Returns `None` if the scalar value is not a supported string type. + fn try_from_scalar(scalar: &'a ScalarValue) -> Option { + match scalar { + ScalarValue::Utf8(_) => Some(Self::Utf8(scalar)), + ScalarValue::LargeUtf8(_) => Some(Self::LargeUtf8(scalar)), + ScalarValue::Utf8View(_) => Some(Self::Utf8View(scalar)), + _ => None, + } + } + + /// Returns the underlying string slice. + pub(crate) fn as_str(&self) -> Option<&'a str> { + match self { + Self::Utf8(scalar) | Self::LargeUtf8(scalar) | Self::Utf8View(scalar) => { + scalar.try_as_str().flatten() + } + } + } + + /// Build a new `Expr` of the same string type with the given value. + pub(crate) fn to_expr(&self, val: &str) -> Expr { + match self { + Self::Utf8(_) => Expr::Literal(ScalarValue::Utf8(Some(val.to_owned())), None), + Self::LargeUtf8(_) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(val.to_owned())), None) + } + Self::Utf8View(_) => { + Expr::Literal(ScalarValue::Utf8View(Some(val.to_owned())), None) + } + } } } @@ -2332,6 +2366,7 @@ mod tests { interval_arithmetic::Interval, *, }; + use datafusion_functions::expr_fn::contains as contains_fn; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::PhysicalExpr; @@ -3347,6 +3382,19 @@ mod tests { col("c1").like(lit("%foo%")), ); + // regular expression that matches a substring + assert_change( + regex_match(col("c1"), lit(".*foo.*")), + contains_fn(col("c1"), lit("foo")), + ); + + assert_change( + regex_not_match(col("c1"), lit(".*foo.*")), + Expr::Not(Box::new(contains_fn(col("c1"), lit("foo")))), + ); + + assert_change(regex_match(col("c1"), lit(".*.*")), col("c1").is_not_null()); + // regular expressions that match an exact literal assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit(""))); assert_change( diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index d388aaf74cdac..8d19f18330392 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -15,10 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{BinaryExpr, Expr, Like, Operator, lit}; +use datafusion_functions::expr_fn::contains; use regex_syntax::hir::{Capture, Hir, HirKind, Literal, Look}; +use crate::simplify_expressions::expr_simplifier::StringScalar; + /// Maximum number of regex alternations (`foo|bar|...`) that will be expanded into multiple `LIKE` expressions. const MAX_REGEX_ALTERNATIONS_EXPANSION: usize = 4; @@ -31,64 +35,99 @@ const ANY_CHAR_REGEX_PATTERN: &str = ".*"; /// /// Typical cases this function can simplify: /// - empty regex pattern to `LIKE '%'` +/// - `EQ .*foo.*` to `contains(left, "foo")` +/// - `NE .*foo.*` to `NOT contains(left, "foo")` /// - literal regex patterns to `LIKE '%foo%'` /// - full anchored regex patterns (e.g. `^foo$`) to `= 'foo'` /// - partial anchored regex patterns (e.g. `^foo`) to `LIKE 'foo%'` /// - combinations (alternatives) of the above, will be concatenated with `OR` or `AND` /// - `EQ .*` to NotNull -/// - `NE .*` means IS EMPTY +/// - `NE .*` to false (.* matches any string, and NULL !~ results in NULL so NOT match can never be true) /// /// Dev note: unit tests of this function are in `expr_simplifier.rs`, case `test_simplify_regex`. pub fn simplify_regex_expr( left: Box, op: Operator, right: Box, -) -> Result { +) -> Result> { + // Check if the right operand is a supported string literal + let Some(string_scalar) = StringScalar::try_from_expr(right.as_ref()) else { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + }; + let pattern = string_scalar.as_str(); + let Some(pattern) = pattern else { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + }; + let mode = OperatorMode::new(&op); + // Handle the special case for ".*" pattern + if pattern == ANY_CHAR_REGEX_PATTERN { + let new_expr = if mode.not { + // Always false. + lit(false) + } else { + // not null + left.is_not_null() + }; + return Ok(Transformed::yes(new_expr)); + } - if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { - // Handle the special case for ".*" pattern - if pattern == ANY_CHAR_REGEX_PATTERN { - let new_expr = if mode.not { - // not empty - let empty_lit = Box::new(lit("")); - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right: empty_lit, - }) - } else { - // not null - left.is_not_null() - }; - return Ok(new_expr); - } + // Convert patterns of the form ".*foo.*" to `contains(left, "foo")` + if !mode.i + && let Some(inner) = pattern + // If pattern starts and ends with ".*" + .strip_prefix(ANY_CHAR_REGEX_PATTERN) + .and_then(|rest| rest.strip_suffix(ANY_CHAR_REGEX_PATTERN)) + // If inner is all non-special characters + && inner.chars().all(|x| !is_special_character(x)) + { + let new_expr = match (mode.not, inner.is_empty()) { + // contains(left, inner) + (false, false) => contains(*left, lit(inner)), + (false, true) => left.is_not_null(), + // not (contains(left, inner)) + (true, false) => Expr::Not(Box::new(contains(*left, lit(inner)))), + (true, true) => lit(false), // "!~ '.*'" is always false. + }; + return Ok(Transformed::yes(new_expr)); + } - match regex_syntax::Parser::new().parse(pattern) { - Ok(hir) => { - let kind = hir.kind(); - if let HirKind::Alternation(alts) = kind { - if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION - && let Some(expr) = lower_alt(&mode, &left, alts) - { - return Ok(expr); - } - } else if let Some(expr) = lower_simple(&mode, &left, &hir) { - return Ok(expr); + match regex_syntax::Parser::new().parse(pattern) { + Ok(hir) => { + let kind = hir.kind(); + if let HirKind::Alternation(alts) = kind { + if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION + && let Some(expr) = lower_alt(&mode, &left, alts, &string_scalar) + { + return Ok(Transformed::yes(expr)); } + } else if let Some(expr) = lower_simple(&mode, &left, &hir, &string_scalar) { + return Ok(Transformed::yes(expr)); } - Err(e) => { - // error out early since the execution may fail anyways - return Err(DataFusionError::Context( - "Invalid regex".to_owned(), - Box::new(DataFusionError::External(Box::new(e))), - )); - } + } + Err(e) => { + // error out early since the execution may fail anyways + return Err(DataFusionError::Context( + "Invalid regex".to_owned(), + Box::new(DataFusionError::External(Box::new(e))), + )); } } // Leave untouched if optimization didn't work - Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })) + Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))) } #[derive(Debug)] @@ -117,11 +156,11 @@ impl OperatorMode { } /// Creates an [`LIKE`](Expr::Like) from the given `LIKE` pattern. - fn expr(&self, expr: Box, pattern: String) -> Expr { + fn expr(&self, expr: Box, pattern: Box) -> Expr { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)), + pattern, escape_char: None, case_insensitive: self.i, }; @@ -181,6 +220,25 @@ fn is_safe_for_like(c: char) -> bool { (c != '%') && (c != '_') } +fn is_special_character(c: char) -> bool { + matches!( + c, + '.' | '*' + | '+' + | '?' + | '|' + | '(' + | ')' + | '[' + | ']' + | '{' + | '}' + | '^' + | '$' + | '\\' + ) +} + /// Returns true if the elements in a `Concat` pattern are: /// - `[Look::Start, Look::End]` /// - `[Look::Start, Literal(_), Look::End]` @@ -311,14 +369,24 @@ fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { } /// Tries to lower (transform) a simple regex pattern to a LIKE expression. -fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { +fn lower_simple( + mode: &OperatorMode, + left: &Expr, + hir: &Hir, + string_scalar: &StringScalar, +) -> Option { match hir.kind() { HirKind::Empty => { - return Some(mode.expr(Box::new(left.clone()), "%".to_owned())); + return Some( + mode.expr(Box::new(left.clone()), Box::new(string_scalar.to_expr("%"))), + ); } HirKind::Literal(l) => { let s = like_str_from_literal(l)?; - return Some(mode.expr(Box::new(left.clone()), format!("%{s}%"))); + return Some(mode.expr( + Box::new(left.clone()), + Box::new(string_scalar.to_expr(&format!("%{s}%"))), + )); } HirKind::Concat(inner) if is_anchored_literal(inner) => { return anchored_literal_to_expr(inner).map(|right| { @@ -333,7 +401,10 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { if let Some(pattern) = partial_anchored_literal_to_like(inner) .or_else(|| collect_concat_to_like_string(inner)) { - return Some(mode.expr(Box::new(left.clone()), pattern)); + return Some(mode.expr( + Box::new(left.clone()), + Box::new(string_scalar.to_expr(&pattern)), + )); } } _ => {} @@ -344,11 +415,16 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { /// Calls [`lower_simple`] for each alternative and combine the results with `or` or `and` /// based on [`OperatorMode`]. Any fail attempt to lower an alternative will makes this /// function to return `None`. -fn lower_alt(mode: &OperatorMode, left: &Expr, alts: &[Hir]) -> Option { +fn lower_alt( + mode: &OperatorMode, + left: &Expr, + alts: &[Hir], + string_scalar: &StringScalar, +) -> Option { let mut accu: Option = None; for part in alts { - if let Some(expr) = lower_simple(mode, left, part) { + if let Some(expr) = lower_simple(mode, left, part, string_scalar) { accu = match accu { Some(accu) => { if mode.not { diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index f7f100015004a..fb970ad0af996 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -883,17 +883,17 @@ mod tests { " )?; - // Test `!= ".*"` transforms to checking if the column is empty + // Test `!= ".*"` transforms to false (.* matches any string, so NOT match is always false) let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotMatch, lit(".*")))? .build()?; assert_optimized_plan_equal!( plan, - @ r#" - Filter: test.a = Utf8("") + @ r" + Filter: Boolean(false) TableScan: test - "# + " )?; // Test case-insensitive versions @@ -911,17 +911,17 @@ mod tests { " )?; - // Test `!~ ".*"` (case-insensitive) transforms to checking if the column is empty + // Test `!~ ".*"` (case-insensitive) transforms to false (.* matches any string, so NOT match is always false) let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotIMatch, lit(".*")))? .build()?; assert_optimized_plan_equal!( plan, - @ r#" - Filter: test.a = Utf8("") + @ r" + Filter: Boolean(false) TableScan: test - "# + " ) } diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index 99fc9900ef619..4f7967bbfa8d6 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -34,30 +34,77 @@ query TT explain select b from t where b ~ '.*' ---- logical_plan -01)Filter: t.b ~ Utf8View(".*") +01)Filter: t.b IS NOT NULL 02)--TableScan: t projection=[b] physical_plan -01)FilterExec: b@0 ~ .* +01)FilterExec: b@0 IS NOT NULL 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TT explain select b from t where b !~ '.*' ---- +logical_plan EmptyRelation: rows=0 +physical_plan EmptyExec + +query T +select b from t where b ~ '.*' +---- +a +c + +query T +select b from t where b !~ '.*' +---- + +# test regex .*literal.* simplifies to contains() +query TT +explain select b from t where b ~ '.*a.*' +---- logical_plan -01)Filter: t.b !~ Utf8View(".*") +01)Filter: contains(t.b, Utf8("a")) 02)--TableScan: t projection=[b] physical_plan -01)FilterExec: b@0 !~ .* +01)FilterExec: contains(b@0, a) 02)--DataSourceExec: partitions=1, partition_sizes=[1] query T -select b from t where b ~ '.*' +select b from t where b ~ '.*a.*' +---- +a + +query TT +explain select b from t where b !~ '.*a.*' +---- +logical_plan +01)Filter: NOT contains(t.b, Utf8("a")) +02)--TableScan: t projection=[b] +physical_plan +01)FilterExec: NOT contains(b@0, a) +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query T +select b from t where b !~ '.*a.*' +---- +c + +query TT +explain select b from t where b ~ '.*.*' +---- +logical_plan +01)Filter: t.b IS NOT NULL +02)--TableScan: t projection=[b] +physical_plan +01)FilterExec: b@0 IS NOT NULL +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query T +select b from t where b ~ '.*.*' ---- a c query T -select b from t where b !~ '.*' +select b from t where b !~ '.*.*' ---- query TT diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 13b0aba653efb..4dcc2f663a830 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -1100,7 +1100,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: test.column1_utf8view ~ Utf8View("an") AS c1 +01)Projection: test.column1_utf8view LIKE Utf8View("%an%") AS c1 02)--TableScan: test projection=[column1_utf8view] # `~*` operator (regex match case-insensitive)