@@ -415,11 +415,16 @@ def train(self):
415
415
416
416
self .run_training_teardown ()
417
417
418
- def run_training_epoch (self ):
418
+ def prepare_train_loop_dataloader (self , train_dataloader ):
419
+ # on TPU we have to wrap it under the ParallelLoader
420
+ if self .use_tpu :
421
+ device = xm .xla_device (self .tpu_id )
422
+ train_dataloader = xla_pl .ParallelLoader (train_dataloader , [device ])
423
+ train_dataloader = train_dataloader .per_device_loader (device )
419
424
420
- # get model
421
- model = self .get_model ()
425
+ return train_dataloader
422
426
427
+ def run_on_epoch_start_hook (self , model ):
423
428
# Epoch start events
424
429
with self .profiler .profile ('on_epoch_start' ):
425
430
# callbacks
@@ -429,17 +434,19 @@ def run_training_epoch(self):
429
434
if self .is_function_implemented ('on_epoch_start' ):
430
435
model .on_epoch_start ()
431
436
432
- # track local dataloader so TPU can wrap each epoch
433
- train_dataloader = self .train_dataloader
437
+ def run_training_epoch (self ):
434
438
435
- # on TPU we have to wrap it under the ParallelLoader
436
- if self .use_tpu :
437
- device = xm .xla_device (self .tpu_id )
438
- train_dataloader = xla_pl .ParallelLoader (train_dataloader , [device ])
439
- train_dataloader = train_dataloader .per_device_loader (device )
439
+ # get model
440
+ model = self .get_model ()
441
+
442
+ # Epoch start events
443
+ self .run_on_epoch_start_hook (model )
444
+
445
+ # modify dataloader if needed (ddp, etc...)
446
+ train_dataloader = self .prepare_train_loop_dataloader (self .train_dataloader )
440
447
441
448
# bookkeeping
442
- outputs = []
449
+ epoch_output = []
443
450
444
451
# run epoch
445
452
for batch_idx , (batch , is_last_batch ) in self .profiler .profile_iterable (
@@ -450,63 +457,41 @@ def run_training_epoch(self):
450
457
break
451
458
452
459
self .batch_idx = batch_idx
453
-
454
460
model .global_step = self .global_step
455
461
456
- # ---------------
457
- # RUN TRAIN STEP
458
- # ---------------
459
- _outputs = self .run_training_batch (batch , batch_idx )
460
- batch_result , grad_norm_dic , batch_step_metrics , batch_output = _outputs
462
+ # ------------------------------------
463
+ # TRAINING_STEP + TRAINING_STEP_END
464
+ # ------------------------------------
465
+ batch_output = self .run_training_batch (batch , batch_idx )
461
466
462
467
# only track outputs when user implements training_epoch_end
463
468
# otherwise we will build up unnecessary memory
464
469
if self .is_overridden ('training_epoch_end' , model = self .get_model ()):
465
- outputs .append (batch_output )
470
+ epoch_output .append (batch_output .training_step_output_for_epoch_end )
471
+
472
+ # update LR schedulers
473
+ self .update_train_loop_lr_schedulers ()
466
474
467
475
# when returning -1 from train_step, we end epoch early
468
- early_stop_epoch = batch_result == - 1
469
-
470
- # TODO: consolidate all actions that need to take place only after
471
- # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment)
472
- if (self .batch_idx + 1 ) % self .accumulate_grad_batches == 0 :
473
- # update lr
474
- self .update_learning_rates (interval = 'step' )
475
-
476
- # ---------------
477
- # RUN VAL STEP
478
- # ---------------
479
- is_val_check_batch = (batch_idx + 1 ) % self .val_check_batch == 0
480
- can_check_epoch = (self .current_epoch + 1 ) % self .check_val_every_n_epoch == 0
481
- can_check_val = not self .disable_validation and can_check_epoch
482
- should_check_val = is_val_check_batch or early_stop_epoch
483
- should_check_val = should_check_val or (is_last_batch and self .val_check_batch == float ('inf' ))
484
- should_check_val = can_check_val and should_check_val
485
-
486
- # ---------------
487
- # CHECKPOINTING, EARLY STOPPING
488
- # ---------------
489
- # fast_dev_run always forces val checking after train batch
490
- if self .fast_dev_run or should_check_val :
491
- self .run_evaluation (test_mode = self .testing )
492
- self .call_checkpoint_callback ()
493
-
494
- # when logs should be saved
495
- should_save_log = (batch_idx + 1 ) % self .log_save_interval == 0 or early_stop_epoch
496
- if should_save_log or self .fast_dev_run :
497
- if self .is_global_zero and self .logger is not None :
498
- self .logger .save ()
499
-
500
- # when metrics should be logged
501
- should_log_metrics = batch_idx % self .row_log_interval == 0 or early_stop_epoch
502
- if should_log_metrics or self .fast_dev_run :
503
- # logs user requested information to logger
504
- self .log_metrics (batch_step_metrics , grad_norm_dic )
476
+ early_stop_epoch = batch_output .signal == - 1
477
+
478
+ # -----------------------------------------
479
+ # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
480
+ # -----------------------------------------
481
+ should_check_val = self .check_validation_in_train_loop (batch_idx , early_stop_epoch , is_last_batch )
482
+
483
+ # -----------------------------------------
484
+ # SAVE LOGGERS (ie: Tensorboard, etc...)
485
+ # -----------------------------------------
486
+ self .save_loggers_in_training_loop (batch_idx , early_stop_epoch )
487
+
488
+ # -----------------------------------------
489
+ # SAVE METRICS TO LOGGERS
490
+ # -----------------------------------------
491
+ self .save_train_loop_metrics_to_loggers (batch_idx , early_stop_epoch , batch_output )
505
492
506
493
# progress global step according to grads progress
507
- if (self .batch_idx + 1 ) % self .accumulate_grad_batches == 0 :
508
- self .global_step += 1
509
- self .total_batch_idx += 1
494
+ self .increment_accumulated_grad_global_step ()
510
495
511
496
# max steps reached, end training
512
497
if self .max_steps is not None and self .max_steps == self .global_step :
@@ -518,13 +503,36 @@ def run_training_epoch(self):
518
503
if early_stop_epoch or self .fast_dev_run :
519
504
break
520
505
521
- if self . use_horovod :
522
- hvd . join ( hvd . local_rank () if self .on_gpu else - 1 )
506
+ # let ddp devices catch up when using horovod
507
+ self .sync_horovod ( )
523
508
524
509
# process epoch outputs
510
+ self .run_training_epoch_end (epoch_output )
511
+
512
+ # when no val loop is present or fast-dev-run still need to call checkpoints
513
+ if not self .is_overridden ('validation_step' ) and not (self .fast_dev_run or should_check_val ):
514
+ self .call_checkpoint_callback ()
515
+
516
+ # epoch end hook
517
+ self .run_on_epoch_end_hook (model )
518
+
519
+ def update_train_loop_lr_schedulers (self ):
520
+ if (self .batch_idx + 1 ) % self .accumulate_grad_batches == 0 :
521
+ # update lr
522
+ self .update_learning_rates (interval = 'step' )
523
+
524
+ def run_on_epoch_end_hook (self , model ):
525
+ with self .profiler .profile ('on_epoch_end' ):
526
+ # callbacks
527
+ self .on_epoch_end ()
528
+ # model hooks
529
+ if self .is_function_implemented ('on_epoch_end' ):
530
+ model .on_epoch_end ()
531
+
532
+ def run_training_epoch_end (self , epoch_output ):
525
533
model = self .get_model ()
526
534
if self .is_overridden ('training_epoch_end' , model = model ):
527
- epoch_output = model .training_epoch_end (outputs )
535
+ epoch_output = model .training_epoch_end (epoch_output )
528
536
_processed_outputs = self .process_output (epoch_output )
529
537
log_epoch_metrics = _processed_outputs [2 ]
530
538
callback_epoch_metrics = _processed_outputs [3 ]
@@ -538,17 +546,45 @@ def run_training_epoch(self):
538
546
# add metrics to progress_bar
539
547
self .add_progress_bar_metrics (_processed_outputs [1 ])
540
548
541
- # when no val loop is present or fast-dev-run still need to call checkpoints
542
- if not self .is_overridden ('validation_step' ) and not (self .fast_dev_run or should_check_val ):
549
+ def sync_horovod (self ):
550
+ if self .use_horovod :
551
+ hvd .join (hvd .local_rank () if self .on_gpu else - 1 )
552
+
553
+ def increment_accumulated_grad_global_step (self ):
554
+ # progress global step according to grads progress
555
+ if (self .batch_idx + 1 ) % self .accumulate_grad_batches == 0 :
556
+ self .global_step += 1
557
+ self .total_batch_idx += 1
558
+
559
+ def save_train_loop_metrics_to_loggers (self , batch_idx , early_stop_epoch , batch_output ):
560
+ # when metrics should be logged
561
+ should_log_metrics = batch_idx % self .row_log_interval == 0 or early_stop_epoch
562
+ if should_log_metrics or self .fast_dev_run :
563
+ # logs user requested information to logger
564
+ self .log_metrics (batch_output .batch_log_metrics , batch_output .grad_norm_dic )
565
+
566
+ def save_loggers_in_training_loop (self , batch_idx , early_stop_epoch ):
567
+ # when loggers should save to disk
568
+ should_save_log = (batch_idx + 1 ) % self .log_save_interval == 0 or early_stop_epoch
569
+ if should_save_log or self .fast_dev_run :
570
+ if self .is_global_zero and self .logger is not None :
571
+ self .logger .save ()
572
+
573
+ def check_validation_in_train_loop (self , batch_idx , early_stop_epoch , is_last_batch ):
574
+ # decide if we should run validation
575
+ is_val_check_batch = (batch_idx + 1 ) % self .val_check_batch == 0
576
+ can_check_epoch = (self .current_epoch + 1 ) % self .check_val_every_n_epoch == 0
577
+ can_check_val = not self .disable_validation and can_check_epoch
578
+ should_check_val = is_val_check_batch or early_stop_epoch
579
+ should_check_val = should_check_val or (is_last_batch and self .val_check_batch == float ('inf' ))
580
+ should_check_val = can_check_val and should_check_val
581
+
582
+ # if we need to run validation, then also call the checkpoint callback
583
+ if self .fast_dev_run or should_check_val :
584
+ self .run_evaluation (test_mode = self .testing )
543
585
self .call_checkpoint_callback ()
544
586
545
- # Epoch end events
546
- with self .profiler .profile ('on_epoch_end' ):
547
- # callbacks
548
- self .on_epoch_end ()
549
- # model hooks
550
- if self .is_function_implemented ('on_epoch_end' ):
551
- model .on_epoch_end ()
587
+ return should_check_val
552
588
553
589
def run_training_batch (self , batch , batch_idx ):
554
590
# track grad norms
@@ -561,7 +597,7 @@ def run_training_batch(self, batch, batch_idx):
561
597
batch_log_metrics = []
562
598
563
599
if batch is None :
564
- return 0 , grad_norm_dic , {}, {}
600
+ return AttributeDict ( signal = 0 , grad_norm_dic = grad_norm_dic )
565
601
566
602
# Batch start events
567
603
with self .profiler .profile ('on_batch_start' ):
@@ -571,7 +607,7 @@ def run_training_batch(self, batch, batch_idx):
571
607
if self .is_function_implemented ('on_batch_start' ):
572
608
response = self .get_model ().on_batch_start (batch )
573
609
if response == - 1 :
574
- return - 1 , grad_norm_dic , {}, {}
610
+ return AttributeDict ( signal = - 1 , grad_norm_dic = grad_norm_dic )
575
611
576
612
splits = [batch ]
577
613
if self .truncated_bptt_steps is not None :
@@ -650,7 +686,13 @@ def run_training_batch(self, batch, batch_idx):
650
686
# track all metrics for callbacks
651
687
self .callback_metrics .update ({k : v for d in batch_callback_metrics for k , v in d .items ()})
652
688
653
- return 0 , grad_norm_dic , batch_log_metrics , opt_closure_result .training_step_output_for_epoch_end
689
+ result = AttributeDict (
690
+ signal = 0 ,
691
+ grad_norm_dic = grad_norm_dic ,
692
+ batch_log_metrics = batch_log_metrics ,
693
+ training_step_output_for_epoch_end = opt_closure_result .training_step_output_for_epoch_end
694
+ )
695
+ return result
654
696
655
697
def run_batch_backward_pass (self , split_batch , batch_idx , opt_idx , optimizer ):
656
698
# ------------------
0 commit comments