Skip to content

Commit 4755ded

Browse files
williamFalconBorda
andauthored
Clean up Argparse interface with trainer (#1606)
* fixed distutil parsing * fixed distutil parsing * Apply suggestions from code review * log * fixed distutil parsing * fixed distutil parsing * fixed distutil parsing * fixed distutil parsing * doctest * fixed hparams section * fixed hparams section * fixed hparams section * formatting Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: J. Borovec <[email protected]>
1 parent 13bf772 commit 4755ded

File tree

7 files changed

+141
-83
lines changed

7 files changed

+141
-83
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8787

8888
- Fixes automatic parser bug ([#1585](https://github.com/PyTorchLightning/pytorch-lightning/issues/1585))
8989

90+
- Fixed bool conversion from string ([#1606](https://github.com/PyTorchLightning/pytorch-lightning/issues/1606))
91+
9092
## [0.7.3] - 2020-04-09
9193

9294
### Added

docs/source/hyperparameters.rst

+109-74
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,111 @@ Hyperparameters
33
Lightning has utilities to interact seamlessly with the command line ArgumentParser
44
and plays well with the hyperparameter optimization framework of your choice.
55

6-
LightiningModule hparams
6+
ArgumentParser
7+
^^^^^^^^^^^^^^
8+
Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser
9+
10+
.. code-block:: python
11+
12+
from argparse import ArgumentParser
13+
14+
parser = ArgumentParser()
15+
parser.add_argument('--layer_1_dim', type=int, default=128)
16+
17+
args = parser.parse_args()
18+
19+
This allows you to call your program like so:
20+
21+
.. code-block:: bash
22+
23+
python trainer.py --layer_1_dim 64
24+
25+
26+
Argparser Best Practices
727
^^^^^^^^^^^^^^^^^^^^^^^^
28+
It is best practice to layer your arguments in three sections.
829

9-
Normally, we don't hard-code the values to a model. We usually use the command line to
10-
modify the network. The `Trainer` can add all the available options to an ArgumentParser.
30+
1. Trainer args (gpus, num_nodes, etc...)
31+
2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...)
32+
3. Program arguments (data_path, cluster_email, etc...)
33+
34+
We can do this as follows. First, in your LightningModule, define the arguments
35+
specific to that module. Remember that data splits or data paths may also be specific to
36+
a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10).
1137

1238
.. code-block:: python
1339
40+
class LitModel(LightningModule):
41+
42+
@staticmethod
43+
def add_model_specific_args(parent_parser):
44+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
45+
parser.add_argument('--encoder_layers', type=int, default=12)
46+
parser.add_argument('--data_path', type=str, default='/some/path')
47+
return parser
48+
49+
Now in your main trainer file, add the Trainer args, the program args, and add the model args
50+
51+
.. code-block:: python
52+
53+
# ----------------
54+
# trainer_main.py
55+
# ----------------
1456
from argparse import ArgumentParser
1557
1658
parser = ArgumentParser()
1759
18-
# parametrize the network
19-
parser.add_argument('--layer_1_dim', type=int, default=128)
20-
parser.add_argument('--layer_2_dim', type=int, default=256)
21-
parser.add_argument('--batch_size', type=int, default=64)
60+
# add PROGRAM level args
61+
parser.add_argument('--conda_env', type=str, default='some_name')
62+
parser.add_argument('--notification_email', type=str, default='[email protected]')
63+
64+
# add model specific args
65+
parser = LitModel.add_model_specific_args(parser)
2266
23-
# add all the available options to the trainer
67+
# add all the available trainer options to argparse
68+
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
2469
parser = pl.Trainer.add_argparse_args(parser)
2570
26-
args = parser.parse_args()
71+
hparams = parser.parse_args()
2772
28-
Now we can parametrize the LightningModule.
73+
Now you can call run your program like so
74+
75+
.. code-block:: bash
76+
77+
python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12
78+
79+
Finally, make sure to start the training like so:
80+
81+
.. code-block:: bash
82+
83+
hparams = parser.parse_args()
84+
85+
# YES
86+
model = LitModel(hparams)
87+
88+
# NO
89+
# model = LitModel(learning_rate=hparams.learning_rate, ...)
90+
91+
# YES
92+
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)
93+
94+
# NO
95+
trainer = Trainer(gpus=hparams.gpus, ...)
96+
97+
98+
LightiningModule hparams
99+
^^^^^^^^^^^^^^^^^^^^^^^^
100+
101+
Normally, we don't hard-code the values to a model. We usually use the command line to
102+
modify the network and read those values in the LightningModule
29103

30104
.. code-block:: python
31-
:emphasize-lines: 5,6,7,12,14
32105
33106
class LitMNIST(pl.LightningModule):
34107
def __init__(self, hparams):
35108
super().__init__()
109+
110+
# do this to save all arguments in any logger (tensorboard)
36111
self.hparams = hparams
37112
38113
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
@@ -49,86 +124,44 @@ Now we can parametrize the LightningModule.
49124
def configure_optimizers(self):
50125
return Adam(self.parameters(), lr=self.hparams.learning_rate)
51126
52-
hparams = parse_args()
53-
model = LitMNIST(hparams)
127+
@staticmethod
128+
def add_model_specific_args(parent_parser):
129+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
54130
55-
.. note:: Bonus! if (hparams) is in your module, Lightning will save it into the checkpoint and restore your
56-
model using those hparams exactly.
131+
parser.add_argument('--layer_1_dim', type=int, default=128)
132+
parser.add_argument('--layer_2_dim', type=int, default=256)
133+
parser.add_argument('--batch_size', type=int, default=64)
134+
parser.add_argument('--learning_rate', type=float, default=0.002)
135+
return parser
57136
58-
And we can also add all the flags available in the Trainer to the Argparser.
137+
Now pass in the params when you init your model
59138

60139
.. code-block:: python
61140
62-
# add all the available Trainer options to the ArgParser
63-
parser = pl.Trainer.add_argparse_args(parser)
64-
args = parser.parse_args()
65-
66-
And now you can start your program with
141+
hparams = parse_args()
142+
model = LitMNIST(hparams)
67143
68-
.. code-block:: bash
144+
The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule.
145+
This does two things:
69146

70-
# now you can use any trainer flag
71-
$ python main.py --num_nodes 2 --gpus 8
147+
1. It adds them automatically to tensorboard logs under the hparams tab.
148+
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.
72149

73150
Trainer args
74151
^^^^^^^^^^^^
75-
76-
It also gets annoying to map each argument into the Argparser. Luckily we have
77-
a default parser
152+
To recap, add ALL possible trainer flags to the argparser and init the Trainer this way
78153

79154
.. code-block:: python
80155
81156
parser = ArgumentParser()
82-
83-
# add all options available in the trainer such as (max_epochs, etc...)
84157
parser = Trainer.add_argparse_args(parser)
158+
hparams = parser.parse_args()
85159
86-
We set up the main training entry point file like this:
87-
88-
.. code-block:: python
89-
90-
def main(args):
91-
model = LitMNIST(hparams=args)
92-
trainer = Trainer(max_epochs=args.max_epochs)
93-
trainer.fit(model)
160+
trainer = Trainer.from_argparse_args(hparams)
94161
95-
if __name__ == '__main__':
96-
parser = ArgumentParser()
162+
# or if you need to pass in callbacks
163+
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])
97164
98-
# adds all the trainer options as default arguments (like max_epochs)
99-
parser = Trainer.add_argparse_args(parser)
100-
101-
# parametrize the network
102-
parser.add_argument('--layer_1_dim', type=int, default=128)
103-
parser.add_argument('--layer_1_dim', type=int, default=256)
104-
parser.add_argument('--batch_size', type=int, default=64)
105-
args = parser.parse_args()
106-
107-
# train
108-
main(args)
109-
110-
And now we can train like this:
111-
112-
.. code-block:: bash
113-
114-
$ python main.py --layer_1_dim 128 --layer_2_dim 256 --batch_size 64 --max_epochs 64
115-
116-
But it would also be nice to pass in any arbitrary argument to the trainer.
117-
We can do it by changing how we init the trainer.
118-
119-
.. code-block:: python
120-
121-
def main(args):
122-
model = LitMNIST(hparams=args)
123-
124-
# makes all trainer options available from the command line
125-
trainer = Trainer.from_argparse_args(args)
126-
127-
and now we can do this:
128-
129-
.. code-block:: bash
130-
131-
$ python main.py --gpus 1 --min_epochs 12 --max_epochs 64 --arbitrary_trainer_arg some_value
132165
133166
Multiple Lightning Modules
134167
^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -173,7 +206,7 @@ Now we can allow each model to inject the arguments it needs in the main.py
173206
model = LitMNIST(hparams=args)
174207
175208
model = LitMNIST(hparams=args)
176-
trainer = Trainer(max_epochs=args.max_epochs)
209+
trainer = Trainer.from_argparse_args(args)
177210
trainer.fit(model)
178211
179212
if __name__ == '__main__':
@@ -182,6 +215,8 @@ Now we can allow each model to inject the arguments it needs in the main.py
182215
183216
# figure out which model to use
184217
parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist')
218+
219+
# THIS LINE IS KEY TO PULL THE MODEL NAME
185220
temp_args = parser.parse_known_args()
186221
187222
# let the model add what it wants

