-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add resuming from specific checkpoint #516
Conversation
|
||
checkpoint_path = Path(checkpoint_path) | ||
if not checkpoint_path.exists(): | ||
return did_restore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can simply return True/False no need for extra variable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. I'll change the code accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified the code and this function no longer exists
@@ -93,6 +96,18 @@ def restore_state_if_checkpoint_exists(self, model): | |||
|
|||
return did_restore | |||
|
|||
def restore_state_from_checkpoint(self, checkpoint_path): | |||
did_restore = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add doc what data type the checkpoint_path
is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this function. Please review the updated code :)
did_restore = False | ||
|
||
checkpoint_path = Path(checkpoint_path) | ||
if not checkpoint_path.exists(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not work for str
as it is defined :param resume_from_checkpoint: str or os.PathLike object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint_path is what torch.load expects as input. It can be file-like object or str containing a file name.
https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load
I updated & simplified the code. There is no longer restore_state_from_checkpoint function. Trainer simply calls The type of resume_from_checkpoint parameter is what torch.load expects as input. It is "a file-like object (has to implement read(), :meth If both resume_from_checkpoint and last checkpoint in checkpoint_callback.filepath exist, Trainer restores checkpoint from resume_from_checkpoint. I chose this policy because resume_from_checkpoint is a more explicit request from user. |
@dreamgonfly looks great. Merging this. We need to add docs for it though |
Before submitting
What does this PR do?
Fixes #515
Did you have fun?
I solved my problem with this feature.