Skip to content

Commit be80603

Browse files
author
gbkh2015
committed
move show_progress_bar to deprecated 0.9 api
1 parent cc48912 commit be80603

File tree

4 files changed

+30
-13
lines changed

4 files changed

+30
-13
lines changed

pytorch_lightning/trainer/deprecated_api.py

+20
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,23 @@ def nb_sanity_val_steps(self, nb):
8787
"`num_sanity_val_steps` since v0.5.0"
8888
" and this method will be removed in v0.8.0", DeprecationWarning)
8989
self.num_sanity_val_steps = nb
90+
91+
92+
class TrainerDeprecatedAPITillVer0_9(ABC):
93+
94+
def __init__(self):
95+
super().__init__() # mixin calls super too
96+
97+
@property
98+
def show_progress_bar(self):
99+
"""Back compatibility, will be removed in v0.9.0"""
100+
warnings.warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
101+
" and this method will be removed in v0.9.0", DeprecationWarning)
102+
return self.progress_bar_refresh_rate >= 1
103+
104+
@show_progress_bar.setter
105+
def show_progress_bar(self, tf):
106+
"""Back compatibility, will be removed in v0.9.0"""
107+
warnings.warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
108+
" and this method will be removed in v0.9.0", DeprecationWarning)
109+
self.show_progress_bar = tf

pytorch_lightning/trainer/evaluation_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_m
278278
dl_outputs.append(output)
279279

280280
# batch done
281-
if self.show_progress_bar and batch_idx % self.progress_bar_refresh_rate == 0:
281+
if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0:
282282
if test_mode:
283283
self.test_progress_bar.update(self.progress_bar_refresh_rate)
284284
else:

pytorch_lightning/trainer/trainer.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
2525
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
2626
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
27-
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8
27+
from pytorch_lightning.trainer.deprecated_api import (TrainerDeprecatedAPITillVer0_8,
28+
TrainerDeprecatedAPITillVer0_9)
2829
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
2930
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
3031
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
@@ -66,12 +67,13 @@ class Trainer(
6667
TrainerCallbackConfigMixin,
6768
TrainerCallbackHookMixin,
6869
TrainerDeprecatedAPITillVer0_8,
70+
TrainerDeprecatedAPITillVer0_9,
6971
):
7072
DEPRECATED_IN_0_8 = (
7173
'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs',
7274
'add_row_log_interval', 'nb_sanity_val_steps'
7375
)
74-
DEPRECATED_IN_0_9 = ('use_amp',)
76+
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar')
7577

7678
def __init__(
7779
self,
@@ -88,7 +90,7 @@ def __init__(
8890
gpus: Optional[Union[List[int], str, int]] = None,
8991
num_tpu_cores: Optional[int] = None,
9092
log_gpu_memory: Optional[str] = None,
91-
show_progress_bar=None, # backward compatible, todo: remove in v0.8.0
93+
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
9294
progress_bar_refresh_rate: int = 1,
9395
overfit_pct: float = 0.0,
9496
track_grad_norm: int = -1,
@@ -416,12 +418,11 @@ def __init__(
416418
# nvidia setup
417419
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
418420

419-
# Backward compatibility, TODO: remove in v0.8.0
420-
if show_progress_bar is not None:
421-
warnings.warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.1"
422-
" and this method will be removed in v0.8.0", DeprecationWarning)
423421
# can't init progress bar here because starting a new process
424422
# means the progress_bar won't survive pickling
423+
# backward compatibility
424+
if show_progress_bar is not None:
425+
self.show_progress_bar = show_progress_bar
425426

426427
# logging
427428
self.log_save_interval = log_save_interval
@@ -567,10 +568,6 @@ def from_argparse_args(cls, args):
567568
params = vars(args)
568569
return cls(**params)
569570

570-
@property
571-
def show_progress_bar(self) -> bool:
572-
return self.progress_bar_refresh_rate >= 1
573-
574571
@property
575572
def num_gpus(self) -> int:
576573
gpus = self.data_parallel_device_ids

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def optimizer_closure():
606606
self.get_model().on_batch_end()
607607

608608
# update progress bar
609-
if self.show_progress_bar and batch_idx % self.progress_bar_refresh_rate == 0:
609+
if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0:
610610
self.main_progress_bar.update(self.progress_bar_refresh_rate)
611611
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
612612

0 commit comments

Comments
 (0)