pytorch_lightning/trainer/logging.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pytorch_lightning.core import memory
77
from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection
8-
from pytorch_lightning.utilities import memory_utils
8+
from pytorch_lightning.utilities.memory import recursive_detach
99

1010

1111
class TrainerLoggingMixin(ABC):
@@ -174,7 +174,7 @@ def process_output(self, output, train=False):
174174

175175
# detach all metrics for callbacks to prevent memory leaks
176176
# no .item() because it will slow things down
177-
callback_metrics = memory_utils.recursive_detach(callback_metrics)
177+
callback_metrics = recursive_detach(callback_metrics)
178178

179179
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
180180

pytorch_lightning/trainer/trainer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import distutils
21
import inspect
32
import os
43
from argparse import ArgumentParser
@@ -33,6 +32,7 @@
3332
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
3433
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3534
from pytorch_lightning.utilities import rank_zero_warn
35+
from pytorch_lightning.utilities import parsing
3636

3737

3838
try:
@@ -599,17 +599,19 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
599599
"""
600600
parser = ArgumentParser(parents=[parent_parser], add_help=False, )
601601

602-
depr_arg_names = cls.get_deprecated_arg_names()
602+
blacklist = ['kwargs']
603+
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
603604

604605
allowed_types = (str, float, int, bool)
606+
605607
# TODO: get "help" from docstring :)
606608
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
607609
if at[0] not in depr_arg_names):
608610

609611
for allowed_type in (at for at in allowed_types if at in arg_types):
610612
if allowed_type is bool:
611613
def allowed_type(x):
612-
return bool(distutils.util.strtobool(x))
614+
return bool(parsing.strtobool(x))
613615

614616
if arg == 'gpus':
615617
def allowed_type(x):
@@ -636,9 +638,11 @@ def arg_default(x):
636638
return parser
637639

638640
@classmethod
639-
def from_argparse_args(cls, args):
641+
def from_argparse_args(cls, args, **kwargs):
640642

641643
params = vars(args)
644+
params.update(**kwargs)
645+
642646
return cls(**params)
643647

644648
@property

pytorch_lightning/trainer/training_loop.py

-3
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def training_step(self, batch, batch_idx):
141141
142142
"""
143143

