@@ -43,7 +43,7 @@ def __init__(self, trainer):
43
43
# used to validate checkpointing logic
44
44
self .has_trained = False
45
45
46
- def restore_weights (self , model : LightningModule ):
46
+ def restore_weights (self , model : LightningModule ) -> None :
47
47
"""
48
48
Attempt to restore a checkpoint (e.g. weights) in this priority:
49
49
1. from HPC weights
@@ -73,11 +73,16 @@ def restore_weights(self, model: LightningModule):
73
73
if self .trainer .on_gpu :
74
74
torch .cuda .empty_cache ()
75
75
76
- def restore (self , checkpoint_path : str , on_gpu : bool ):
76
+ def restore (self , checkpoint_path : str , on_gpu : bool ) -> bool :
77
77
"""
78
78
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
79
79
All restored states are listed in return value description of `dump_checkpoint`.
80
80
"""
81
+ # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
82
+ fs = get_filesystem (checkpoint_path )
83
+ if not fs .exists (checkpoint_path ):
84
+ rank_zero_warn ("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch" )
85
+ return False
81
86
82
87
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
83
88
checkpoint = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
@@ -94,6 +99,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
94
99
# restore training state
95
100
self .restore_training_state (checkpoint )
96
101
102
+ rank_zero_info (f"Restored states from the checkpoint file at { checkpoint_path } " )
103
+ return True
104
+
97
105
def restore_model_state (self , model : LightningModule , checkpoint ) -> None :
98
106
"""
99
107
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
0 commit comments