@@ -141,25 +141,21 @@ def _run_early_stopping_check(self, trainer, pl_module):
141
141
if not isinstance (current , torch .Tensor ):
142
142
current = torch .tensor (current )
143
143
144
+ # in ddp, reduce the stopping metric so every process conditions the same
145
+ if trainer .use_ddp or trainer .use_ddp2 :
146
+ print (f'RANK: { trainer .global_rank } , BEFORE: { current } ' )
147
+ current = current .to (pl_module .device )
148
+ dist .all_reduce (current , op = dist .reduce_op .MAX )
149
+ print (f'RANK: { trainer .global_rank } , AFTER: { current } ' )
150
+
144
151
if self .monitor_op (current - self .min_delta , self .best_score ):
145
152
self .best_score = current
146
153
self .wait_count = 0
147
154
else :
148
155
self .wait_count += 1
149
156
should_stop = self .wait_count >= self .patience
150
157
151
- # check flag across all GPUs
152
- should_stop = torch .tensor (int (should_stop ), device = pl_module .device )
153
- if trainer .use_ddp or trainer .use_ddp2 :
154
- print (f'RANK: { trainer .global_rank } REDUCING...' )
155
- dist .all_reduce (should_stop , op = dist .reduce_op .MAX )
156
- print (f'RANK: { trainer .global_rank } REDUCED...' )
157
-
158
- print (f'RANK: { trainer .global_rank } SHOULD STOP: { should_stop } BEST: { self .best_score } ' )
159
-
160
- # do actual stop
161
- print (f'RANK: { trainer .global_rank } , SHOULD STOP: { should_stop } , EPOCH: { trainer .current_epoch } ' )
162
- if bool (should_stop .item ()):
158
+ if bool (should_stop ):
163
159
print (f'RANK: { trainer .global_rank } , STOPPING...' )
164
160
self .stopped_epoch = trainer .current_epoch
165
161
trainer .should_stop = True
0 commit comments