Skip to content

Commit 598f514

Browse files
refactor training loop (#2336)
* refactoring training epoch * refactored training epoch * refactored training epoch * refactored training epoch * refactored training epoch * refactored training epoch * fixes slurm weights saving * fixes slurm weights saving
1 parent c09b2ff commit 598f514

File tree

2 files changed

+131
-93
lines changed

2 files changed

+131
-93
lines changed

pytorch_lightning/trainer/training_loop.py

+115-73
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,16 @@ def train(self):
415415

416416
self.run_training_teardown()
417417

418-
def run_training_epoch(self):
418+
def prepare_train_loop_dataloader(self, train_dataloader):
419+
# on TPU we have to wrap it under the ParallelLoader
420+
if self.use_tpu:
421+
device = xm.xla_device(self.tpu_id)
422+
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
423+
train_dataloader = train_dataloader.per_device_loader(device)
419424

420-
# get model
421-
model = self.get_model()
425+
return train_dataloader
422426

427+
def run_on_epoch_start_hook(self, model):
423428
# Epoch start events
424429
with self.profiler.profile('on_epoch_start'):
425430
# callbacks
@@ -429,17 +434,19 @@ def run_training_epoch(self):
429434
if self.is_function_implemented('on_epoch_start'):
430435
model.on_epoch_start()
431436

432-
# track local dataloader so TPU can wrap each epoch
433-
train_dataloader = self.train_dataloader
437+
def run_training_epoch(self):
434438

435-
# on TPU we have to wrap it under the ParallelLoader
436-
if self.use_tpu:
437-
device = xm.xla_device(self.tpu_id)
438-
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
439-
train_dataloader = train_dataloader.per_device_loader(device)
439+
# get model
440+
model = self.get_model()
441+
442+
# Epoch start events
443+
self.run_on_epoch_start_hook(model)
444+
445+
# modify dataloader if needed (ddp, etc...)
446+
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)
440447

441448
# bookkeeping
442-
outputs = []
449+
epoch_output = []
443450

444451
# run epoch
445452
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
@@ -450,63 +457,41 @@ def run_training_epoch(self):
450457
break
451458

452459
self.batch_idx = batch_idx
453-
454460
model.global_step = self.global_step
455461

456-
# ---------------
457-
# RUN TRAIN STEP
458-
# ---------------
459-
_outputs = self.run_training_batch(batch, batch_idx)
460-
batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
462+
# ------------------------------------
463+
# TRAINING_STEP + TRAINING_STEP_END
464+
# ------------------------------------
465+
batch_output = self.run_training_batch(batch, batch_idx)
461466

462467
# only track outputs when user implements training_epoch_end
463468
# otherwise we will build up unnecessary memory
464469
if self.is_overridden('training_epoch_end', model=self.get_model()):
465-
outputs.append(batch_output)
470+
epoch_output.append(batch_output.training_step_output_for_epoch_end)
471+
472+
# update LR schedulers
473+
self.update_train_loop_lr_schedulers()
466474

467475
# when returning -1 from train_step, we end epoch early
468-
early_stop_epoch = batch_result == -1
469-
470-
# TODO: consolidate all actions that need to take place only after
471-
# self.accumulate_grad_batches steps (optimizer step, lr update, global step increment)
472-
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
473-
# update lr
474-
self.update_learning_rates(interval='step')
475-
476-
# ---------------
477-
# RUN VAL STEP
478-
# ---------------
479-
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
480-
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
481-
can_check_val = not self.disable_validation and can_check_epoch
482-
should_check_val = is_val_check_batch or early_stop_epoch
483-
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
484-
should_check_val = can_check_val and should_check_val
485-
486-
# ---------------
487-
# CHECKPOINTING, EARLY STOPPING
488-
# ---------------
489-
# fast_dev_run always forces val checking after train batch
490-
if self.fast_dev_run or should_check_val:
491-
self.run_evaluation(test_mode=self.testing)
492-
self.call_checkpoint_callback()
493-
494-
# when logs should be saved
495-
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
496-
if should_save_log or self.fast_dev_run:
497-
if self.is_global_zero and self.logger is not None:
498-
self.logger.save()
499-
500-
# when metrics should be logged
501-
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
502-
if should_log_metrics or self.fast_dev_run:
503-
# logs user requested information to logger
504-
self.log_metrics(batch_step_metrics, grad_norm_dic)
476+
early_stop_epoch = batch_output.signal == -1
477+
478+
# -----------------------------------------
479+
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
480+
# -----------------------------------------
481+
should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch)
482+
483+
# -----------------------------------------
484+
# SAVE LOGGERS (ie: Tensorboard, etc...)
485+
# -----------------------------------------
486+
self.save_loggers_in_training_loop(batch_idx, early_stop_epoch)
487+
488+
# -----------------------------------------
489+
# SAVE METRICS TO LOGGERS
490+
# -----------------------------------------
491+
self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output)
505492

