Skip to content

Commit 5035ce5

Browse files
authored
Make default tqdm dict overridable (#749)
* overridable tqdm_dict * Slim down default tqdm_metrics * gpu fix
1 parent 734b28e commit 5035ce5

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

pytorch_lightning/core/lightning.py

+19
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,25 @@ def on_save_checkpoint(self, checkpoint):
12101210
12111211
"""
12121212

1213+
def get_tqdm_dict(self):
1214+
r"""
1215+
Additional items to be displayed in the progress bar.
1216+
1217+
Return:
1218+
Dictionary with the items to be displayed in the progress bar.
1219+
"""
1220+
tqdm_dict = {
1221+
'loss': '{:.3f}'.format(self.trainer.avg_loss)
1222+
}
1223+
1224+
if self.trainer.truncated_bptt_steps is not None:
1225+
tqdm_dict['split_idx'] = self.trainer.split_idx
1226+
1227+
if self.trainer.logger is not None and self.trainer.logger.version is not None:
1228+
tqdm_dict['v_num'] = self.trainer.logger.version
1229+
1230+
return tqdm_dict
1231+
12131232

12141233
def load_hparams_from_tags_csv(tags_csv):
12151234
if not os.path.isfile(tags_csv):

pytorch_lightning/trainer/evaluation_loop.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def run_evaluation(self, test=False):
295295
desc = 'Testing' if test else 'Validating'
296296
pbar = tqdm(desc=desc, total=max_batches, leave=test, position=position,
297297
disable=not self.show_progress_bar, dynamic_ncols=True,
298-
unit='batch', file=sys.stdout)
298+
file=sys.stdout)
299299
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
300300

301301
# run evaluation
@@ -319,9 +319,8 @@ def run_evaluation(self, test=False):
319319
model.on_post_performance_check()
320320

321321
# add model specific metrics
322-
tqdm_metrics = self.training_tqdm_dict
323322
if not test:
324-
self.main_progress_bar.set_postfix(**tqdm_metrics)
323+
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
325324

326325
# close progress bar
327326
if test:

pytorch_lightning/trainer/trainer.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -681,23 +681,9 @@ def training_tqdm_dict(self):
681681
"""Read-only for tqdm metrics.
682682
:return:
683683
"""
684-
tqdm_dict = {
685-
'loss': '{0:.3f}'.format(self.avg_loss),
686-
'batch_idx': '{}'.format(self.batch_idx),
687-
}
684+
ref_model = self.model if not self.data_parallel else self.model.module
688685

689-
if self.truncated_bptt_steps is not None:
690-
tqdm_dict['split_idx'] = self.split_idx
691-
692-
if self.logger is not None and self.logger.version is not None:
693-
tqdm_dict['v_num'] = self.logger.version
694-
695-
tqdm_dict.update(self.tqdm_metrics)
696-
697-
if self.on_gpu:
698-
tqdm_dict['gpu'] = '{}'.format(torch.cuda.current_device())
699-
700-
return tqdm_dict
686+
return dict(**ref_model.get_tqdm_dict(), **self.tqdm_metrics)
701687

702688
@property
703689
def tng_tqdm_dic(self):
@@ -855,7 +841,7 @@ def run_pretrain_routine(self, model):
855841
pbar = tqdm(desc='Validation sanity check',
856842
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
857843
leave=False, position=2 * self.process_position,
858-
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
844+
disable=not self.show_progress_bar, dynamic_ncols=True)
859845
self.main_progress_bar = pbar
860846
# dummy validation progress bar
861847
self.val_progress_bar = tqdm(disable=True)
@@ -873,7 +859,7 @@ def run_pretrain_routine(self, model):
873859

874860
# init progress bar
875861
pbar = tqdm(leave=True, position=2 * self.process_position,
876-
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
862+
disable=not self.show_progress_bar, dynamic_ncols=True,
877863
file=sys.stdout)
878864
self.main_progress_bar = pbar
879865

0 commit comments

Comments
 (0)