Skip to content

Commit 565367b

Browse files
committed
[FLINK-39340][table] Add BinaryMultiJoinToJoinRule for transform multi join back to regular join
1 parent 73d71d9 commit 565367b

6 files changed

Lines changed: 227 additions & 0 deletions

File tree

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,16 @@ public class OptimizerConfigOptions {
368368
+ "These might break savepoint compatibility across Flink versions and the goal is to have a stable version in the next release.")
369369
.build());
370370

371+
@Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING)
372+
public static final ConfigOption<Boolean> TABLE_OPTIMIZER_USE_MULTI_JOIN_FOR_BINARY_JOIN =
373+
key("table.optimizer.multi-join.use-for-binary-join")
374+
.booleanType()
375+
.defaultValue(false)
376+
.withDescription(
377+
Description.builder()
378+
.text("Allows binary multi join (multi join with 2 inputs).")
379+
.build());
380+
371381
@Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING)
372382
public static final ConfigOption<Boolean> TABLE_OPTIMIZER_INCREMENTAL_AGG_ENABLED =
373383
key("table.optimizer.incremental-agg-enabled")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.rules.logical;
20+
21+
import org.apache.flink.table.api.TableConfig;
22+
import org.apache.flink.table.api.config.OptimizerConfigOptions;
23+
import org.apache.flink.table.planner.utils.ShortcutUtils;
24+
import org.apache.flink.util.Preconditions;
25+
26+
import org.apache.calcite.plan.RelOptRuleCall;
27+
import org.apache.calcite.plan.RelRule;
28+
import org.apache.calcite.rel.RelNode;
29+
import org.apache.calcite.rel.core.Join;
30+
import org.apache.calcite.rel.logical.LogicalJoin;
31+
import org.apache.calcite.rel.rules.MultiJoin;
32+
import org.apache.calcite.rel.rules.TransformationRule;
33+
import org.apache.calcite.rex.RexNode;
34+
import org.apache.calcite.tools.RelBuilderFactory;
35+
import org.immutables.value.Value;
36+
37+
/** Rule for transform {@link MultiJoin} with 2 inputs back to {@link Join}. */
38+
@Value.Enclosing
39+
public class BinaryMultiJoinToJoinRule extends RelRule<BinaryMultiJoinToJoinRule.Config>
40+
implements TransformationRule {
41+
42+
public static final BinaryMultiJoinToJoinRule INSTANCE =
43+
BinaryMultiJoinToJoinRule.Config.DEFAULT.toRule();
44+
45+
/** Creates a JoinToMultiJoinRule. */
46+
public BinaryMultiJoinToJoinRule(BinaryMultiJoinToJoinRule.Config config) {
47+
super(config);
48+
}
49+
50+
@Deprecated // to be removed before 2.0
51+
public BinaryMultiJoinToJoinRule(Class<? extends MultiJoin> clazz) {
52+
this(BinaryMultiJoinToJoinRule.Config.DEFAULT.withOperandFor(clazz));
53+
}
54+
55+
@Deprecated // to be removed before 2.0
56+
public BinaryMultiJoinToJoinRule(
57+
Class<? extends MultiJoin> joinClass, RelBuilderFactory relBuilderFactory) {
58+
this(
59+
BinaryMultiJoinToJoinRule.Config.DEFAULT
60+
.withRelBuilderFactory(relBuilderFactory)
61+
.as(BinaryMultiJoinToJoinRule.Config.class)
62+
.withOperandFor(joinClass));
63+
}
64+
65+
/** This rule matches binary multi joins. */
66+
@Override
67+
public boolean matches(RelOptRuleCall call) {
68+
MultiJoin multiJoin = call.rel(0);
69+
return isEnabledViaConfig(multiJoin) && multiJoin.getInputs().size() < 3;
70+
}
71+
72+
/** This rule transform binary multi joins to regular joins. */
73+
@Override
74+
public void onMatch(RelOptRuleCall call) {
75+
MultiJoin multiJoin = call.rel(0);
76+
Preconditions.checkArgument(
77+
multiJoin.getInputs().size() == 2,
78+
"Only binary multi-join can be transformed into regular join.");
79+
80+
RexNode condition = multiJoin.getOuterJoinConditions().get(1);
81+
Join join =
82+
LogicalJoin.create(
83+
multiJoin.getInputs().get(0),
84+
multiJoin.getInputs().get(1),
85+
multiJoin.getHints(),
86+
Preconditions.checkNotNull(condition),
87+
multiJoin.getVariablesSet(),
88+
multiJoin.getJoinTypes().get(1));
89+
call.transformTo(join);
90+
}
91+
92+
/**
93+
* Checks if multi-join optimization and not use binary multi join option are enabled via
94+
* configuration.
95+
*
96+
* @param multiJoin the multi join node
97+
* @return true if TABLE_OPTIMIZER_MULTI_JOIN_ENABLED is set to true
98+
*/
99+
private boolean isEnabledViaConfig(MultiJoin multiJoin) {
100+
final TableConfig tableConfig = ShortcutUtils.unwrapTableConfig(multiJoin);
101+
return tableConfig.get(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED)
102+
&& tableConfig.get(
103+
OptimizerConfigOptions.TABLE_OPTIMIZER_USE_MULTI_JOIN_FOR_BINARY_JOIN);
104+
}
105+
106+
/** Rule configuration. */
107+
@Value.Immutable(singleton = false)
108+
public interface Config extends RelRule.Config {
109+
BinaryMultiJoinToJoinRule.Config DEFAULT =
110+
ImmutableBinaryMultiJoinToJoinRule.Config.builder()
111+
.build()
112+
.as(BinaryMultiJoinToJoinRule.Config.class)
113+
.withOperandFor(MultiJoin.class);
114+
115+
@Override
116+
default BinaryMultiJoinToJoinRule toRule() {
117+
return new BinaryMultiJoinToJoinRule(this);
118+
}
119+
120+
/** Defines an operand tree for the given classes. */
121+
default BinaryMultiJoinToJoinRule.Config withOperandFor(
122+
Class<? extends MultiJoin> joinClass) {
123+
return withOperandSupplier(
124+
b0 ->
125+
b0.operand(joinClass)
126+
.inputs(
127+
b1 -> b1.operand(RelNode.class).anyInputs(),
128+
b2 -> b2.operand(RelNode.class).anyInputs()))
129+
.as(BinaryMultiJoinToJoinRule.Config.class);
130+
}
131+
}
132+
}

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkStreamProgram.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ object FlinkStreamProgram {
3636
val PREDICATE_PUSHDOWN = "predicate_pushdown"
3737
val JOIN_REORDER = "join_reorder"
3838
val MULTI_JOIN = "multi_join"
39+
val BINARY_MULTI_JOIN = "binary_multi_join"
3940
val PROJECT_REWRITE = "project_rewrite"
4041
val LOGICAL = "logical"
4142
val LOGICAL_REWRITE = "logical_rewrite"
@@ -248,6 +249,21 @@ object FlinkStreamProgram {
248249
.build()
249250
)
250251

252+
chainedProgram.addLast(
253+
BINARY_MULTI_JOIN,
254+
FlinkGroupProgramBuilder
255+
.newBuilder[StreamOptimizeContext]
256+
.addProgram(
257+
FlinkHepRuleSetProgramBuilder.newBuilder
258+
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
259+
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
260+
.add(FlinkStreamRuleSets.BINARY_MULTI_JOIN_RULES)
261+
.build(),
262+
"transform binary multi joins back into regular join"
263+
)
264+
.build()
265+
)
266+
251267
// project rewrite
252268
chainedProgram.addLast(
253269
PROJECT_REWRITE,

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ object FlinkStreamRuleSets {
248248
JoinToMultiJoinRule.INSTANCE
249249
)
250250

251+
val BINARY_MULTI_JOIN_RULES: RuleSet = RuleSets.ofList(
252+
// transform binary MultiJoin back into regular join
253+
BinaryMultiJoinToJoinRule.INSTANCE
254+
)
255+
251256
/** RuleSet to do logical optimize. This RuleSet is a sub-set of [[LOGICAL_OPT_RULES]]. */
252257
private val LOGICAL_RULES: RuleSet = RuleSets.ofList(
253258
// scan optimization

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,28 @@ void testThreeWayInnerJoinRelPlanNoCommonJoinKey() {
209209
+ " ON u.cash = p.price");
210210
}
211211

212+
@Test
213+
@Tag("no-common-join-key")
214+
void testThreeWayInnerJoinRelPlanNoCommonJoinKeyAllowedBinaryMultiJoin() {
215+
util.getTableEnv()
216+
.getConfig()
217+
.set(OptimizerConfigOptions.TABLE_OPTIMIZER_USE_MULTI_JOIN_FOR_BINARY_JOIN, true);
218+
util.verifyRelPlan(
219+
"\nSELECT\n"
220+
+ " u.user_id,\n"
221+
+ " u.name,\n"
222+
+ " o.order_id,\n"
223+
+ " p.payment_id\n"
224+
+ "FROM Users u\n"
225+
+ "INNER JOIN Orders o\n"
226+
+ " ON u.user_id = o.user_id\n"
227+
+ "INNER JOIN Payments p\n"
228+
+ " ON u.cash = p.price");
229+
util.getTableEnv()
230+
.getConfig()
231+
.set(OptimizerConfigOptions.TABLE_OPTIMIZER_USE_MULTI_JOIN_FOR_BINARY_JOIN, false);
232+
}
233+
212234
@Test
213235
void testThreeWayInnerJoinExecPlan() {
214236
util.verifyExecPlan(

flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,6 +2097,48 @@ Calc(select=[user_id, name, order_id, payment_id])
20972097
: +- TableSourceScan(table=[[default_catalog, default_database, Orders, project=[order_id, user_id], metadata=[]]], fields=[order_id, user_id])
20982098
+- Exchange(distribution=[hash[price]])
20992099
+- TableSourceScan(table=[[default_catalog, default_database, Payments, project=[payment_id, price], metadata=[]]], fields=[payment_id, price])
2100+
]]>
2101+
</Resource>
2102+
</TestCase>
2103+
<TestCase name="testThreeWayInnerJoinRelPlanNoCommonJoinKeyAllowedBinaryMultiJoin">
2104+
<Resource name="sql">
2105+
<![CDATA[
2106+
SELECT
2107+
u.user_id,
2108+
u.name,
2109+
o.order_id,
2110+
p.payment_id
2111+
FROM Users u
2112+
INNER JOIN Orders o
2113+
ON u.user_id = o.user_id
2114+
INNER JOIN Payments p
2115+
ON u.cash = p.price]]>
2116+
</Resource>
2117+
<Resource name="ast">
2118+
<![CDATA[
2119+
LogicalProject(user_id=[$0], name=[$1], order_id=[$3], payment_id=[$6])
2120+
+- LogicalJoin(condition=[=($2, $7)], joinType=[inner])
2121+
:- LogicalJoin(condition=[=($0, $4)], joinType=[inner])
2122+
: :- LogicalTableScan(table=[[default_catalog, default_database, Users]])
2123+
: +- LogicalTableScan(table=[[default_catalog, default_database, Orders]])
2124+
+- LogicalTableScan(table=[[default_catalog, default_database, Payments]])
2125+
]]>
2126+
</Resource>
2127+
<Resource name="optimized rel plan">
2128+
<![CDATA[
2129+
Calc(select=[user_id, name, order_id, payment_id])
2130+
+- Join(joinType=[InnerJoin], where=[=(cash, price)], select=[user_id, name, cash, order_id, payment_id, price], leftInputSpec=[NoUniqueKey], rightInputSpec=[HasUniqueKey], stateTtlHints=[[[STATE_TTL options:[0s, 0s]]]])
2131+
:- Exchange(distribution=[hash[cash]])
2132+
: +- Calc(select=[user_id, name, cash, order_id])
2133+
: +- Join(joinType=[InnerJoin], where=[=(user_id, user_id0)], select=[user_id, name, cash, order_id, user_id0], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[HasUniqueKey], stateTtlHints=[[[STATE_TTL options:[0s, 0s]]]])
2134+
: :- Exchange(distribution=[hash[user_id]])
2135+
: : +- ChangelogNormalize(key=[user_id])
2136+
: : +- Exchange(distribution=[hash[user_id]])
2137+
: : +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id, name, cash])
2138+
: +- Exchange(distribution=[hash[user_id]])
2139+
: +- TableSourceScan(table=[[default_catalog, default_database, Orders, project=[order_id, user_id], metadata=[]]], fields=[order_id, user_id])
2140+
+- Exchange(distribution=[hash[price]])
2141+
+- TableSourceScan(table=[[default_catalog, default_database, Payments, project=[payment_id, price], metadata=[]]], fields=[payment_id, price])
21002142
]]>
21012143
</Resource>
21022144
</TestCase>

0 commit comments

Comments
 (0)