Skip to content

Commit 68e3b3f

Browse files
Bordaakarnachev
authored and
akarnachev
committed
Profiler summary (Lightning-AI#1259)
* refactor and add types * add Prorfiler summary * fix imports * Revert "refactor and add types" This reverts commit b4c552f * changelog * revert rename * fix test * mute verbose
1 parent b05a209 commit 68e3b3f

20 files changed

+113
-59
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
2020
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
2121
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
22+
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
2223
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
2324
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
2425

@@ -74,7 +75,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7475
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
7576
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
7677
- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941))
77-
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
78+
- Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
7879
- Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
7980
- Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019))
8081
- Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), )

pytorch_lightning/core/lightning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytorch_lightning.core.memory import ModelSummary
2121
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
2222
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
23-
from pytorch_lightning.utilities.debugging import MisconfigurationException
23+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2424

2525
try:
2626
import torch_xla.core.xla_model as xm

pytorch_lightning/loggers/comet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from pytorch_lightning import _logger as log
3030
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
31-
from pytorch_lightning.utilities.debugging import MisconfigurationException
31+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232

3333

3434
class CometLogger(LightningLoggerBase):

pytorch_lightning/profiler/__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
55
Built-in checks
6-
----------------
6+
---------------
77
88
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
99
@@ -20,7 +20,7 @@
2020
- on_training_end
2121
2222
Enable simple profiling
23-
-------------------------
23+
-----------------------
2424
2525
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
2626
your `Trainer` object.
@@ -113,10 +113,11 @@ def custom_processing_step(self, data):
113113
114114
"""
115115

116-
from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler
116+
from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler
117117

118118
__all__ = [
119-
'Profiler',
119+
'BaseProfiler',
120+
'SimpleProfiler',
120121
'AdvancedProfiler',
121122
'PassThroughProfiler',
122123
]

pytorch_lightning/profiler/profiler.py pytorch_lightning/profiler/profilers.py

+79-25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import cProfile
22
import io
3+
import os
34
import pstats
45
import time
56
from abc import ABC, abstractmethod
@@ -16,6 +17,18 @@ class BaseProfiler(ABC):
1617
If you wish to write a custom profiler, you should inhereit from this class.
1718
"""
1819

20+
def __init__(self, output_streams: list = None):
21+
"""
22+
Params:
23+
stream_out: callable
24+
"""
25+
if output_streams:
26+
if not isinstance(output_streams, (list, tuple)):
27+
output_streams = [output_streams]
28+
else:
29+
output_streams = []
30+
self.write_streams = output_streams
31+
1932
@abstractmethod
2033
def start(self, action_name: str) -> None:
2134
"""Defines how to start recording an action."""
@@ -57,7 +70,12 @@ def profile_iterable(self, iterable, action_name: str) -> None:
5770

5871
def describe(self) -> None:
5972
"""Logs a profile report after the conclusion of the training run."""
60-
pass
73+
for write in self.write_streams:
74+
write(self.summary())
75+
76+
@abstractmethod
77+
def summary(self) -> str:
78+
"""Create profiler summary in text format."""
6179

6280

