Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,28 @@ private Set<FuncDepsItem> findValidItems(Set<Slot> requireOutputs) {
* <p>
* Example:
* Given:
* - Initial slots: {{A, B, C}, {D, E}, {F, G}}
* - Required outputs: {A, D, F}
* - Valid functional dependencies: {A} -> {B}, {D, E} -> {G}, {F} -> {G}
* - Initial slots: {{A}, {B}, {C}, {D}, {E}}
* - Required outputs: {}
* - validItems: {A} -> {B}, {B} -> {C}, {C} -> {D}, {D} -> {A}, {A} -> {E}
*
* Process:
* 1. Start with minSlotSet = {{A, B, C}, {D, E}, {F, G}}
* 1. Start with minSlotSet = {{A}, {B}, {C}, {D}, {E}}
* 2. For {A} -> {B}:
* - Both {A} and {B} are in minSlotSet, so mark {B} for elimination
* 3. For {D, E} -> {G}:
* - Both {D, E} and {G} are in minSlotSet, so mark {G} for elimination
* 4. For {F} -> {G}:
* - Both {F} and {G} are in minSlotSet, but {G} is already marked for elimination
* 5. Remove eliminated slots: {B} and {G}
* 3. For {B} -> {C}:
* - Both {B} and {C} are in minSlotSet, so mark {C} for elimination
* 4. For {C} -> {D}:
* - Both {C} and {D} are in minSlotSet, so mark {D} for elimination
* 4. For {D} -> {E}:
* - Both {D} and {E} are in minSlotSet, so mark {E} for elimination
*
* Result: {{A, C}, {D, E}, {F}}
* Result: {{A}}
* </p>
*
* @param slots the initial set of slot sets to be reduced
* @param requireOutputs the set of slots that must be preserved in the output
* @return the minimal set of slot sets after applying all possible reductions
*/
*/
public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots, Set<Slot> requireOutputs) {
Set<Set<Slot>> minSlotSet = Sets.newHashSet(slots);
Set<Set<Slot>> eliminatedSlots = new HashSet<>();
Expand Down Expand Up @@ -201,16 +202,30 @@ public Map<Set<Slot>, Set<Set<Slot>>> getREdges() {
}

/**
* find the determinants of dependencies
* Finds all slot sets that have a bijective relationship with the given slot set.
* Given edges containing:
* {A} -> {{B}, {C}}
* {B} -> {{A}, {D}}
* {C} -> {{A}}
* When slot = {A}, returns {{B}} because {A} and {B} mutually determine each other.
* {C} is not returned because {C} does not determine {A} (one-way dependency only).
*/
public Set<Set<Slot>> findDeterminats(Set<Slot> dependency) {
Set<Set<Slot>> determinants = new HashSet<>();
for (FuncDepsItem item : items) {
if (item.dependencies.equals(dependency)) {
determinants.add(item.determinants);
public Set<Set<Slot>> findBijectionSlots(Set<Slot> slot) {
Set<Set<Slot>> bijectionSlots = new HashSet<>();
if (!edges.containsKey(slot)) {
return bijectionSlots;
}
for (Set<Slot> dep : edges.get(slot)) {
if (!edges.containsKey(dep)) {
continue;
}
for (Set<Slot> det : edges.get(dep)) {
if (det.equals(slot)) {
bijectionSlots.add(dep);
}
}
}
return determinants;
return bijectionSlots;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.FuncDeps;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
Expand Down Expand Up @@ -78,17 +79,38 @@ public List<Rule> buildRules() {
}

LogicalAggregate<Plan> eliminateGroupByKey(LogicalAggregate<? extends Plan> agg, Set<Slot> requireOutput) {
Set<Expression> removeExpression = findCanBeRemovedExpressions(agg, requireOutput,
agg.child().getLogicalProperties().getTrait());
List<Expression> newGroupExpression = new ArrayList<>();
for (Expression expression : agg.getGroupByExpressions()) {
if (!removeExpression.contains(expression)) {
newGroupExpression.add(expression);
}
}
List<NamedExpression> newOutput = new ArrayList<>();
for (NamedExpression expression : agg.getOutputExpressions()) {
if (!removeExpression.contains(expression)) {
newOutput.add(expression);
}
}
return agg.withGroupByAndOutput(newGroupExpression, newOutput);
}

/**
* return removeExpression
*/
public static Set<Expression> findCanBeRemovedExpressions(LogicalAggregate<? extends Plan> agg,
Set<Slot> requireOutput, DataTrait dataTrait) {
Map<Expression, Set<Slot>> groupBySlots = new HashMap<>();
Set<Slot> validSlots = new HashSet<>();
for (Expression expression : agg.getGroupByExpressions()) {
groupBySlots.put(expression, expression.getInputSlots());
validSlots.addAll(expression.getInputSlots());
}

FuncDeps funcDeps = agg.child().getLogicalProperties()
.getTrait().getAllValidFuncDeps(validSlots);
FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(validSlots);
if (funcDeps.isEmpty()) {
return null;
return new HashSet<>();
}

Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values()), requireOutput);
Expand All @@ -99,19 +121,6 @@ LogicalAggregate<Plan> eliminateGroupByKey(LogicalAggregate<? extends Plan> agg,
removeExpression.add(entry.getKey());
}
}

List<Expression> newGroupExpression = new ArrayList<>();
for (Expression expression : agg.getGroupByExpressions()) {
if (!removeExpression.contains(expression)) {
newGroupExpression.add(expression);
}
}
List<NamedExpression> newOutput = new ArrayList<>();
for (NamedExpression expression : agg.getOutputExpressions()) {
if (!removeExpression.contains(expression)) {
newOutput.add(expression);
}
}
return agg.withGroupByAndOutput(newGroupExpression, newOutput);
return removeExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.FuncDeps;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
Expand Down Expand Up @@ -125,19 +124,38 @@ public List<Rule> buildRules() {
}

// eliminate the slot of primary plan in agg
// e.g.
// select primary_table_pk, primary_table_other from primary_table join foreign_table on pk = fk
// group by pk, primary_table_other_cols;
private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan child,
Plan primary, Plan foreign) {
Set<Slot> aggInputs = agg.getInputSlots();
if (primary.getOutputSet().stream().noneMatch(aggInputs::contains)) {
return agg;
}
// Firstly, using fd to eliminate group by key.
// group by pk, primary_table_other_cols;
// -> group by pk;
Set<Expression> removeExpression = EliminateGroupByKey.findCanBeRemovedExpressions(agg,
Sets.intersection(agg.getOutputSet(), foreign.getOutputSet()),
child.getLogicalProperties().getTrait());
List<Expression> minGroupBySlotList = new ArrayList<>();
for (Expression expression : agg.getGroupByExpressions()) {
if (!removeExpression.contains(expression)) {
minGroupBySlotList.add(expression);
}
}

// Secondly, put bijective slot into map: {pk : fk}
// Bijective slots are mutually interchangeable within GROUP BY keys.
// group by pk -> group by fk
Set<Slot> primaryOutputSet = primary.getOutputSet();
Set<Slot> primarySlots = Sets.intersection(aggInputs, primaryOutputSet);
DataTrait dataTrait = child.getLogicalProperties().getTrait();
FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(foreign.getOutputSet(), primary.getOutputSet()));
HashMap<Slot, Slot> primaryToForeignDeps = new HashMap<>();
FuncDeps funcDepsForJoin = child.getLogicalProperties().getTrait()
.getAllValidFuncDeps(Sets.union(primaryOutputSet, foreign.getOutputSet()));
for (Slot slot : primarySlots) {
Set<Set<Slot>> replacedSlotSets = funcDeps.findDeterminats(ImmutableSet.of(slot));
Set<Set<Slot>> replacedSlotSets = funcDepsForJoin.findBijectionSlots(ImmutableSet.of(slot));
for (Set<Slot> replacedSlots : replacedSlotSets) {
if (primaryOutputSet.stream().noneMatch(replacedSlots::contains)
&& replacedSlots.size() == 1) {
Expand All @@ -147,19 +165,23 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
}
}

Set<Expression> newGroupBySlots = constructNewGroupBy(agg, primaryOutputSet, primaryToForeignDeps);
// Thirdly, construct new Agg below join.
// For the pk-fk join, the foreign table side will not expand rows.
// As a result, executing agg(group by fk) before join is same with executing agg(group by fk) after join.
Set<Expression> newGroupBySlots = constructNewGroupBy(minGroupBySlotList, primaryOutputSet,
primaryToForeignDeps);
List<NamedExpression> newOutput = constructNewOutput(
agg, primaryOutputSet, primaryToForeignDeps, funcDeps, primary);
agg, primaryOutputSet, primaryToForeignDeps, funcDepsForJoin, primary);
if (newGroupBySlots == null || newOutput == null) {
return null;
}
return agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newOutput));
}

