Skip to content

Commit 95f9a16

Browse files
committed
[FLINK-36067][runtime] Support optimize stream graph based on input info.
1 parent 0ce55c3 commit 95f9a16

23 files changed

+724
-101
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java

+25
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public class IntermediateResult {
6363
private final int numParallelProducers;
6464

6565
private final ExecutionPlanSchedulingContext executionPlanSchedulingContext;
66+
private final boolean singleSubpartitionContainsAllData;
6667

6768
private int partitionsAssigned;
6869

@@ -102,6 +103,8 @@ public IntermediateResult(
102103
this.shuffleDescriptorCache = new HashMap<>();
103104

104105
this.executionPlanSchedulingContext = checkNotNull(executionPlanSchedulingContext);
106+
107+
this.singleSubpartitionContainsAllData = intermediateDataSet.isBroadcast();
105108
}
106109

107110
public boolean areAllConsumerVerticesCreated() {
@@ -199,6 +202,16 @@ public DistributionPattern getConsumingDistributionPattern() {
199202
return intermediateDataSet.getDistributionPattern();
200203
}
201204

205+
/**
206+
* Determines whether the associated intermediate data set uses a broadcast distribution
207+
* pattern.
208+
*
209+
* <p>A broadcast distribution pattern indicates that all data produced by this intermediate
210+
* data set should be broadcast to every downstream consumer.
211+
*
212+
* @return true if the intermediate data set is using a broadcast distribution pattern; false
213+
* otherwise.
214+
*/
202215
public boolean isBroadcast() {
203216
return intermediateDataSet.isBroadcast();
204217
}
@@ -207,6 +220,18 @@ public boolean isForward() {
207220
return intermediateDataSet.isForward();
208221
}
209222

223+
/**
224+
* Checks if a single subpartition contains all the produced data. This condition indicate that
225+
* the data was intended to be broadcast to all consumers. If the decision to broadcast was made
226+
* before the data production, this flag would likely be set accordingly. Conversely, if the
227+
* broadcasting decision was made post-production, this flag will be false.
228+
*
229+
* @return true if a single subpartition contains all the data; false otherwise.
230+
*/
231+
public boolean isSingleSubpartitionContainsAllData() {
232+
return singleSubpartitionContainsAllData;
233+
}
234+
210235
public int getConnectionIndex() {
211236
return connectionIndex;
212237
}

flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,21 @@ public interface IntermediateResultInfo {
2929
IntermediateDataSetID getResultId();
3030

3131
/**
32-
* Whether it is a broadcast result.
32+
* Checks whether there is a single subpartition that contains all the produced data.
3333
*
34-
* @return whether it is a broadcast result
34+
* @return true if one subpartition that contains all the data; false otherwise.
35+
*/
36+
boolean isSingleSubpartitionContainsAllData();
37+
38+
/**
39+
* Determines whether the associated intermediate data set uses a broadcast distribution
40+
* pattern.
41+
*
42+
* <p>A broadcast distribution pattern indicates that all data produced by this intermediate
43+
* data set should be broadcast to every downstream consumer.
44+
*
45+
* @return true if the intermediate data set is using a broadcast distribution pattern; false
46+
* otherwise.
3547
*/
3648
boolean isBroadcast();
3749

flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ public int getNumberOfSubpartitions() {
151151
}
152152

153153
private int computeNumberOfSubpartitionsForDynamicGraph() {
154-
if (totalResult.isBroadcast() || totalResult.isForward()) {
154+
if (totalResult.isSingleSubpartitionContainsAllData() || totalResult.isForward()) {
155155
// for dynamic graph and broadcast result, and forward result, we only produced one
156156
// subpartition, and all the downstream vertices should consume this subpartition.
157157
return 1;

flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java

+25-7
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputI
8484
parallelism,
8585
input::getNumSubpartitions,
8686
isDynamicGraph,
87-
input.isBroadcast()));
87+
input.isBroadcast(),
88+
input.isSingleSubpartitionContainsAllData()));
8889
}
8990
}
9091