6381
class PassThroughProfiler(BaseProfiler):
@@ -67,25 +85,39 @@ class PassThroughProfiler(BaseProfiler):
6785
"""
6886

6987
def __init__(self):
70-
pass
88+
super().__init__(output_streams=None)
7189

7290
def start(self, action_name: str) -> None:
7391
pass
7492

7593
def stop(self, action_name: str) -> None:
7694
pass
7795

96+
def summary(self) -> str:
97+
return ""
98+
7899

79-
class Profiler(BaseProfiler):
100+
class SimpleProfiler(BaseProfiler):
80101
"""
81102
This profiler simply records the duration of actions (in seconds) and reports
82103
the mean duration of each action and the total time spent over the entire training run.
83104
"""
84105

85-
def __init__(self):
106+
def __init__(self, output_filename: str = None):
107+
"""
108+
Params:
109+
output_filename (str): optionally save profile results to file instead of printing
110+
to std out when training is finished.
111+
"""
86112
self.current_actions = {}
87113
self.recorded_durations = defaultdict(list)
88114

115+
self.output_fname = output_filename
116+
self.output_file = open(self.output_fname, 'w') if self.output_fname else None
117+
118+
streaming_out = [self.output_file.write] if self.output_file else [log.info]
119+
super().__init__(output_streams=streaming_out)
120+
89121
def start(self, action_name: str) -> None:
90122
if action_name in self.current_actions:
91123
raise ValueError(
@@ -103,20 +135,31 @@ def stop(self, action_name: str) -> None:
103135
duration = end_time - start_time
104136
self.recorded_durations[action_name].append(duration)
105137

106-
def describe(self) -> None:
138+
def summary(self) -> str:
107139
output_string = "\n\nProfiler Report\n"
108140

109141
def log_row(action, mean, total):
110-
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
142+
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
111143

112144
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
113-
output_string += f"\n{'-' * 65}"
145+
output_string += f"{os.linesep}{'-' * 65}"
114146
for action, durations in self.recorded_durations.items():
115147
output_string += log_row(
116148
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
117149
)
118-
output_string += "\n"
119-
log.info(output_string)
150+
output_string += os.linesep
151+
return output_string
152+
153+
def describe(self):
154+
"""Logs a profile report after the conclusion of the training run."""
155+
super().describe()
156+
if self.output_file:
157+
self.output_file.flush()
158+
159+
def __del__(self):
160+
"""Close profiler's stream."""
161+
if self.output_file:
162+
self.output_file.close()
120163

121164

122165
class AdvancedProfiler(BaseProfiler):
@@ -136,9 +179,14 @@ def __init__(self, output_filename: str = None, line_count_restriction: float =
136179
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
137180
"""
138181
self.profiled_actions = {}
139-
self.output_filename = output_filename
140182
self.line_count_restriction = line_count_restriction
141183

184+
self.output_fname = output_filename
185+
self.output_file = open(self.output_fname, 'w') if self.output_fname else None
186+
187+
streaming_out = [self.output_file.write] if self.output_file else [log.info]
188+
super().__init__(output_streams=streaming_out)
189+
142190
def start(self, action_name: str) -> None:
143191
if action_name not in self.profiled_actions:
144192
self.profiled_actions[action_name] = cProfile.Profile()
@@ -152,22 +200,28 @@ def stop(self, action_name: str) -> None:
152200
)
153201
pr.disable()
154202

155-
def describe(self) -> None:
156-
self.recorded_stats = {}
203+
def summary(self) -> str:
204+
recorded_stats = {}
157205
for action_name, pr in self.profiled_actions.items():
158206
s = io.StringIO()
159207
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
160208
ps.print_stats(self.line_count_restriction)
161-
self.recorded_stats[action_name] = s.getvalue()
162-
if self.output_filename is not None:
163-
# save to file
164-
with open(self.output_filename, "w") as f:
165-
for action, stats in self.recorded_stats.items():
166-
f.write(f"Profile stats for: {action}")
167-
f.write(stats)
168-
else:
169-
# log to standard out
170-
output_string = "\nProfiler Report\n"
171-
for action, stats in self.recorded_stats.items():
172-
output_string += f"\nProfile stats for: {action}\n{stats}"
173-
log.info(output_string)
209+
recorded_stats[action_name] = s.getvalue()
210+
211+
# log to standard out
212+
output_string = f"{os.linesep}Profiler Report{os.linesep}"
213+
for action, stats in recorded_stats.items():
214+
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"
215+
216+
return output_string
217+
218+
def describe(self):
219+
"""Logs a profile report after the conclusion of the training run."""
220+
super().describe()
221+
if self.output_file:
222+
self.output_file.flush()
223+
224+
def __del__(self):
225+
"""Close profiler's stream."""
226+
if self.output_file:
227+
self.output_file.close()

pytorch_lightning/trainer/data_loading.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.utils.data.distributed import DistributedSampler
77

88
from pytorch_lightning.core import LightningModule
9-
from pytorch_lightning.utilities.debugging import MisconfigurationException
9+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1010

1111
try:
1212
from apex import amp

pytorch_lightning/trainer/distrib_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def train_fx(trial_hparams, cluster_manager, _):
122122
import torch
123123
from pytorch_lightning import _logger as log
124124
from pytorch_lightning.loggers import LightningLoggerBase
125-
from pytorch_lightning.utilities.debugging import MisconfigurationException
125+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
126126

127127
try:
128128
from apex import amp

pytorch_lightning/trainer/distrib_parts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@
344344
LightningDistributedDataParallel,
345345
LightningDataParallel,
346346
)
347-
from pytorch_lightning.utilities.debugging import MisconfigurationException
347+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
348348

