@@ -42,7 +42,7 @@ def __init__(self, trainer):
42
42
# used to validate checkpointing logic
43
43
self .has_trained = False
44
44
45
- def restore_weights (self , model : LightningModule ):
45
+ def restore_weights (self , model : LightningModule ) -> None :
46
46
"""
47
47
Attempt to restore a checkpoint (e.g. weights) in this priority:
48
48
1. from HPC weights
@@ -72,11 +72,16 @@ def restore_weights(self, model: LightningModule):
72
72
if self .trainer .on_gpu :
73
73
torch .cuda .empty_cache ()
74
74
75
- def restore (self , checkpoint_path : str , on_gpu : bool ):
75
+ def restore (self , checkpoint_path : str , on_gpu : bool ) -> bool :
76
76
"""
77
77
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
78
78
All restored states are listed in return value description of `dump_checkpoint`.
79
79
"""
80
+ # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
81
+ fs = get_filesystem (checkpoint_path )
82
+ if not fs .exists (checkpoint_path ):
83
+ rank_zero_warn ("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch" )
84
+ return False
80
85
81
86
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
82
87
checkpoint = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
@@ -93,6 +98,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
93
98
# restore training state
94
99
self .restore_training_state (checkpoint )
95
100
101
+ rank_zero_info (f"Restored states from the checkpoint file at { checkpoint_path } " )
102
+ return True
103
+
96
104
def restore_model_state (self , model : LightningModule , checkpoint ) -> None :
97
105
"""
98
106
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
0 commit comments