forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprogress.py
369 lines (301 loc) · 13.6 KB
/
progress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
"""
Progress Bars
=============
Use or override one of the progress bar callbacks.
"""
import importlib
import sys
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
else:
from tqdm import tqdm
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.data import has_len
class ProgressBarBase(Callback):
r"""
The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback`
that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
You should implement your highly custom progress bars with this as the base class.
Example::
class LitProgressBar(ProgressBarBase):
def __init__(self):
super().__init__() # don't forget this :)
self.enable = True
def disable(self):
self.enable = False
def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module) # don't forget this :)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
"""
def __init__(self):
self._trainer = None
self._train_batch_idx = 0
self._val_batch_idx = 0
self._test_batch_idx = 0
@property
def trainer(self):
return self._trainer
@property
def train_batch_idx(self) -> int:
"""
The current batch index being processed during training.
Use this to update your progress bar.
"""
return self._train_batch_idx
@property
def val_batch_idx(self) -> int:
"""
The current batch index being processed during validation.
Use this to update your progress bar.
"""
return self._val_batch_idx
@property
def test_batch_idx(self) -> int:
"""
The current batch index being processed during testing.
Use this to update your progress bar.
"""
return self._test_batch_idx
@property
def total_train_batches(self) -> int:
"""
The total number of training batches during training, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
training dataloader is of infinite size.
"""
return self.trainer.num_training_batches
@property
def total_val_batches(self) -> int:
"""
The total number of training batches during validation, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
validation dataloader is of infinite size.
"""
total_val_batches = 0
if not self.trainer.disable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
return total_val_batches
@property
def total_test_batches(self) -> int:
"""
The total number of training batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
test dataloader is of infinite size.
"""
return sum(self.trainer.num_test_batches)
def disable(self):
"""
You should provide a way to disable the progress bar.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the
output on processes that have a rank different from 0, e.g., in multi-node training.
"""
raise NotImplementedError
def enable(self):
"""
You should provide a way to enable the progress bar.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training
routines like the `learning rate finder <lr_finder.rst>`_ to temporarily enable and
disable the main progress bar.
"""
raise NotImplementedError
def on_init_end(self, trainer):
self._trainer = trainer
def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.batch_idx
def on_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._train_batch_idx += 1
def on_validation_start(self, trainer, pl_module):
self._val_batch_idx = 0
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._val_batch_idx += 1
def on_test_start(self, trainer, pl_module):
self._test_batch_idx = 0
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1
class ProgressBar(ProgressBarBase):
r"""
This is the default progress bar used by Lightning. It prints to `stdout` using the
:mod:`tqdm` package and shows up to four different bars:
- **sanity check progress:** the progress during the sanity check run
- **main progress:** shows training + validation progress combined. It also accounts for
multiple validation runs during training when
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
- **validation progress:** only visible during validation;
shows total progress over all validation datasets.
- **test progress:** only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the
:class:`~pytorch_lightning.trainer.trainer.Trainer`:
Example::
class LitProgressBar(ProgressBar):
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
bar.set_description('running validation ...')
return bar
bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
Args:
refresh_rate:
Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display. By default, the
:class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress
bar and sets the refresh rate to the value provided to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
process_position:
Set this to a value greater than ``0`` to offset the progress bars by this many lines.
This is useful when you have progress bars defined elsewhere and want to show all of them
together. This corresponds to
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
"""
def __init__(self, refresh_rate: int = 1, process_position: int = 0):
super().__init__()
self._refresh_rate = refresh_rate
self._process_position = process_position
self._enabled = True
self.main_progress_bar = None
self.val_progress_bar = None
self.test_progress_bar = None
def __getstate__(self):
# can't pickle the tqdm objects
state = self.__dict__.copy()
state['main_progress_bar'] = None
state['val_progress_bar'] = None
state['test_progress_bar'] = None
return state
@property
def refresh_rate(self) -> int:
return self._refresh_rate
@property
def process_position(self) -> int:
return self._process_position
@property
def is_enabled(self) -> bool:
return self._enabled and self.refresh_rate > 0
@property
def is_disabled(self) -> bool:
return not self.is_enabled
def disable(self) -> None:
self._enabled = False
def enable(self) -> None:
self._enabled = True
def init_sanity_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for the validation sanity run. """
bar = tqdm(
desc='Validation sanity check',
position=(2 * self.process_position),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
file=sys.stdout,
)
return bar
def init_train_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for training. """
bar = tqdm(
desc='Training',
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)
return bar
def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
bar = tqdm(
desc='Validating',
position=(2 * self.process_position + 1),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
file=sys.stdout
)
return bar
def init_test_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for testing. """
bar = tqdm(
desc='Testing',
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout
)
return bar
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = sum(
min(trainer.num_sanity_val_steps, len(d) if has_len(d) else float('inf')) for d in trainer.val_dataloaders
)
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self.main_progress_bar.close()
self.val_progress_bar.close()
def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()
def on_epoch_start(self, trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf') and not trainer.fast_dev_run:
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
self.main_progress_bar.update(self.refresh_rate)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
self.val_progress_bar.update(self.refresh_rate)
self.main_progress_bar.update(self.refresh_rate)
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.val_progress_bar.close()
def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
self.main_progress_bar.close()
def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0:
self.test_progress_bar.update(self.refresh_rate)
def on_test_end(self, trainer, pl_module):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()
def convert_inf(x):
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
return None
return x