27
27
import java .util .Arrays ;
28
28
import java .util .Collections ;
29
29
import java .util .List ;
30
+ import java .util .Optional ;
30
31
import java .util .stream .Collectors ;
31
32
32
33
import static org .apache .flink .util .Preconditions .checkState ;
@@ -74,18 +75,28 @@ public int getNumSubpartitions(int partitionIndex) {
74
75
75
76
@ Override
76
77
public long getNumBytesProduced () {
77
- checkState (aggregatedSubpartitionBytes != null , "Not all partition infos are ready" );
78
+ checkState (
79
+ aggregatedSubpartitionBytes != null
80
+ || subpartitionBytesByPartitionIndex .size () == numOfPartitions ,
81
+ "Not all partition infos are ready" );
82
+
83
+ List <Long > bytes =
84
+ Optional .ofNullable (aggregatedSubpartitionBytes )
85
+ .orElse (getAggregatedSubpartitionBytesInternal ());
78
86
if (isBroadcast ) {
79
- return aggregatedSubpartitionBytes .get (0 );
87
+ return bytes .get (0 );
80
88
} else {
81
- return aggregatedSubpartitionBytes .stream ().reduce (0L , Long ::sum );
89
+ return bytes .stream ().reduce (0L , Long ::sum );
82
90
}
83
91
}
84
92
85
93
@ Override
86
94
public long getNumBytesProduced (
87
95
IndexRange partitionIndexRange , IndexRange subpartitionIndexRange ) {
88
- checkState (aggregatedSubpartitionBytes != null , "Not all partition infos are ready" );
96
+ List <Long > bytes =
97
+ Optional .ofNullable (aggregatedSubpartitionBytes )
98
+ .orElse (getAggregatedSubpartitionBytesInternal ());
99
+
89
100
checkState (
90
101
partitionIndexRange .getStartIndex () == 0
91
102
&& partitionIndexRange .getEndIndex () == numOfPartitions - 1 ,
@@ -96,7 +107,7 @@ public long getNumBytesProduced(
96
107
"Subpartition index %s is out of range." ,
97
108
subpartitionIndexRange .getEndIndex ());
98
109
99
- return aggregatedSubpartitionBytes
110
+ return bytes
100
111
.subList (
101
112
subpartitionIndexRange .getStartIndex (),
102
113
subpartitionIndexRange .getEndIndex () + 1 )
@@ -106,31 +117,56 @@ public long getNumBytesProduced(
106
117
107
118
@ Override
108
119
public void recordPartitionInfo (int partitionIndex , ResultPartitionBytes partitionBytes ) {
109
- // Once all partitions are finished, we can convert the subpartition bytes to aggregated
110
- // value to reduce the space usage, because the distribution of source splits does not
111
- // affect the distribution of data consumed by downstream tasks of ALL_TO_ALL edges(Hashing
112
- // or Rebalancing, we do not consider rare cases such as custom partitions here).
113
120
if (aggregatedSubpartitionBytes == null ) {
114
121
super .recordPartitionInfo (partitionIndex , partitionBytes );
122
+ }
123
+ }
115
124
116
- if (subpartitionBytesByPartitionIndex .size () == numOfPartitions ) {
117
- long [] aggregatedBytes = new long [numOfSubpartitions ];
118
- subpartitionBytesByPartitionIndex
119
- .values ()
120
- .forEach (
121
- subpartitionBytes -> {
122
- checkState (subpartitionBytes .length == numOfSubpartitions );
123
- for (int i = 0 ; i < subpartitionBytes .length ; ++i ) {
124
- aggregatedBytes [i ] += subpartitionBytes [i ];
125
- }
126
- });
127
- this .aggregatedSubpartitionBytes =
128
- Arrays .stream (aggregatedBytes ).boxed ().collect (Collectors .toList ());
129
- this .subpartitionBytesByPartitionIndex .clear ();
130
- }
125
+ /**
126
+ * Aggregates the subpartition bytes to reduce space usage.
127
+ *
128
+ * <p>Once all partitions are finished and all consumer jobVertices are initialized, we can
129
+ * convert the subpartition bytes to aggregated value to reduce the space usage, because the
130
+ * distribution of source splits does not affect the distribution of data consumed by downstream
131
+ * tasks of ALL_TO_ALL edges(Hashing or Rebalancing, we do not consider rare cases such as
132
+ * custom partitions here).
133
+ */
134
+ protected void aggregateSubpartitionBytes () {
135
+ if (subpartitionBytesByPartitionIndex .size () == numOfPartitions ) {
136
+ this .aggregatedSubpartitionBytes = getAggregatedSubpartitionBytesInternal ();
137
+ this .subpartitionBytesByPartitionIndex .clear ();
131
138
}
132
139
}
133
140
141
+ /**
142
+ * Aggregates the bytes of subpartitions across all partition indices without modifying the
143
+ * existing state. This method is intended for querying purposes only.
144
+ *
145
+ * <p>The method computes the sum of the bytes for each subpartition across all partitions and
146
+ * returns a list containing these summed values.
147
+ *
148
+ * <p>This method is needed in scenarios where aggregated results are required, but fine-grained
149
+ * statistics should remain not aggregated. Specifically, when not all consumer vertices of this
150
+ * result info are created or initialized, this result info could not be aggregated. And the
151
+ * existing consumer vertices of this info still require these aggregated result for scheduling.
152
+ *
153
+ * @return a list of aggregated byte counts for each subpartition.
154
+ */
155
+ private List <Long > getAggregatedSubpartitionBytesInternal () {
156
+ long [] aggregatedBytes = new long [numOfSubpartitions ];
157
+ subpartitionBytesByPartitionIndex
158
+ .values ()
159
+ .forEach (
160
+ subpartitionBytes -> {
161
+ checkState (subpartitionBytes .length == numOfSubpartitions );
162
+ for (int i = 0 ; i < subpartitionBytes .length ; ++i ) {
163
+ aggregatedBytes [i ] += subpartitionBytes [i ];
164
+ }
165
+ });
166
+
167
+ return Arrays .stream (aggregatedBytes ).boxed ().collect (Collectors .toList ());
168
+ }
169
+
134
170
@ Override
135
171
public void resetPartitionInfo (int partitionIndex ) {
136
172
if (aggregatedSubpartitionBytes == null ) {
@@ -139,7 +175,14 @@ public void resetPartitionInfo(int partitionIndex) {
139
175
}
140
176
141
177
public List <Long > getAggregatedSubpartitionBytes () {
142
- checkState (aggregatedSubpartitionBytes != null , "Not all partition infos are ready" );
143
- return Collections .unmodifiableList (aggregatedSubpartitionBytes );
178
+ checkState (
179
+ aggregatedSubpartitionBytes != null
180
+ || subpartitionBytesByPartitionIndex .size () == numOfPartitions ,
181
+ "Not all partition infos are ready" );
182
+ if (aggregatedSubpartitionBytes == null ) {
183
+ return getAggregatedSubpartitionBytesInternal ();
184
+ } else {
185
+ return Collections .unmodifiableList (aggregatedSubpartitionBytes );
186
+ }
144
187
}
145
188
}
0 commit comments