Skip to content

Commit 2a9132b

Browse files
committed
2 parents 1506d52 + 001856d commit 2a9132b

File tree

6 files changed

+56
-47
lines changed

6 files changed

+56
-47
lines changed

.github/workflows/ci-testing.yml

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: CI testing
22

3-
# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
3+
# https://help.github.com/en/actions/reference/events-that-trigger-workflows
44
on:
55
# Trigger the workflow on push or pull request,
66
# but only for the master branch
@@ -63,12 +63,12 @@ jobs:
6363
# Note: This uses an internal pip API and may not always work
6464
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
6565
- name: Get pip cache
66-
id: pip-cache
67-
run: |
68-
python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)"
66+
id: pip-cache
67+
run: |
68+
python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)"
6969
7070
- name: Cache pip
71-
- uses: actions/cache@v1
71+
uses: actions/cache@v1
7272
with:
7373
path: ${{ steps.pip-cache.outputs.dir }}
7474
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-extra.txt') }}
@@ -116,7 +116,6 @@ jobs:
116116
python setup.py check --metadata --strict
117117
python setup.py sdist
118118
twine check dist/*
119-
120119
#- name: Try install package
121120
# if: ! startsWith(matrix.os, 'windows')
122121
# run: |
@@ -127,4 +126,4 @@ jobs:
127126
- name: Statistics
128127
if: success()
129128
run: |
130-
coverage report
129+
coverage report

pytorch_lightning/trainer/data_loading.py

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
133133
world_size = {
134134
'ddp': self.num_nodes * self.num_processes,
135135
'ddp2': self.num_nodes,
136+
'ddp_cpu': self.num_processes * self.num_nodes
136137
}
137138
sampler = DistributedSampler(
138139
dataloader.dataset,

pytorch_lightning/trainer/trainer.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -614,17 +614,8 @@ def allowed_type(x):
614614
return bool(parsing.strtobool(x))
615615

616616
if arg == 'gpus':
617-
def allowed_type(x):
618-
if ',' in x:
619-
return str(x)
620-
else:
621-
return int(x)
622-
623-
def arg_default(x):
624-
if ',' in x:
625-
return str(x)
626-
else:
627-
return int(x)
617+
allowed_type = Trainer.allowed_type
618+
arg_default = Trainer.arg_default
628619

629620
parser.add_argument(
630621
f'--{arg}',
@@ -637,6 +628,18 @@ def arg_default(x):
637628

638629
return parser
639630

631+
def allowed_type(x):
632+
if ',' in x:
633+
return str(x)
634+
else:
635+
return int(x)
636+
637+
def arg_default(x):
638+
if ',' in x:
639+
return str(x)
640+
else:
641+
return int(x)
642+
640643
@classmethod
641644
def from_argparse_args(cls, args, **kwargs):
642645

@@ -711,6 +714,10 @@ def fit(
711714
model.logger = self.logger
712715
self.copy_trainer_model_properties(model)
713716

717+
# clean hparams
718+
if hasattr(model, 'hparams'):
719+
parsing.clean_namespace(model.hparams)
720+
714721
# set up the passed in dataloaders (if needed)
715722
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
716723

pytorch_lightning/trainer/training_io.py

+2-27
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
LightningDistributedDataParallel,
102102
LightningDataParallel,
103103
)
104-
from pytorch_lightning.utilities import rank_zero_warn
104+
from pytorch_lightning.utilities import rank_zero_warn, parsing
105105

106106
try:
107107
import torch_xla
@@ -325,7 +325,7 @@ def dump_checkpoint(self):
325325
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
326326

327327
if hasattr(model, "hparams"):
328-
self.__clean_namespace(model.hparams)
328+
parsing.clean_namespace(model.hparams)
329329
is_namespace = isinstance(model.hparams, Namespace)
330330
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
331331
checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict'
@@ -339,31 +339,6 @@ def dump_checkpoint(self):
339339

340340
return checkpoint
341341

342-
def __clean_namespace(self, hparams):
343-
"""
344-
Removes all functions from hparams so we can pickle
345-
:param hparams:
346-
:return:
347-
"""
348-
349-
if isinstance(hparams, Namespace):
350-
del_attrs = []
351-
for k in hparams.__dict__:
352-
if callable(getattr(hparams, k)):
353-
del_attrs.append(k)
354-
355-
for k in del_attrs:
356-
delattr(hparams, k)
357-
358-
elif isinstance(hparams, dict):
359-
del_attrs = []
360-
for k, v in hparams.items():
361-
if callable(v):
362-
del_attrs.append(k)
363-
364-
for k in del_attrs:
365-
del hparams[k]
366-
367342
# --------------------
368343
# HPC IO
369344
# --------------------

pytorch_lightning/utilities/parsing.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from argparse import Namespace
2+
3+
14
def strtobool(val):
25
"""Convert a string representation of truth to true (1) or false (0).
36
Copied from the python implementation distutils.utils.strtobool
@@ -18,3 +21,29 @@ def strtobool(val):
1821
return 0
1922
else:
2023
raise ValueError(f'invalid truth value {val}')
24+
25+
26+
def clean_namespace(hparams):
27+
"""
28+
Removes all functions from hparams so we can pickle
29+
:param hparams:
30+
:return:
31+
"""
32+
33+
if isinstance(hparams, Namespace):
34+
del_attrs = []
35+
for k in hparams.__dict__:
36+
if callable(getattr(hparams, k)):
37+
del_attrs.append(k)
38+
39+
for k in del_attrs:
40+
delattr(hparams, k)
41+
42+
elif isinstance(hparams, dict):
43+
del_attrs = []
44+
for k, v in hparams.items():
45+
if callable(v):
46+
del_attrs.append(k)
47+
48+
for k in del_attrs:
49+
del hparams[k]

tests/trainer/test_trainer_cli.py

-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ def test_add_argparse_args_redefined(cli_args):
5151
assert depr_name not in args
5252

5353
trainer = Trainer.from_argparse_args(args=args)
54-
55-
# make sure we can pickle trainer
5654
pickle.dumps(trainer)
5755

5856
assert isinstance(trainer, Trainer)

0 commit comments

Comments
 (0)