Skip to content

Commit a4fb827

Browse files
williamFalcontullie
authored andcommitted
fixes test issues on ddp (Lightning-AI#1017)
* updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs * updated checkpoint docs
1 parent 1c086fa commit a4fb827

File tree

6 files changed

+72
-1
lines changed

6 files changed

+72
-1
lines changed

docs/source/weights_loading.rst

+22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@ Lightning can automate saving and loading checkpoints.
55

66
Checkpoint saving
77
-----------------
8+
A Lightning checkpoint has everything needed to restore a training session including:
9+
10+
- 16-bit scaling factor (apex)
11+
- Current epoch
12+
- Global step
13+
- Model state_dict
14+
- State of all optimizers
15+
- State of all learningRate schedulers
16+
- State of all callbacks
17+
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
18+
19+
Automatic saving
20+
^^^^^^^^^^^^^^^^
821

922
Checkpointing is enabled by default to the current working directory.
1023
To change the checkpoint path pass in:
@@ -59,6 +72,15 @@ The Lightning checkpoint also saves the hparams (hyperparams) passed into the Li
5972
def __init__(self, hparams, ...):
6073
self.hparams = hparams
6174
75+
Manual saving
76+
^^^^^^^^^^^^^
77+
78+
To save your own checkpoint call:
79+
80+
.. code-block:: python
81+
82+
model.save_checkpoint(PATH)
83+
6284
Checkpoint Loading
6385
------------------
6486

pytorch_lightning/trainer/callback_config.py

+7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ class TrainerCallbackConfigMixin(ABC):
1212
# the proper values/initialisation should be done in child class
1313
default_save_path: str
1414
logger: Union[LightningLoggerBase, bool]
15+
weights_save_path: str
16+
ckpt_path: str
17+
checkpoint_callback: ModelCheckpoint
1518

1619
@property
1720
@abstractmethod
@@ -29,6 +32,7 @@ def configure_checkpoint_callback(self):
2932
User provided weights_saved_path
3033
Otherwise use os.getcwd()
3134
"""
35+
ckpt_path = self.default_save_path
3236
if self.checkpoint_callback is True:
3337
# init a default one
3438
if self.logger is not None:
@@ -44,12 +48,15 @@ def configure_checkpoint_callback(self):
4448
else:
4549
ckpt_path = os.path.join(self.default_save_path, "checkpoints")
4650

51+
self.ckpt_path = ckpt_path
4752
self.checkpoint_callback = ModelCheckpoint(
4853
filepath=ckpt_path
4954
)
5055
elif self.checkpoint_callback is False:
5156
self.checkpoint_callback = None
5257

58+
self.ckpt_path = ckpt_path
59+
5360
if self.checkpoint_callback:
5461
# set the path for the callbacks
5562
self.checkpoint_callback.save_function = self.save_checkpoint

pytorch_lightning/trainer/distrib_data_parallel.py

+30
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class TrainerDDPMixin(ABC):
145145
use_amp: bool
146146
amp_level: str
147147
use_tpu: bool
148+
default_save_path: str
148149

149150
@property
150151
@abstractmethod
@@ -340,6 +341,35 @@ def ddp_train(self, gpu_idx, model):
340341
# continue training routine
341342
self.run_pretrain_routine(model)
342343

344+
# when ddp ends, we save the model
345+
self.save_spawn_weights(model)
346+
347+
def save_spawn_weights(self, model):
348+
"""
349+
Dump a temporary checkpoint after ddp ends to get weights out of the process
350+
:param model:
351+
:return:
352+
"""
353+
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
354+
self.save_checkpoint(path)
355+
356+
def load_spawn_weights(self, original_model):
357+
"""
358+
Load the temp weights saved in the process
359+
To recover the trained model from the ddp process we load the saved weights
360+
:param model:
361+
:return:
362+
"""
363+
# load weights saved in ddp
364+
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
365+
loaded_model = original_model.__class__.load_from_checkpoint(path)
366+
367+
# copy loaded weights to old model
368+
original_model.load_state_dict(loaded_model.state_dict())
369+
370+
# remove ddp weights
371+
os.remove(path)
372+
343373
def resolve_root_node_address(self, root_node):
344374
if '[' in root_node:
345375
name = root_node.split('[')[0]

pytorch_lightning/trainer/distrib_parts.py

+2
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,8 @@ def tpu_train(self, tpu_core_idx, model):
496496
log.info(m)
497497
self.run_pretrain_routine(model)
498498

499+
self.save_spawn_weights(model)
500+
499501
def dp_train(self, model):
500502

501503
# CHOOSE OPTIMIZER

pytorch_lightning/trainer/evaluation_loop.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,11 @@ def run_evaluation(self, test_mode: bool = False):
349349

350350
# log results of test
351351
if test_mode:
352-
model.print(prog_bar_metrics)
352+
if self.proc_rank == 0:
353+
print('-' * 100)
354+
print('TEST RESULTS')
355+
print(prog_bar_metrics)
356+
print('-' * 100)
353357

354358
# log metrics
355359
self.log_metrics(log_metrics, {})

pytorch_lightning/trainer/trainer.py

+6
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,8 @@ def fit(
960960
else:
961961
self.__set_random_port()
962962
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
963+
self.load_spawn_weights(model)
964+
self.model = model
963965

964966
# 1 gpu or dp option triggers training using DP module
965967
# easier to avoid NCCL issues
@@ -975,6 +977,8 @@ def fit(
975977
# COLAB_GPU is an env var available by default in Colab environments.
976978
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'
977979
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)
980+
self.load_spawn_weights(model)
981+
self.model = model
978982

979983
# ON CPU
980984
else:
@@ -1192,6 +1196,8 @@ def test(self, model: Optional[LightningModule] = None):
11921196
if model is not None:
11931197
self.model = model
11941198
self.fit(model)
1199+
elif self.model is not None and (self.use_ddp or self.use_tpu):
1200+
self.fit(self.model)
11951201
else:
11961202
self.run_evaluation(test_mode=True)
11971203

0 commit comments

Comments
 (0)