@@ -155,6 +155,7 @@ def training_step(self, batch, batch_idx):
155
155
import copy
156
156
import warnings
157
157
from abc import ABC , abstractmethod
158
+ import logging as log
158
159
159
160
import numpy as np
160
161
@@ -307,98 +308,95 @@ def process_output(self, output, train):
307
308
def train (self ):
308
309
warnings .warn ('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
309
310
' but will start from "0" in v0.8.0.' , DeprecationWarning )
311
+ # get model
310
312
model = self .get_model ()
311
- # run all epochs
312
- for epoch in range (self .current_epoch , self .max_epochs ):
313
- # set seed for distributed sampler (enables shuffling for each epoch)
314
- if (self .use_ddp or self .use_tpu ) \
315
- and hasattr (self .get_train_dataloader ().sampler , 'set_epoch' ):
316
- self .get_train_dataloader ().sampler .set_epoch (epoch )
317
-
318
- # get model
319
- model = self .get_model ()
320
-
321
- # update training progress in trainer and model
322
- model .current_epoch = epoch
323
- self .current_epoch = epoch
324
-
325
- total_val_batches = 0
326
- is_val_epoch = False
327
- if not self .disable_validation :
328
- # val can be checked multiple times in epoch
329
- is_val_epoch = (self .current_epoch + 1 ) % self .check_val_every_n_epoch == 0
330
- val_checks_per_epoch = self .num_training_batches // self .val_check_batch
331
- val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
332
- total_val_batches = self .num_val_batches * val_checks_per_epoch
333
-
334
- # total batches includes multiple val checks
335
- self .total_batches = self .num_training_batches + total_val_batches
336
- self .batch_loss_value = 0 # accumulated grads
337
-
338
- if self .fast_dev_run :
339
- # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
340
- num_iterations = 2
341
- elif self .is_iterable_train_dataloader :
342
- # for iterable train loader, the progress bar never ends
343
- num_iterations = None
344
- else :
345
- num_iterations = self .total_batches
346
-
347
- # reset progress bar
348
- # .reset() doesn't work on disabled progress bar so we should check
349
- if not self .main_progress_bar .disable :
350
- self .main_progress_bar .reset (num_iterations )
351
- desc = f'Epoch { epoch + 1 } ' if not self .is_iterable_train_dataloader else ''
352
- self .main_progress_bar .set_description (desc )
353
-
354
- # changing gradient according accumulation_scheduler
355
- self .accumulation_scheduler .on_epoch_begin ()
356
-
357
- # -----------------
358
- # RUN TNG EPOCH
359
- # -----------------
360
- self .run_training_epoch ()
361
-
362
- # update LR schedulers
363
- if self .lr_schedulers is not None :
364
- for lr_scheduler in self .lr_schedulers :
365
- lr_scheduler .step ()
366
- if self .reduce_lr_on_plateau_scheduler is not None :
367
- val_loss = self .callback_metrics .get ('val_loss' )
368
- if val_loss is None :
369
- avail_metrics = ',' .join (list (self .callback_metrics .keys ()))
370
- m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
371
- f'which is not available. Available metrics are: { avail_metrics } '
372
- raise MisconfigurationException (m )
373
- self .reduce_lr_on_plateau_scheduler .step (val_loss )
374
-
375
- if self .max_steps and self .max_steps == self .global_step :
376
- self .main_progress_bar .close ()
377
- model .on_train_end ()
378
- return
379
-
380
- # early stopping
381
- met_min_epochs = epoch >= self .min_epochs - 1
382
- met_min_steps = self .global_step >= self .min_steps if self .min_steps else True
383
-
384
- if (self .enable_early_stop and not self .disable_validation and is_val_epoch and
385
- ((met_min_epochs and met_min_steps ) or self .fast_dev_run )):
386
- should_stop = self .early_stop_callback .on_epoch_end ()
387
- # stop training
388
- stop = should_stop and met_min_epochs
389
- if stop :
313
+ try :
314
+ # run all epochs
315
+ for epoch in range (self .current_epoch , self .max_epochs ):
316
+ # set seed for distributed sampler (enables shuffling for each epoch)
317
+ if (self .use_ddp or self .use_tpu ) \
318
+ and hasattr (self .get_train_dataloader ().sampler , 'set_epoch' ):
319
+ self .get_train_dataloader ().sampler .set_epoch (epoch )
320
+
321
+ # get model
322
+ model = self .get_model ()
323
+
324
+ # update training progress in trainer and model
325
+ model .current_epoch = epoch
326
+ self .current_epoch = epoch
327
+
328
+ total_val_batches = 0
329
+ is_val_epoch = False
330
+ if not self .disable_validation :
331
+ # val can be checked multiple times in epoch
332
+ is_val_epoch = (self .current_epoch + 1 ) % self .check_val_every_n_epoch == 0
333
+ val_checks_per_epoch = self .num_training_batches // self .val_check_batch
334
+ val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
335
+ total_val_batches = self .num_val_batches * val_checks_per_epoch
336
+
337
+ # total batches includes multiple val checks
338
+ self .total_batches = self .num_training_batches + total_val_batches
339
+ self .batch_loss_value = 0 # accumulated grads
340
+
341
+ if self .fast_dev_run :
342
+ # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
343
+ num_iterations = 2
344
+ elif self .is_iterable_train_dataloader :
345
+ # for iterable train loader, the progress bar never ends
346
+ num_iterations = None
347
+ else :
348
+ num_iterations = self .total_batches
349
+
350
+ # reset progress bar
351
+ # .reset() doesn't work on disabled progress bar so we should check
352
+ if not self .main_progress_bar .disable :
353
+ self .main_progress_bar .reset (num_iterations )
354
+ desc = f'Epoch { epoch + 1 } ' if not self .is_iterable_train_dataloader else ''
355
+ self .main_progress_bar .set_description (desc )
356
+
357
+ # changing gradient according accumulation_scheduler
358
+ self .accumulation_scheduler .on_epoch_begin ()
359
+
360
+ # -----------------
361
+ # RUN TNG EPOCH
362
+ # -----------------
363
+ self .run_training_epoch ()
364
+
365
+ # update LR schedulers
366
+ if self .lr_schedulers is not None :
367
+ for lr_scheduler in self .lr_schedulers :
368
+ lr_scheduler .step ()
369
+ if self .reduce_lr_on_plateau_scheduler is not None :
370
+ val_loss = self .callback_metrics .get ('val_loss' )
371
+ if val_loss is None :
372
+ avail_metrics = ',' .join (list (self .callback_metrics .keys ()))
373
+ m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
374
+ f'which is not available. Available metrics are: { avail_metrics } '
375
+ raise MisconfigurationException (m )
376
+ self .reduce_lr_on_plateau_scheduler .step (val_loss )
377
+
378
+ if self .max_steps and self .max_steps == self .global_step :
390
379
self .main_progress_bar .close ()
391
- with self .profiler .profile ('on_train_end' ):
392
- model .on_train_end ()
380
+ model .on_train_end ()
393
381
return
394
382
395
- self .main_progress_bar .close ()
383
+ # early stopping
384
+ met_min_epochs = epoch >= self .min_epochs - 1
385
+ met_min_steps = self .global_step >= self .min_steps if self .min_steps else True
396
386
397
- with self .profiler .profile ('on_train_end' ):
398
- model .on_train_end ()
387
+ if (self .enable_early_stop and not self .disable_validation and is_val_epoch and
388
+ ((met_min_epochs and met_min_steps ) or self .fast_dev_run )):
389
+ should_stop = self .early_stop_callback .on_epoch_end ()
390
+ # stop training
391
+ stop = should_stop and met_min_epochs
392
+ if stop :
393
+ self .run_training_teardown ()
394
+ return
399
395
400
- if self .logger is not None :
401
- self .logger .finalize ("success" )
396
+ self .run_training_teardown ()
397
+ except KeyboardInterrupt :
398
+ log .info ('Detected KeyboardInterrupt, attempting graceful shutdown...' )
399
+ self .run_training_teardown ()
402
400
403
401
def run_training_epoch (self ):
404
402
# before epoch hook
@@ -622,6 +620,20 @@ def optimizer_closure():
622
620
623
621
return 0 , grad_norm_dic , all_log_metrics
624
622
623
+ def run_training_teardown (self ):
624
+ model = self .get_model ()
625
+
626
+ self .main_progress_bar .close ()
627
+
628
+ with self .profiler .profile ('on_train_end' ):
629
+ model .on_train_end ()
630
+
631
+ if self .logger is not None :
632
+ self .logger .finalize ("success" )
633
+
634
+ # summarize profile results
635
+ self .profiler .describe ()
636
+
625
637
def training_forward (self , batch , batch_idx , opt_idx , hiddens ):
626
638
"""
627
639
Handle forward for each training case (distributed, single gpu, etc...)
0 commit comments