Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit cb75a58

Browse files
authoredDec 27, 2024
[FLINK-36951][table] Migrate ProjectSemiAntiJoinTransposeRule to java
1 parent e26d721 commit cb75a58

File tree

2 files changed

+219
-170
lines changed

2 files changed

+219
-170
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.rules.logical;
20+
21+
import org.apache.calcite.plan.RelOptRuleCall;
22+
import org.apache.calcite.plan.RelOptUtil;
23+
import org.apache.calcite.plan.RelRule;
24+
import org.apache.calcite.rel.RelNode;
25+
import org.apache.calcite.rel.core.Join;
26+
import org.apache.calcite.rel.core.JoinRelType;
27+
import org.apache.calcite.rel.core.Project;
28+
import org.apache.calcite.rel.logical.LogicalJoin;
29+
import org.apache.calcite.rel.logical.LogicalProject;
30+
import org.apache.calcite.rel.type.RelDataType;
31+
import org.apache.calcite.rel.type.RelDataTypeFactory;
32+
import org.apache.calcite.rex.RexBuilder;
33+
import org.apache.calcite.rex.RexInputRef;
34+
import org.apache.calcite.rex.RexNode;
35+
import org.apache.calcite.rex.RexShuttle;
36+
import org.apache.calcite.tools.RelBuilder;
37+
import org.apache.calcite.util.ImmutableBitSet;
38+
import org.apache.calcite.util.mapping.Mappings;
39+
import org.immutables.value.Value;
40+
41+
import java.util.ArrayList;
42+
import java.util.Collections;
43+
import java.util.List;
44+
import java.util.stream.Collectors;
45+
46+
/**
47+
* Planner rule that pushes a {@link Project} down in a tree past a semi/anti {@link Join} by
48+
* splitting the projection into a projection on top of left child of the Join.
49+
*/
50+
@Value.Enclosing
51+
public class ProjectSemiAntiJoinTransposeRule
52+
extends RelRule<ProjectSemiAntiJoinTransposeRule.ProjectSemiAntiJoinTransposeRuleConfig> {
53+
54+
public static final ProjectSemiAntiJoinTransposeRule INSTANCE =
55+
ProjectSemiAntiJoinTransposeRule.ProjectSemiAntiJoinTransposeRuleConfig.DEFAULT
56+
.toRule();
57+
58+
private ProjectSemiAntiJoinTransposeRule(ProjectSemiAntiJoinTransposeRuleConfig config) {
59+
super(config);
60+
}
61+
62+
@Override
63+
public boolean matches(RelOptRuleCall call) {
64+
LogicalJoin join = call.rel(1);
65+
JoinRelType joinType = join.getJoinType();
66+
return joinType == JoinRelType.SEMI || joinType == JoinRelType.ANTI;
67+
}
68+
69+
@Override
70+
public void onMatch(RelOptRuleCall call) {
71+
LogicalProject project = call.rel(0);
72+
LogicalJoin join = call.rel(1);
73+
74+
// 1. calculate every inputs reference fields
75+
ImmutableBitSet joinCondFields = RelOptUtil.InputFinder.bits(join.getCondition());
76+
ImmutableBitSet projectFields = RelOptUtil.InputFinder.bits(project.getProjects(), null);
77+
ImmutableBitSet allNeededFields =
78+
projectFields.isEmpty()
79+
? joinCondFields.union(ImmutableBitSet.of(0))
80+
: joinCondFields.union(projectFields);
81+
82+
int leftFieldCount = join.getLeft().getRowType().getFieldCount();
83+
int allInputFieldCount = leftFieldCount + join.getRight().getRowType().getFieldCount();
84+
if (allNeededFields.equals(ImmutableBitSet.range(0, allInputFieldCount))) {
85+
return;
86+
}
87+
88+
ImmutableBitSet leftNeededFields =
89+
ImmutableBitSet.range(0, leftFieldCount).intersect(allNeededFields);
90+
ImmutableBitSet rightNeededFields =
91+
ImmutableBitSet.range(leftFieldCount, allInputFieldCount)
92+
.intersect(allNeededFields);
93+
94+
// 2. new join inputs
95+
RelNode newLeftInput =
96+
createNewJoinInput(call.builder(), join.getLeft(), leftNeededFields, 0);
97+
RelNode newRightInput =
98+
createNewJoinInput(
99+
call.builder(), join.getRight(), rightNeededFields, leftFieldCount);
100+
101+
// mapping origin field index to new field index,
102+
// used to rewrite join condition and top project
103+
Mappings.TargetMapping mapping =
104+
Mappings.target(
105+
i -> allNeededFields.indexOf(i),
106+
allInputFieldCount,
107+
allNeededFields.cardinality());
108+
109+
// 3. create new join
110+
RelNode newJoin = createNewJoin(join, mapping, newLeftInput, newRightInput);
111+
112+
// 4. create top project
113+
List<RexNode> newProjects = createNewProjects(project, newJoin, mapping);
114+
RelNode topProject =
115+
call.builder()
116+
.push(newJoin)
117+
.project(newProjects, project.getRowType().getFieldNames())
118+
.build();
119+
120+
call.transformTo(topProject);
121+
}
122+
123+
private RelNode createNewJoinInput(
124+
RelBuilder relBuilder,
125+
RelNode originInput,
126+
ImmutableBitSet inputNeededFields,
127+
int offset) {
128+
RexBuilder rexBuilder = originInput.getCluster().getRexBuilder();
129+
RelDataTypeFactory.Builder typeBuilder = relBuilder.getTypeFactory().builder();
130+
List<RexNode> newProjects = new ArrayList<>();
131+
List<String> newFieldNames = new ArrayList<>();
132+
for (int i : inputNeededFields.toList()) {
133+
newProjects.add(rexBuilder.makeInputRef(originInput, i - offset));
134+
newFieldNames.add(originInput.getRowType().getFieldNames().get(i - offset));
135+
typeBuilder.add(originInput.getRowType().getFieldList().get(i - offset));
136+
}
137+
return relBuilder.push(originInput).project(newProjects, newFieldNames).build();
138+
}
139+
140+
private Join createNewJoin(
141+
Join originJoin,
142+
Mappings.TargetMapping mapping,
143+
RelNode newLeftInput,
144+
RelNode newRightInput) {
145+
RexNode newCondition = rewriteJoinCondition(originJoin, mapping);
146+
return LogicalJoin.create(
147+
newLeftInput,
148+
newRightInput,
149+
Collections.emptyList(),
150+
newCondition,
151+
originJoin.getVariablesSet(),
152+
originJoin.getJoinType());
153+
}
154+
155+
private RexNode rewriteJoinCondition(Join originJoin, Mappings.TargetMapping mapping) {
156+
RexBuilder rexBuilder = originJoin.getCluster().getRexBuilder();
157+
RexShuttle rexShuttle =
158+
new RexShuttle() {
159+
@Override
160+
public RexNode visitInputRef(RexInputRef ref) {
161+
int leftFieldCount = originJoin.getLeft().getRowType().getFieldCount();
162+
RelDataType fieldType =
163+
ref.getIndex() < leftFieldCount
164+
? originJoin
165+
.getLeft()
166+
.getRowType()
167+
.getFieldList()
168+
.get(ref.getIndex())
169+
.getType()
170+
: originJoin
171+
.getRight()
172+
.getRowType()
173+
.getFieldList()
174+
.get(ref.getIndex() - leftFieldCount)
175+
.getType();
176+
return rexBuilder.makeInputRef(
177+
fieldType, mapping.getTarget(ref.getIndex()));
178+
}
179+
};
180+
return originJoin.getCondition().accept(rexShuttle);
181+
}
182+
183+
private List<RexNode> createNewProjects(
184+
Project originProject, RelNode newInput, Mappings.TargetMapping mapping) {
185+
RexBuilder rexBuilder = originProject.getCluster().getRexBuilder();
186+
RexShuttle projectShuffle =
187+
new RexShuttle() {
188+
@Override
189+
public RexNode visitInputRef(RexInputRef ref) {
190+
return rexBuilder.makeInputRef(newInput, mapping.getTarget(ref.getIndex()));
191+
}
192+
};
193+
return originProject.getProjects().stream()
194+
.map(p -> p.accept(projectShuffle))
195+
.collect(Collectors.toList());
196+
}
197+
198+
/** Rule configuration. */
199+
@Value.Immutable(singleton = false)
200+
public interface ProjectSemiAntiJoinTransposeRuleConfig extends RelRule.Config {
201+
ProjectSemiAntiJoinTransposeRule.ProjectSemiAntiJoinTransposeRuleConfig DEFAULT =
202+
ImmutableProjectSemiAntiJoinTransposeRule.ProjectSemiAntiJoinTransposeRuleConfig
203+
.builder()
204+
.build()
205+
.withOperandSupplier(
206+
b0 ->
207+
b0.operand(LogicalProject.class)
208+
.inputs(
209+
b1 ->
210+
b1.operand(LogicalJoin.class)
211+
.anyInputs()))
212+
.withDescription("ProjectSemiAntiJoinTransposeRule");
213+
214+
@Override
215+
default ProjectSemiAntiJoinTransposeRule toRule() {
216+
return new ProjectSemiAntiJoinTransposeRule(this);
217+
}
218+
}
219+
}

‎flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/ProjectSemiAntiJoinTransposeRule.scala

-170
This file was deleted.

0 commit comments

Comments
 (0)
Please sign in to comment.