349349
try:
350350
from apex import amp

pytorch_lightning/trainer/evaluation_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@
135135

136136
from pytorch_lightning.core.lightning import LightningModule
137137
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
138-
from pytorch_lightning.utilities.debugging import MisconfigurationException
138+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
139139

140140
try:
141141
import torch_xla.distributed.parallel_loader as xla_pl

pytorch_lightning/trainer/trainer.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
1919
from pytorch_lightning.core.lightning import LightningModule
2020
from pytorch_lightning.loggers import LightningLoggerBase
21-
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
22-
from pytorch_lightning.profiler.profiler import BaseProfiler
21+
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
2322
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
2423
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
2524
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
@@ -33,7 +32,7 @@
3332
from pytorch_lightning.trainer.training_io import TrainerIOMixin
3433
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
3534
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
36-
from pytorch_lightning.utilities.debugging import MisconfigurationException
35+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3736
from pytorch_lightning.trainer.supporters import TensorRunningMean
3837

3938
try:
@@ -364,7 +363,7 @@ def __init__(
364363

365364
# configure profiler
366365
if profiler is True:
367-
profiler = Profiler()
366+
profiler = SimpleProfiler()
368367
self.profiler = profiler or PassThroughProfiler()
369368

370369
# configure early stop callback
@@ -490,10 +489,10 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
490489
('print_nan_grads', (<class 'bool'>,), False),
491490
('process_position', (<class 'int'>,), 0),
492491
('profiler',
493-
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
492+
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
494493
<class 'NoneType'>),
495494
None),
496-
...
495+
...
497496
"""
498497
trainer_default_params = inspect.signature(cls).parameters
499498
name_type_default = []

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def training_step(self, batch, batch_idx):
145145
from pytorch_lightning.callbacks.base import Callback
146146
from pytorch_lightning.core.lightning import LightningModule
147147
from pytorch_lightning.loggers import LightningLoggerBase
148-
from pytorch_lightning.utilities.debugging import MisconfigurationException
148+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
149149
from pytorch_lightning.trainer.supporters import TensorRunningMean
150150

151151
try:

tests/loggers/test_comet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tests.base.utils as tutils
99
from pytorch_lightning import Trainer
1010
from pytorch_lightning.loggers import CometLogger
11-
from pytorch_lightning.utilities.debugging import MisconfigurationException
11+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1212
from tests.base import LightningTestModel
1313

1414

tests/models/test_amp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import tests.base.utils as tutils
77
from pytorch_lightning import Trainer
8-
from pytorch_lightning.utilities.debugging import MisconfigurationException
8+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
99
from tests.base import (
1010
LightningTestModel,
1111
)

tests/models/test_gpu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
parse_gpu_ids,
1212
determine_root_gpu_device,
1313
)
14-
from pytorch_lightning.utilities.debugging import MisconfigurationException
14+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1515
from tests.base import LightningTestModel
1616

1717
PRETEND_N_OF_GPUS = 16

tests/models/test_restore.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tests.base.utils as tutils
99
from pytorch_lightning import Trainer
1010
from pytorch_lightning.callbacks import ModelCheckpoint
11-
from pytorch_lightning.utilities.debugging import MisconfigurationException
11+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1212
from tests.base import (
1313
LightningTestModel,
1414
LightningTestModelWithoutHyperparametersArg,

tests/test_deprecated.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports():
5757
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
5858
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402
5959

60+
from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler # noqa: F402
61+
6062

6163
class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
6264

0 commit comments

Comments
 (0)