506493
# progress global step according to grads progress
507-
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
508-
self.global_step += 1
509-
self.total_batch_idx += 1
494+
self.increment_accumulated_grad_global_step()
510495

511496
# max steps reached, end training
512497
if self.max_steps is not None and self.max_steps == self.global_step:
@@ -518,13 +503,36 @@ def run_training_epoch(self):
518503
if early_stop_epoch or self.fast_dev_run:
519504
break
520505

521-
if self.use_horovod:
522-
hvd.join(hvd.local_rank() if self.on_gpu else -1)
506+
# let ddp devices catch up when using horovod
507+
self.sync_horovod()
523508

524509
# process epoch outputs
510+
self.run_training_epoch_end(epoch_output)
511+
512+
# when no val loop is present or fast-dev-run still need to call checkpoints
513+
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
514+
self.call_checkpoint_callback()
515+
516+
# epoch end hook
517+
self.run_on_epoch_end_hook(model)
518+
519+
def update_train_loop_lr_schedulers(self):
520+
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
521+
# update lr
522+
self.update_learning_rates(interval='step')
523+
524+
def run_on_epoch_end_hook(self, model):
525+
with self.profiler.profile('on_epoch_end'):
526+
# callbacks
527+
self.on_epoch_end()
528+
# model hooks
529+
if self.is_function_implemented('on_epoch_end'):
530+
model.on_epoch_end()
531+
532+
def run_training_epoch_end(self, epoch_output):
525533
model = self.get_model()
526534
if self.is_overridden('training_epoch_end', model=model):
527-
epoch_output = model.training_epoch_end(outputs)
535+
epoch_output = model.training_epoch_end(epoch_output)
528536
_processed_outputs = self.process_output(epoch_output)
529537
log_epoch_metrics = _processed_outputs[2]
530538
callback_epoch_metrics = _processed_outputs[3]
@@ -538,17 +546,45 @@ def run_training_epoch(self):
538546
# add metrics to progress_bar
539547
self.add_progress_bar_metrics(_processed_outputs[1])
540548

541-
# when no val loop is present or fast-dev-run still need to call checkpoints
542-
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
549+
def sync_horovod(self):
550+
if self.use_horovod:
551+
hvd.join(hvd.local_rank() if self.on_gpu else -1)
552+
553+
def increment_accumulated_grad_global_step(self):
554+
# progress global step according to grads progress
555+
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
556+
self.global_step += 1
557+
self.total_batch_idx += 1
558+
559+
def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output):
560+
# when metrics should be logged
561+
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
562+
if should_log_metrics or self.fast_dev_run:
563+
# logs user requested information to logger
564+
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)
565+
566+
def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch):
567+
# when loggers should save to disk
568+
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
569+
if should_save_log or self.fast_dev_run:
570+
if self.is_global_zero and self.logger is not None:
571+
self.logger.save()
572+
573+
def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch):
574+
# decide if we should run validation
575+
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
576+
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
577+
can_check_val = not self.disable_validation and can_check_epoch
578+
should_check_val = is_val_check_batch or early_stop_epoch
579+
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
580+
should_check_val = can_check_val and should_check_val
581+
582+
# if we need to run validation, then also call the checkpoint callback
583+
if self.fast_dev_run or should_check_val:
584+
self.run_evaluation(test_mode=self.testing)
543585
self.call_checkpoint_callback()
544586

545-
# Epoch end events
546-
with self.profiler.profile('on_epoch_end'):
547-
# callbacks
548-
self.on_epoch_end()
549-
# model hooks
550-
if self.is_function_implemented('on_epoch_end'):
551-
model.on_epoch_end()
587+
return should_check_val
552588

