Skip to content

Commit e2a20f0

Browse files
jeremyjordanBorda
authored andcommitted
extract training teardown into method, catch KeyboardInterrupt (Lightning-AI#856)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent fb0013b commit e2a20f0

File tree

2 files changed

+98
-89
lines changed

2 files changed

+98
-89
lines changed

pytorch_lightning/trainer/trainer.py

-3
Original file line numberDiff line numberDiff line change
@@ -1060,9 +1060,6 @@ def run_pretrain_routine(self, model):
10601060
# CORE TRAINING LOOP
10611061
self.train()
10621062

1063-
# summarize profile results
1064-
self.profiler.describe()
1065-
10661063
def test(self, model=None):
10671064
r"""
10681065

pytorch_lightning/trainer/training_loop.py

+98-86
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def training_step(self, batch, batch_idx):
155155
import copy
156156
import warnings
157157
from abc import ABC, abstractmethod
158+
import logging as log
158159

159160
import numpy as np
160161

@@ -307,98 +308,95 @@ def process_output(self, output, train):
307308
def train(self):
308309
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
309310
' but will start from "0" in v0.8.0.', DeprecationWarning)
311+
# get model
310312
model = self.get_model()
311-
# run all epochs
312-
for epoch in range(self.current_epoch, self.max_epochs):
313-
# set seed for distributed sampler (enables shuffling for each epoch)
314-
if (self.use_ddp or self.use_tpu) \
315-
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
316-
self.get_train_dataloader().sampler.set_epoch(epoch)
317-
318-
# get model
319-
model = self.get_model()
320-
321-
# update training progress in trainer and model
322-
model.current_epoch = epoch
323-
self.current_epoch = epoch
324-
325-
total_val_batches = 0
326-
is_val_epoch = False
327-
if not self.disable_validation:
328-
# val can be checked multiple times in epoch
329-
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
330-
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
331-
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
332-
total_val_batches = self.num_val_batches * val_checks_per_epoch
333-
334-
# total batches includes multiple val checks
335-
self.total_batches = self.num_training_batches + total_val_batches
336-
self.batch_loss_value = 0 # accumulated grads
337-
338-
if self.fast_dev_run:
339-
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
340-
num_iterations = 2
341-
elif self.is_iterable_train_dataloader:
342-
# for iterable train loader, the progress bar never ends
343-
num_iterations = None
344-
else:
345-
num_iterations = self.total_batches
346-
347-
# reset progress bar
348-
# .reset() doesn't work on disabled progress bar so we should check
349-
if not self.main_progress_bar.disable:
350-
self.main_progress_bar.reset(num_iterations)
351-
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
352-
self.main_progress_bar.set_description(desc)
353-
354-
# changing gradient according accumulation_scheduler
355-
self.accumulation_scheduler.on_epoch_begin()
356-
357-
# -----------------
358-
# RUN TNG EPOCH
359-
# -----------------
360-
self.run_training_epoch()
361-
362-
# update LR schedulers
363-
if self.lr_schedulers is not None:
364-
for lr_scheduler in self.lr_schedulers:
365-
lr_scheduler.step()
366-
if self.reduce_lr_on_plateau_scheduler is not None:
367-
val_loss = self.callback_metrics.get('val_loss')
368-
if val_loss is None:
369-
avail_metrics = ','.join(list(self.callback_metrics.keys()))
370-
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
371-
f'which is not available. Available metrics are: {avail_metrics}'
372-
raise MisconfigurationException(m)
373-
self.reduce_lr_on_plateau_scheduler.step(val_loss)
374-
375-
if self.max_steps and self.max_steps == self.global_step:
376-
self.main_progress_bar.close()
377-
model.on_train_end()
378-
return
379-
380-
# early stopping
381-
met_min_epochs = epoch >= self.min_epochs - 1
382-
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
383-
384-
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
385-
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
386-
should_stop = self.early_stop_callback.on_epoch_end()
387-
# stop training
388-
stop = should_stop and met_min_epochs
389-
if stop:
313+
try:
314+
# run all epochs
315+
for epoch in range(self.current_epoch, self.max_epochs):
316+
# set seed for distributed sampler (enables shuffling for each epoch)
317+
if (self.use_ddp or self.use_tpu) \
318+
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
319+
self.get_train_dataloader().sampler.set_epoch(epoch)
320+
321+
# get model
322+
model = self.get_model()
323+
324+
# update training progress in trainer and model
325+
model.current_epoch = epoch
326+
self.current_epoch = epoch
327+
328+
total_val_batches = 0
329+
is_val_epoch = False
330+
if not self.disable_validation:
331+
# val can be checked multiple times in epoch
332+
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
333+
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
334+
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
335+
total_val_batches = self.num_val_batches * val_checks_per_epoch
336+
337+
# total batches includes multiple val checks
338+
self.total_batches = self.num_training_batches + total_val_batches
339+
self.batch_loss_value = 0 # accumulated grads
340+
341+
if self.fast_dev_run:
342+
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
343+
num_iterations = 2
344+
elif self.is_iterable_train_dataloader:
345+
# for iterable train loader, the progress bar never ends
346+
num_iterations = None
347+
else:
348+
num_iterations = self.total_batches
349+
350+
# reset progress bar
351+
# .reset() doesn't work on disabled progress bar so we should check
352+
if not self.main_progress_bar.disable:
353+
self.main_progress_bar.reset(num_iterations)
354+
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
355+
self.main_progress_bar.set_description(desc)
356+
357+
# changing gradient according accumulation_scheduler
358+
self.accumulation_scheduler.on_epoch_begin()
359+
360+
# -----------------
361+
# RUN TNG EPOCH
362+
# -----------------
363+
self.run_training_epoch()
364+
365+
# update LR schedulers
366+
if self.lr_schedulers is not None:
367+
for lr_scheduler in self.lr_schedulers:
368+
lr_scheduler.step()
369+
if self.reduce_lr_on_plateau_scheduler is not None:
370+
val_loss = self.callback_metrics.get('val_loss')
371+
if val_loss is None:
372+
avail_metrics = ','.join(list(self.callback_metrics.keys()))
373+
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
374+
f'which is not available. Available metrics are: {avail_metrics}'
375+
raise MisconfigurationException(m)
376+
self.reduce_lr_on_plateau_scheduler.step(val_loss)
377+
378+
if self.max_steps and self.max_steps == self.global_step:
390379
self.main_progress_bar.close()
391-
with self.profiler.profile('on_train_end'):
392-
model.on_train_end()
380+
model.on_train_end()
393381
return
394382

395-
self.main_progress_bar.close()
383+
# early stopping
384+
met_min_epochs = epoch >= self.min_epochs - 1
385+
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
396386

397-
with self.profiler.profile('on_train_end'):
398-
model.on_train_end()
387+
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
388+
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
389+
should_stop = self.early_stop_callback.on_epoch_end()
390+
# stop training
391+
stop = should_stop and met_min_epochs
392+
if stop:
393+
self.run_training_teardown()
394+
return
399395

400-
if self.logger is not None:
401-
self.logger.finalize("success")
396+
self.run_training_teardown()
397+
except KeyboardInterrupt:
398+
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
399+
self.run_training_teardown()
402400

403401
def run_training_epoch(self):
404402
# before epoch hook
@@ -622,6 +620,20 @@ def optimizer_closure():
622620

623621
return 0, grad_norm_dic, all_log_metrics
624622

623+
def run_training_teardown(self):
624+
model = self.get_model()
625+
626+
self.main_progress_bar.close()
627+
628+
with self.profiler.profile('on_train_end'):
629+
model.on_train_end()
630+
631+
if self.logger is not None:
632+
self.logger.finalize("success")
633+
634+
# summarize profile results
635+
self.profiler.describe()
636+
625637
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
626638
"""
627639
Handle forward for each training case (distributed, single gpu, etc...)

0 commit comments

Comments
 (0)