private @Nullable Set<Expression> constructNewGroupBy(LogicalAggregate<?> agg, Set<Slot> primaryOutputs,
Map<Slot, Slot> primaryToForeignBiDeps) {
private @Nullable Set<Expression> constructNewGroupBy(List<? extends Expression> gbyExpression,
Set<Slot> primaryOutputs, Map<Slot, Slot> primaryToForeignBiDeps) {
Set<Expression> newGroupBySlots = new HashSet<>();
for (Expression expression : agg.getGroupByExpressions()) {
for (Expression expression : gbyExpression) {
if (!(expression instanceof Slot)) {
return null;
}
Expand Down Expand Up @@ -196,9 +218,7 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
&& expression.child(0).child(0) instanceof Slot) {
// count(slot) can be rewritten by circle deps
Slot slot = (Slot) expression.child(0).child(0);
if (primaryToForeignDeps.containsKey(slot)
&& funcDeps.isCircleDeps(
ImmutableSet.of(slot), ImmutableSet.of(primaryToForeignDeps.get(slot)))) {
if (primaryToForeignDeps.containsKey(slot)) {
expression = (NamedExpression) expression.rewriteUp(e ->
e instanceof Slot
? primaryToForeignDeps.getOrDefault((Slot) e, (Slot) e)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !not_push_down_shape --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((store_sales_test.ss_customer_sk = customer_test.c_customer_sk)) otherCondition=()
--------PhysicalOlapScan[store_sales_test]
--------PhysicalOlapScan[customer_test]

-- !not_push_down_result --
Smith John 2024-01-01

-- !push_down_shape --
PhysicalResultSink
--hashJoin[INNER_JOIN] hashCondition=((store_sales_test.ss_customer_sk = customer_test.c_customer_sk)) otherCondition=()
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalOlapScan[store_sales_test]
----PhysicalOlapScan[customer_test]

-- !push_down_result --
John 1 2024-01-01
John 2 2024-01-01

-- !push_down_with_count_shape --
PhysicalResultSink
--hashJoin[INNER_JOIN] hashCondition=((store_sales_test.ss_customer_sk = customer_test.c_customer_sk)) otherCondition=()
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalOlapScan[store_sales_test]
----PhysicalOlapScan[customer_test]

-- !push_down_with_count_result --
John 1 2024-01-01 1
John 2 2024-01-01 1

Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,46 @@ PhysicalResultSink
----------hashAgg[LOCAL]
------------PhysicalProject
--------------PhysicalIntersect RFV2: RF6[c_last_name->c_last_name] RF7[c_last_name->c_last_name]
----------------PhysicalDistribute[DistributionSpecHash]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=()
----------------------hashAgg[GLOBAL]
------------------------PhysicalDistribute[DistributionSpecHash]
--------------------------hashAgg[LOCAL]
----------------------------PhysicalProject
------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ws_sold_date_sk]
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[web_sales] apply RFs: RF0
--------------------------------PhysicalProject
----------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183))
------------------------------------PhysicalOlapScan[date_dim]
----------------hashAgg[GLOBAL]
------------------PhysicalDistribute[DistributionSpecHash]
--------------------hashAgg[LOCAL]
----------------------PhysicalProject
------------------------PhysicalOlapScan[customer]
----------------PhysicalDistribute[DistributionSpecHash]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=()
----------------------hashAgg[GLOBAL]
------------------------PhysicalDistribute[DistributionSpecHash]
--------------------------hashAgg[LOCAL]
----------------------------PhysicalProject
------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk]
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF2
--------------------------------PhysicalProject
----------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183))
------------------------------------PhysicalOlapScan[date_dim]
------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ws_sold_date_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=()
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[web_sales] apply RFs: RF1
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[customer]
--------------------------PhysicalProject
----------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183))
------------------------------PhysicalOlapScan[date_dim]
----------------hashAgg[GLOBAL]
------------------PhysicalDistribute[DistributionSpecHash]
--------------------hashAgg[LOCAL]
----------------------PhysicalProject
------------------------PhysicalOlapScan[customer] RFV2: RF6
----------------PhysicalDistribute[DistributionSpecHash]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=()
----------------------hashAgg[GLOBAL]
------------------------PhysicalDistribute[DistributionSpecHash]
--------------------------hashAgg[LOCAL]
----------------------------PhysicalProject
------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk]
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[store_sales] apply RFs: RF4
--------------------------------PhysicalProject
----------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183))
------------------------------------PhysicalOlapScan[date_dim]
------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=()
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[customer] RFV2: RF6
--------------------------PhysicalProject
----------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183))
------------------------------PhysicalOlapScan[date_dim]
----------------hashAgg[GLOBAL]
------------------PhysicalDistribute[DistributionSpecHash]
--------------------hashAgg[LOCAL]
----------------------PhysicalProject
------------------------PhysicalOlapScan[customer] RFV2: RF7
------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF5 d_date_sk->[ss_sold_date_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=()
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[store_sales] apply RFs: RF5
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[customer] RFV2: RF7
--------------------------PhysicalProject
----------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183))
------------------------------PhysicalOlapScan[date_dim]

Loading