144-
import copy
145144
from abc import ABC, abstractmethod
146145
from typing import Callable
147146
from typing import Union, List
@@ -154,11 +153,9 @@ def training_step(self, batch, batch_idx):
154153
from pytorch_lightning.callbacks.base import Callback
155154
from pytorch_lightning.core.lightning import LightningModule
156155
from pytorch_lightning.loggers import LightningLoggerBase
157-
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
158156
from pytorch_lightning.utilities.exceptions import MisconfigurationException
159157
from pytorch_lightning.trainer.supporters import TensorRunningAccum
160158
from pytorch_lightning.utilities import rank_zero_warn
161-
from pytorch_lightning.utilities import memory_utils
162159

163160
try:
164161
from apex import amp
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
def strtobool(val):
2+
"""Convert a string representation of truth to true (1) or false (0).
3+
Copied from the python implementation distutils.utils.strtobool
4+
5+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
6+
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
7+
'val' is anything else.
8+
9+
>>> strtobool('YES')
10+
1
11+
>>> strtobool('FALSE')
12+
0
13+
"""
14+
val = val.lower()
15+
if val in ('y', 'yes', 't', 'true', 'on', '1'):
16+
return 1
17+
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
18+
return 0
19+
else:
20+
raise ValueError(f'invalid truth value {val}')

0 commit comments

Comments
 (0)