Skip to content

Commit 1cf430f

Browse files
new feature for profiling training runs (#782)
* initial implementation * formatting, pass through profiler, docstring * call profiler during training * add initial tests * report stats when training is done * fix formatting * error handling, bugfix in passthroughprofiler * finish documenting profiler arg in Trainer * relax required precision for profiling tests * option to dump cProfiler results to text file * use logging, format with black * include profiler in docs * improved logging and better docs * appease the linter * better summaries, wrapper for iterables * fix typo * allow profiler=True creation * more documentation * add tests for advanced profiler * Update trainer.py * make profilers accessible in pl.utilities * reorg profiler files * change import for profiler tests Co-authored-by: William Falcon <[email protected]>
1 parent 57074b3 commit 1cf430f

File tree

8 files changed

+423
-19
lines changed

8 files changed

+423
-19
lines changed

docs/source/common-cases.rst

+9-3
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@ gradient clipping
1313
modifying training via hooks
1414
=============================
1515

16-
17-
1816
.. toctree::
1917
:maxdepth: 3
2018

21-
pl_examples
19+
pl_examples
20+
21+
22+
profiling a training run
23+
========================
24+
.. toctree::
25+
:maxdepth: 1
26+
27+
profiler

docs/source/profiler.rst

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
5+
Profiling performance during training
6+
===========
7+
.. automodule:: pytorch_lightning.profiler
8+
:exclude-members:
9+
_abc_impl,
10+
summarize,
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""
2+
Profiling your training run can help you understand if there are any bottlenecks in your code.
3+
4+
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
5+
6+
- on_epoch_start
7+
- on_epoch_end
8+
- on_batch_start
9+
- tbptt_split_batch
10+
- model_forward
11+
- model_backward
12+
- on_after_backward
13+
- optimizer_step
14+
- on_batch_end
15+
- training_end
16+
- on_training_end
17+
18+
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
19+
your `Trainer` object.
20+
21+
.. code-block:: python
22+
23+
trainer = Trainer(..., profiler=True)
24+
25+
The profiler's results will be printed at the completion of a training `fit()`.
26+
27+
.. code-block:: python
28+
29+
Profiler Report
30+
31+
Action | Mean duration (s) | Total time (s)
32+
-----------------------------------------------------------------
33+
on_epoch_start | 5.993e-06 | 5.993e-06
34+
get_train_batch | 0.0087412 | 16.398
35+
on_batch_start | 5.0865e-06 | 0.0095372
36+
model_forward | 0.0017818 | 3.3408
37+
model_backward | 0.0018283 | 3.4282
38+
on_after_backward | 4.2862e-06 | 0.0080366
39+
optimizer_step | 0.0011072 | 2.0759
40+
on_batch_end | 4.5202e-06 | 0.0084753
41+
on_epoch_end | 3.919e-06 | 3.919e-06
42+
on_train_end | 5.449e-06 | 5.449e-06
43+
44+
45+
If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
46+
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
47+
48+
.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile
49+
50+
.. code-block:: python
51+
52+
profiler = AdvancedProfiler()
53+
trainer = Trainer(..., profiler=profiler)
54+
55+
The profiler's results will be printed at the completion of a training `fit()`. This profiler
56+
report can be quite long, so you can also specify an `output_filename` to save the report instead
57+
of logging it to the output in your terminal. The output below shows the profiling for the action
58+
`get_train_batch`.
59+
60+
.. code-block:: python
61+
62+
Profiler Report
63+
64+
Profile stats for: get_train_batch
65+
4869394 function calls (4863767 primitive calls) in 18.893 seconds
66+
Ordered by: cumulative time
67+
List reduced from 76 to 10 due to restriction <10>
68+
ncalls tottime percall cumtime percall filename:lineno(function)
69+
3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next}
70+
1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__)
71+
1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data)
72+
1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch)
73+
1875 0.084 0.000 18.290 0.010 fetch.py:44(<listcomp>)
74+
60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__)
75+
60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__)
76+
60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__)
77+
60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor)
78+
60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__)
79+
80+
You can also reference this profiler in your LightningModule to profile specific actions of interest.
81+
If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler`
82+
which will allow you to skip profiling without having to make any code changes. Each profiler has a
83+
method `profile()` which returns a context handler. Simply pass in the name of your action that you want
84+
to track and the profiler will record performance for code executed within this context.
85+
86+
.. code-block:: python
87+
88+
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
89+
90+
class MyModel(LightningModule):
91+
def __init__(self, hparams, profiler=None):
92+
self.hparams = hparams
93+
self.profiler = profiler or PassThroughProfiler()
94+
95+
def custom_processing_step(self, data):
96+
with profiler.profile('my_custom_action'):
97+
# custom processing step
98+
return data
99+
100+
profiler = Profiler()
101+
model = MyModel(hparams, profiler)
102+
trainer = Trainer(profiler=profiler, max_epochs=1)
103+
104+
"""
105+
106+
from .profiler import Profiler, AdvancedProfiler, PassThroughProfiler
107+
108+
__all__ = [
109+
'Profiler',
110+
'AdvancedProfiler',
111+
'PassThroughProfiler',
112+
]
+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from contextlib import contextmanager
2+
from collections import defaultdict
3+
import time
4+
import numpy as np
5+
import cProfile
6+
import pstats
7+
import io
8+
from abc import ABC, abstractmethod
9+
import logging
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class BaseProfiler(ABC):
15+
"""
16+
If you wish to write a custom profiler, you should inhereit from this class.
17+
"""
18+
19+
@abstractmethod
20+
def start(self, action_name):
21+
"""
22+
Defines how to start recording an action.
23+
"""
24+
pass
25+
26+
@abstractmethod
27+
def stop(self, action_name):
28+
"""
29+
Defines how to record the duration once an action is complete.
30+
"""
31+
pass
32+
33+
@contextmanager
34+
def profile(self, action_name):
35+
"""
36+
Yields a context manager to encapsulate the scope of a profiled action.
37+
38+
Example::
39+
40+
with self.profile('load training data'):
41+
# load training data code
42+
43+
The profiler will start once you've entered the context and will automatically
44+
stop once you exit the code block.
45+
"""
46+
try:
47+
self.start(action_name)
48+
yield action_name
49+
finally:
50+
self.stop(action_name)
51+
52+
def profile_iterable(self, iterable, action_name):
53+
iterator = iter(iterable)
54+
while True:
55+
try:
56+
self.start(action_name)
57+
value = next(iterator)
58+
self.stop(action_name)
59+
yield value
60+
except StopIteration:
61+
self.stop(action_name)
62+
break
63+
64+
def describe(self):
65+
"""
66+
Logs a profile report after the conclusion of the training run.
67+
"""
68+
pass
69+
70+
71+
class PassThroughProfiler(BaseProfiler):
72+
"""
73+
This class should be used when you don't want the (small) overhead of profiling.
74+
The Trainer uses this class by default.
75+
"""
76+
77+
def __init__(self):
78+
pass
79+
80+
def start(self, action_name):
81+
pass
82+
83+
def stop(self, action_name):
84+
pass
85+
86+
87+
class Profiler(BaseProfiler):
88+
"""
89+
This profiler simply records the duration of actions (in seconds) and reports
90+
the mean duration of each action and the total time spent over the entire training run.
91+
"""
92+
93+
def __init__(self):
94+
self.current_actions = {}
95+
self.recorded_durations = defaultdict(list)
96+
97+
def start(self, action_name):
98+
if action_name in self.current_actions:
99+
raise ValueError(
100+
f"Attempted to start {action_name} which has already started."
101+
)
102+
self.current_actions[action_name] = time.monotonic()
103+
104+
def stop(self, action_name):
105+
end_time = time.monotonic()
106+
if action_name not in self.current_actions:
107+
raise ValueError(
108+
f"Attempting to stop recording an action ({action_name}) which was never started."
109+
)
110+
start_time = self.current_actions.pop(action_name)
111+
duration = end_time - start_time
112+
self.recorded_durations[action_name].append(duration)
113+
114+
def describe(self):
115+
output_string = "\n\nProfiler Report\n"
116+
117+
def log_row(action, mean, total):
118+
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
119+
120+
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
121+
output_string += f"\n{'-' * 65}"
122+
for action, durations in self.recorded_durations.items():
123+
output_string += log_row(
124+
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
125+
)
126+
output_string += "\n"
127+
logger.info(output_string)
128+
129+
130+
class AdvancedProfiler(BaseProfiler):
131+
"""
132+
This profiler uses Python's cProfiler to record more detailed information about
133+
time spent in each function call recorded during a given action. The output is quite
134+
verbose and you should only use this if you want very detailed reports.
135+
"""
136+
137+
def __init__(self, output_filename=None, line_count_restriction=1.0):
138+
"""
139+
:param output_filename (str): optionally save profile results to file instead of printing
140+
to std out when training is finished.
141+
:param line_count_restriction (int|float): this can be used to limit the number of functions
142+
reported for each action. either an integer (to select a count of lines),
143+
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
144+
"""
145+
self.profiled_actions = {}
146+
self.output_filename = output_filename
147+
self.line_count_restriction = line_count_restriction
148+
149+
def start(self, action_name):
150+
if action_name not in self.profiled_actions:
151+
self.profiled_actions[action_name] = cProfile.Profile()
152+
self.profiled_actions[action_name].enable()
153+
154+
def stop(self, action_name):
155+
pr = self.profiled_actions.get(action_name)
156+
if pr is None:
157+
raise ValueError(
158+
f"Attempting to stop recording an action ({action_name}) which was never started."
159+
)
160+
pr.disable()
161+
162+
def describe(self):
163+
self.recorded_stats = {}
164+
for action_name, pr in self.profiled_actions.items():
165+
s = io.StringIO()
166+
sortby = pstats.SortKey.CUMULATIVE
167+
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby)
168+
ps.print_stats(self.line_count_restriction)
169+
self.recorded_stats[action_name] = s.getvalue()
170+
if self.output_filename is not None:
171+
# save to file
172+
with open(self.output_filename, "w") as f:
173+
for action, stats in self.recorded_stats.items():
174+
f.write(f"Profile stats for: {action}")
175+
f.write(stats)
176+
else:
177+
# log to standard out
178+
output_string = "\nProfiler Report\n"
179+
for action, stats in self.recorded_stats.items():
180+
output_string += f"\nProfile stats for: {action}\n{stats}"
181+
logger.info(output_string)

pytorch_lightning/trainer/evaluation_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def evaluate(self, model, dataloaders, max_batches, test=False):
212212
# bookkeeping
213213
outputs = []
214214

215-
# run training
215+
# run validation
216216
for dataloader_idx, dataloader in enumerate(dataloaders):
217217
dl_outputs = []
218218
for batch_idx, batch in enumerate(dataloader):

0 commit comments

Comments
 (0)