553589
def run_training_batch(self, batch, batch_idx):
554590
# track grad norms
@@ -561,7 +597,7 @@ def run_training_batch(self, batch, batch_idx):
561597
batch_log_metrics = []
562598

563599
if batch is None:
564-
return 0, grad_norm_dic, {}, {}
600+
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
565601

566602
# Batch start events
567603
with self.profiler.profile('on_batch_start'):
@@ -571,7 +607,7 @@ def run_training_batch(self, batch, batch_idx):
571607
if self.is_function_implemented('on_batch_start'):
572608
response = self.get_model().on_batch_start(batch)
573609
if response == -1:
574-
return -1, grad_norm_dic, {}, {}
610+
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
575611

576612
splits = [batch]
577613
if self.truncated_bptt_steps is not None:
@@ -650,7 +686,13 @@ def run_training_batch(self, batch, batch_idx):
650686
# track all metrics for callbacks
651687
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})
652688

653-
return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end
689+
result = AttributeDict(
690+
signal=0,
691+
grad_norm_dic=grad_norm_dic,
692+
batch_log_metrics=batch_log_metrics,
693+
training_step_output_for_epoch_end=opt_closure_result.training_step_output_for_epoch_end
694+
)
695+
return result
654696

655697
def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
656698
# ------------------

tests/trainer/test_trainer_steps.py

+16-20
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ def test_trainingstep_dict(tmpdir):
2323
break
2424

2525
out = trainer.run_training_batch(batch, batch_idx)
26-
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
27-
assert signal == 0
28-
assert all_log_metrics['log_acc1'] == 12.0
29-
assert all_log_metrics['log_acc2'] == 7.0
26+
assert out.signal == 0
27+
assert out.batch_log_metrics['log_acc1'] == 12.0
28+
assert out.batch_log_metrics['log_acc2'] == 7.0
3029

31-
pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
30+
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
3231
assert pbar_metrics['pbar_acc1'] == 17.0
3332
assert pbar_metrics['pbar_acc2'] == 19.0
3433

@@ -55,12 +54,11 @@ def training_step_with_step_end(tmpdir):
5554
break
5655

5756
out = trainer.run_training_batch(batch, batch_idx)
58-
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
59-
assert signal == 0
60-
assert all_log_metrics['log_acc1'] == 12.0
61-
assert all_log_metrics['log_acc2'] == 7.0
57+
assert out.signal == 0
58+
assert out.batch_log_metrics['log_acc1'] == 12.0
59+
assert out.batch_log_metrics['log_acc2'] == 7.0
6260

63-
pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
61+
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
6462
assert pbar_metrics['pbar_acc1'] == 17.0
6563
assert pbar_metrics['pbar_acc2'] == 19.0
6664

@@ -92,12 +90,11 @@ def test_full_training_loop_dict(tmpdir):
9290
break
9391

9492
out = trainer.run_training_batch(batch, batch_idx)
95-
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
96-
assert signal == 0
97-
assert all_log_metrics['log_acc1'] == 12.0
98-
assert all_log_metrics['log_acc2'] == 7.0
93+
assert out.signal == 0
94+
assert out.batch_log_metrics['log_acc1'] == 12.0
95+
assert out.batch_log_metrics['log_acc2'] == 7.0
9996

100-
pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
97+
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
10198
assert pbar_metrics['pbar_acc1'] == 17.0
10299
assert pbar_metrics['pbar_acc2'] == 19.0
103100

@@ -129,11 +126,10 @@ def test_train_step_epoch_end(tmpdir):
129126
break
130127

131128
out = trainer.run_training_batch(batch, batch_idx)
132-
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
133-
assert signal == 0
134-
assert all_log_metrics['log_acc1'] == 12.0
135-
assert all_log_metrics['log_acc2'] == 7.0
129+
assert out.signal == 0
130+
assert out.batch_log_metrics['log_acc1'] == 12.0
131+
assert out.batch_log_metrics['log_acc2'] == 7.0
136132

137-
pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
133+
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
138134
assert pbar_metrics['pbar_acc1'] == 17.0
139135
assert pbar_metrics['pbar_acc2'] == 19.0

0 commit comments

Comments
 (0)