@@ -124,6 +125,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
124125
1,
125126
() -> numOfSubpartitionsRetriever.apply(start),
126127
isDynamicGraph,
128+
false,
127129
false);
128130
executionVertexInputInfos.add(
129131
new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange));
@@ -145,6 +147,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
145147
numConsumers,
146148
() -> numOfSubpartitionsRetriever.apply(finalPartitionNum),
147149
isDynamicGraph,
150+
false,
148151
false);
149152
executionVertexInputInfos.add(
150153
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
@@ -165,14 +168,16 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
165168
* @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions
166169
* @param isDynamicGraph whether is dynamic graph
167170
* @param isBroadcast whether the edge is broadcast
171+
* @param isSingleSubpartitionContainsAllData whether single subpartition contains all data
168172
* @return the computed {@link JobVertexInputInfo}
169173
*/
170174
static JobVertexInputInfo computeVertexInputInfoForAllToAll(
171175
int sourceCount,
172176
int targetCount,
173177
Function<Integer, Integer> numOfSubpartitionsRetriever,
174178
boolean isDynamicGraph,
175-
boolean isBroadcast) {
179+
boolean isBroadcast,
180+
boolean isSingleSubpartitionContainsAllData) {
176181
final List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>();
177182
IndexRange partitionRange = new IndexRange(0, sourceCount - 1);
178183
for (int i = 0; i < targetCount; ++i) {
@@ -182,7 +187,8 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
182187
targetCount,
183188
() -> numOfSubpartitionsRetriever.apply(0),
184189
isDynamicGraph,
185-
isBroadcast);
190+
isBroadcast,
191+
isSingleSubpartitionContainsAllData);
186192
executionVertexInputInfos.add(
187193
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
188194
}
@@ -199,6 +205,7 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
199205
* @param numOfSubpartitionsSupplier a supplier to get the number of subpartitions
200206
* @param isDynamicGraph whether is dynamic graph
201207
* @param isBroadcast whether the edge is broadcast
208+
* @param isSingleSubpartitionContainsAllData whether single subpartition contains all data
202209
* @return the computed subpartition range
203210
*/
204211
@VisibleForTesting
@@ -207,16 +214,22 @@ static IndexRange computeConsumedSubpartitionRange(
207214
int numConsumers,
208215
Supplier<Integer> numOfSubpartitionsSupplier,
209216
boolean isDynamicGraph,
210-
boolean isBroadcast) {
217+
boolean isBroadcast,
218+
boolean isSingleSubpartitionContainsAllData) {
211219
int consumerIndex = consumerSubtaskIndex % numConsumers;
212220
if (!isDynamicGraph) {
213221
return new IndexRange(consumerIndex, consumerIndex);
214222
} else {
215223
int numSubpartitions = numOfSubpartitionsSupplier.get();
216224
if (isBroadcast) {
217-
// broadcast results have only one subpartition, and be consumed multiple times.
218-
checkArgument(numSubpartitions == 1);
219-
return new IndexRange(0, 0);
225+
if (isSingleSubpartitionContainsAllData) {
226+
// early decided broadcast results have only one subpartition, and be consumed
227+
// multiple times.
228+
checkArgument(numSubpartitions == 1);
229+
return new IndexRange(0, 0);
230+
} else {
231+
return new IndexRange(0, numSubpartitions - 1);
232+
}
220233
} else {
221234
checkArgument(consumerIndex < numConsumers);
222235
checkArgument(numConsumers <= numSubpartitions);
@@ -246,6 +259,11 @@ public boolean isBroadcast() {
246259
return intermediateResult.isBroadcast();
247260
}
248261

262+
@Override
263+
public boolean isSingleSubpartitionContainsAllData() {
264+
return intermediateResult.isSingleSubpartitionContainsAllData();
265+
}
266+
249267
@Override
250268
public boolean isPointwise() {
251269
return intermediateResult.getConsumingDistributionPattern()

flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java

+12
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,18 @@ public void configure(
134134
}
135135
}
136136

137+
public void updateOutputPattern(
138+
DistributionPattern distributionPattern, boolean isBroadcast, boolean isForward) {
139+
checkState(consumers.isEmpty(), "The output job edges have already been added.");
140+
checkState(
141+
numJobEdgesToCreate == 1,
142+
"Modification is not allowed when the subscribing output is reused.");
143+
144+
this.distributionPattern = distributionPattern;
145+
this.isBroadcast = isBroadcast;
146+
this.isForward = isForward;
147+
}
148+
137149
public void increaseNumJobEdgesToCreate() {
138150
this.numJobEdgesToCreate++;
139151
}

flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
2323
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
2424

25+
import java.util.Collections;
2526
import java.util.HashMap;
2627
import java.util.Map;
2728

@@ -44,11 +45,14 @@ abstract class AbstractBlockingResultInfo implements BlockingResultInfo {
4445
protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex;
4546

4647
AbstractBlockingResultInfo(
47-
IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) {
48+
IntermediateDataSetID resultId,
49+
int numOfPartitions,
50+
int numOfSubpartitions,
51+
Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
4852
this.resultId = checkNotNull(resultId);
4953
this.numOfPartitions = numOfPartitions;
5054
this.numOfSubpartitions = numOfSubpartitions;
51-
this.subpartitionBytesByPartitionIndex = new HashMap<>();
55+
this.subpartitionBytesByPartitionIndex = new HashMap<>(subpartitionBytesByPartitionIndex);
5256
}
5357

5458
@Override
@@ -72,4 +76,9 @@ public void resetPartitionInfo(int partitionIndex) {
7276
int getNumOfRecordedPartitions() {
7377
return subpartitionBytesByPartitionIndex.size();
7478
}
79+
80+
@Override
81+
public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() {
82+
return Collections.unmodifiableMap(subpartitionBytesByPartitionIndex);
83+
}
7584
}

flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java

+63-5
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,14 @@ public void onNewJobVerticesAdded(List<JobVertex> newVertices, int pendingOperat
274274
// 4. update json plan
275275
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));
276276

277-
// 5. try aggregate subpartition bytes
277+
// 5. In broadcast join optimization, results might be written first with a hash
278+
// method and then read with a broadcast method. Therefore, we need to update the
279+
// result info:
280+
// 1. Update the DistributionPattern to reflect the optimized data distribution.
281+
// 2. Aggregate subpartition bytes when possible for efficiency.
278282
for (JobVertex newVertex : newVertices) {
279283
for (JobEdge input : newVertex.getInputs()) {
284+
tryUpdateResultInfo(input.getSourceId(), input.getDistributionPattern());
280285
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
281286
.ifPresent(this::maybeAggregateSubpartitionBytes);
282287
}
@@ -490,7 +495,8 @@ private void updateResultPartitionBytesMetrics(
490495
result.getId(),
491496
(ignored, resultInfo) -> {
492497
if (resultInfo == null) {
493-
resultInfo = createFromIntermediateResult(result);
498+
resultInfo =
499+
createFromIntermediateResult(result, new HashMap<>());
494500
}
495501
resultInfo.recordPartitionInfo(
496502
partitionId.getPartitionNumber(), partitionBytes);
@@ -500,6 +506,16 @@ private void updateResultPartitionBytesMetrics(
500506
});
501507
}
502508

509+
/**
510+
* Aggregates subpartition bytes if all conditions are met. This method checks whether the
511+
* result info instance is of type {@link AllToAllBlockingResultInfo}, whether all consumer
512+
* vertices are created, and whether all consumer vertices are initialized. If these conditions
513+
* are satisfied, the fine-grained statistic info will not be required by consumer vertices, and
514+
* then we could aggregate the subpartition bytes.
515+
*
516+
* @param resultInfo the BlockingResultInfo instance to potentially aggregate subpartition bytes
517+
* for.
518+
*/
503519
private void maybeAggregateSubpartitionBytes(BlockingResultInfo resultInfo) {
504520
IntermediateResult intermediateResult =
505521
getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId());
@@ -937,21 +953,24 @@ private static void resetDynamicParallelism(Iterable<JobVertex> vertices) {
937953
}
938954
}
939955

940-
private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) {
956+
private static BlockingResultInfo createFromIntermediateResult(
957+
IntermediateResult result, Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
941958
checkArgument(result != null);
942959
// Note that for dynamic graph, different partitions in the same result have the same number
943960
// of subpartitions.
944961
if (result.getConsumingDistributionPattern() == DistributionPattern.POINTWISE) {
945962
return new PointwiseBlockingResultInfo(
946963
result.getId(),
947964
result.getNumberOfAssignedPartitions(),
948-
result.getPartitions()[0].getNumberOfSubpartitions());
965+
result.getPartitions()[0].getNumberOfSubpartitions(),
966+
subpartitionBytesByPartitionIndex);
949967
} else {
950968
return new AllToAllBlockingResultInfo(
951969
result.getId(),
952970
result.getNumberOfAssignedPartitions(),
953971
result.getPartitions()[0].getNumberOfSubpartitions(),
954-
result.isBroadcast());
972+
result.isSingleSubpartitionContainsAllData(),
973+
subpartitionBytesByPartitionIndex);
955974
}
956975
}
957976

@@ -965,6 +984,45 @@ SpeculativeExecutionHandler getSpeculativeExecutionHandler() {
965984
return speculativeExecutionHandler;
966985
}
967986

987+
/**
988+
* Tries to update the result information for a given IntermediateDataSetID according to the
989+
* specified DistributionPattern. This ensures consistency between the distribution pattern and
990+
* the stored result information.
991+
*
992+
* <p>The result information is updated under the following conditions:
993+
*
994+
* <ul>
995+
* <li>If the target pattern is ALL_TO_ALL and the current result info is POINTWISE, a new
996+
* BlockingResultInfo is created and stored.
997+
* <li>If the target pattern is POINTWISE and the current result info is ALL_TO_ALL, a
998+
* conversion is similarly triggered.
999+
* <li>Additionally, for ALL_TO_ALL patterns, the status of broadcast of the result info
1000+
* should be updated.
1001+
* </ul>
1002+
*
1003+
* @param id The ID of the intermediate dataset to update.
1004+
* @param targetPattern The target distribution pattern to apply.
1005+
*/
1006+
private void tryUpdateResultInfo(IntermediateDataSetID id, DistributionPattern targetPattern) {
1007+
if (blockingResultInfos.containsKey(id)) {
1008+
BlockingResultInfo resultInfo = blockingResultInfos.get(id);
1009+
IntermediateResult result = getExecutionGraph().getAllIntermediateResults().get(id);
1010+
1011+
if ((targetPattern == DistributionPattern.ALL_TO_ALL && resultInfo.isPointwise())
1012+
|| (targetPattern == DistributionPattern.POINTWISE
1013+
&& !resultInfo.isPointwise())) {
1014+
1015+
BlockingResultInfo newInfo =
1016+
createFromIntermediateResult(
1017+
result, resultInfo.getSubpartitionBytesByPartitionIndex());
1018+
1019+
blockingResultInfos.put(id, newInfo);
1020+
} else if (resultInfo instanceof AllToAllBlockingResultInfo) {
1021+
((AllToAllBlockingResultInfo) resultInfo).setBroadcast(result.isBroadcast());
1022+
}
1023+
}
1024+
}
1025+
9681026
private class DefaultBatchJobRecoveryContext implements BatchJobRecoveryContext {
9691027

9701028
private final FailoverStrategy restartStrategyOnResultConsumable =

flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.runtime.jobgraph.JobGraph;
2222
import org.apache.flink.streaming.api.graph.ExecutionPlan;
2323
import org.apache.flink.streaming.api.graph.StreamGraph;
24+
import org.apache.flink.util.DynamicCodeLoadingException;
2425

2526
import java.util.concurrent.Executor;
2627

@@ -46,7 +47,8 @@ public class AdaptiveExecutionHandlerFactory {
4647
public static AdaptiveExecutionHandler create(
4748
ExecutionPlan executionPlan,
4849
ClassLoader userClassLoader,
49-
Executor serializationExecutor) {
50+
Executor serializationExecutor)
51+
throws DynamicCodeLoadingException {
5052
if (executionPlan instanceof JobGraph) {
5153
return new NonAdaptiveExecutionHandler((JobGraph) executionPlan);
5254
} else {

0 commit comments

Comments
 (0)