Skip to content

Commit 0ce55c3

Browse files
committed
[FLINK-36067][runtime] Manually trigger aggregate all-to-all result partition info when all consumers created and initialized.
1 parent 39ac9e5 commit 0ce55c3

File tree

4 files changed

+100
-26
lines changed

4 files changed

+100
-26
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java

+23
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ public void onNewJobVerticesAdded(List<JobVertex> newVertices, int pendingOperat
273273

274274
// 4. update json plan
275275
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));
276+
277+
// 5. try aggregate subpartition bytes
278+
for (JobVertex newVertex : newVertices) {
279+
for (JobEdge input : newVertex.getInputs()) {
280+
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
281+
.ifPresent(this::maybeAggregateSubpartitionBytes);
282+
}
283+
}
276284
}
277285

278286
@Override
@@ -486,11 +494,25 @@ private void updateResultPartitionBytesMetrics(
486494
}
487495
resultInfo.recordPartitionInfo(
488496
partitionId.getPartitionNumber(), partitionBytes);
497+
maybeAggregateSubpartitionBytes(resultInfo);
489498
return resultInfo;
490499
});
491500
});
492501
}
493502

503+
private void maybeAggregateSubpartitionBytes(BlockingResultInfo resultInfo) {
504+
IntermediateResult intermediateResult =
505+
getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId());
506+
507+
if (resultInfo instanceof AllToAllBlockingResultInfo
508+
&& intermediateResult.areAllConsumerVerticesCreated()
509+
&& intermediateResult.getConsumerVertices().stream()
510+
.map(this::getExecutionJobVertex)
511+
.allMatch(ExecutionJobVertex::isInitialized)) {
512+
((AllToAllBlockingResultInfo) resultInfo).aggregateSubpartitionBytes();
513+
}
514+
}
515+
494516
@Override
495517
public void allocateSlotsAndDeploy(final List<ExecutionVertexID> verticesToDeploy) {
496518
List<ExecutionVertex> executionVertices =
@@ -657,6 +679,7 @@ public void initializeVerticesIfPossible() {
657679
parallelismAndInputInfos.getJobVertexInputInfos(),
658680
createTimestamp);
659681
newlyInitializedJobVertices.add(jobVertex);
682+
consumedResultsInfo.get().forEach(this::maybeAggregateSubpartitionBytes);
660683
}
661684
}
662685
}

flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java

+69-26
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.Arrays;
2828
import java.util.Collections;
2929
import java.util.List;
30+
import java.util.Optional;
3031
import java.util.stream.Collectors;
3132

3233
import static org.apache.flink.util.Preconditions.checkState;
@@ -74,18 +75,28 @@ public int getNumSubpartitions(int partitionIndex) {
7475

7576
@Override
7677
public long getNumBytesProduced() {
77-
checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
78+
checkState(
79+
aggregatedSubpartitionBytes != null
80+
|| subpartitionBytesByPartitionIndex.size() == numOfPartitions,
81+
"Not all partition infos are ready");
82+
83+
List<Long> bytes =
84+
Optional.ofNullable(aggregatedSubpartitionBytes)
85+
.orElse(getAggregatedSubpartitionBytesInternal());
7886
if (isBroadcast) {
79-
return aggregatedSubpartitionBytes.get(0);
87+
return bytes.get(0);
8088
} else {
81-
return aggregatedSubpartitionBytes.stream().reduce(0L, Long::sum);
89+
return bytes.stream().reduce(0L, Long::sum);
8290
}
8391
}
8492

8593
@Override
8694
public long getNumBytesProduced(
8795
IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
88-
checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
96+
List<Long> bytes =
97+
Optional.ofNullable(aggregatedSubpartitionBytes)
98+
.orElse(getAggregatedSubpartitionBytesInternal());
99+
89100
checkState(
90101
partitionIndexRange.getStartIndex() == 0
91102
&& partitionIndexRange.getEndIndex() == numOfPartitions - 1,
@@ -96,7 +107,7 @@ public long getNumBytesProduced(
96107
"Subpartition index %s is out of range.",
97108
subpartitionIndexRange.getEndIndex());
98109

99-
return aggregatedSubpartitionBytes
110+
return bytes
100111
.subList(
101112
subpartitionIndexRange.getStartIndex(),
102113
subpartitionIndexRange.getEndIndex() + 1)
@@ -106,31 +117,56 @@ public long getNumBytesProduced(
106117

107118
@Override
108119
public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {
109-
// Once all partitions are finished, we can convert the subpartition bytes to aggregated
110-
// value to reduce the space usage, because the distribution of source splits does not
111-
// affect the distribution of data consumed by downstream tasks of ALL_TO_ALL edges(Hashing
112-
// or Rebalancing, we do not consider rare cases such as custom partitions here).
113120
if (aggregatedSubpartitionBytes == null) {
114121
super.recordPartitionInfo(partitionIndex, partitionBytes);
122+
}
123+
}
115124

116-
if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) {
117-
long[] aggregatedBytes = new long[numOfSubpartitions];
118-
subpartitionBytesByPartitionIndex
119-
.values()
120-
.forEach(
121-
subpartitionBytes -> {
122-
checkState(subpartitionBytes.length == numOfSubpartitions);
123-
for (int i = 0; i < subpartitionBytes.length; ++i) {
124-
aggregatedBytes[i] += subpartitionBytes[i];
125-
}
126-
});
127-
this.aggregatedSubpartitionBytes =
128-
Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
129-
this.subpartitionBytesByPartitionIndex.clear();
130-
}
125+
/**
126+
* Aggregates the subpartition bytes to reduce space usage.
127+
*
128+
* <p>Once all partitions are finished and all consumer jobVertices are initialized, we can
129+
* convert the subpartition bytes to aggregated value to reduce the space usage, because the
130+
* distribution of source splits does not affect the distribution of data consumed by downstream
131+
* tasks of ALL_TO_ALL edges(Hashing or Rebalancing, we do not consider rare cases such as
132+
* custom partitions here).
133+
*/
134+
protected void aggregateSubpartitionBytes() {
135+
if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) {
136+
this.aggregatedSubpartitionBytes = getAggregatedSubpartitionBytesInternal();
137+
this.subpartitionBytesByPartitionIndex.clear();
131138
}
132139
}
133140

141+
/**
142+
* Aggregates the bytes of subpartitions across all partition indices without modifying the
143+
* existing state. This method is intended for querying purposes only.
144+
*
145+
* <p>The method computes the sum of the bytes for each subpartition across all partitions and
146+
* returns a list containing these summed values.
147+
*
148+
* <p>This method is needed in scenarios where aggregated results are required, but fine-grained
149+
* statistics should remain not aggregated. Specifically, when not all consumer vertices of this
150+
* result info are created or initialized, this result info could not be aggregated. And the
151+
* existing consumer vertices of this info still require these aggregated result for scheduling.
152+
*
153+
* @return a list of aggregated byte counts for each subpartition.
154+
*/
155+
private List<Long> getAggregatedSubpartitionBytesInternal() {
156+
long[] aggregatedBytes = new long[numOfSubpartitions];
157+
subpartitionBytesByPartitionIndex
158+
.values()
159+
.forEach(
160+
subpartitionBytes -> {
161+
checkState(subpartitionBytes.length == numOfSubpartitions);
162+
for (int i = 0; i < subpartitionBytes.length; ++i) {
163+
aggregatedBytes[i] += subpartitionBytes[i];
164+
}
165+
});
166+
167+
return Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
168+
}
169+
134170
@Override
135171
public void resetPartitionInfo(int partitionIndex) {
136172
if (aggregatedSubpartitionBytes == null) {
@@ -139,7 +175,14 @@ public void resetPartitionInfo(int partitionIndex) {
139175
}
140176

141177
public List<Long> getAggregatedSubpartitionBytes() {
142-
checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
143-
return Collections.unmodifiableList(aggregatedSubpartitionBytes);
178+
checkState(
179+
aggregatedSubpartitionBytes != null
180+
|| subpartitionBytesByPartitionIndex.size() == numOfPartitions,
181+
"Not all partition infos are ready");
182+
if (aggregatedSubpartitionBytes == null) {
183+
return getAggregatedSubpartitionBytesInternal();
184+
} else {
185+
return Collections.unmodifiableList(aggregatedSubpartitionBytes);
186+
}
144187
}
145188
}

flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java

+3
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ void testRecordPartitionInfoMultiTimes() {
9999
// The result info should be (partitionBytes2 + partitionBytes3)
100100
assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L);
101101
assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 384L);
102+
// The raw info should not be clear
103+
assertThat(resultInfo.getNumOfRecordedPartitions()).isGreaterThan(0);
104+
resultInfo.aggregateSubpartitionBytes();
102105
// The raw info should be clear
103106
assertThat(resultInfo.getNumOfRecordedPartitions()).isZero();
104107

flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java

+5
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,11 @@ public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partiti
609609

610610
@Override
611611
public void resetPartitionInfo(int partitionIndex) {}
612+
613+
@Override
614+
public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() {
615+
return Map.of();
616+
}
612617
}
613618

614619
private static BlockingResultInfo createFromBroadcastResult(long producedBytes) {

0 commit comments

Comments
 (0)