-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes resuming checkpoints rerunning last epoch #866
Fixes resuming checkpoints rerunning last epoch #866
Conversation
Hello @MattPainter01! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-02-19 15:57:39 UTC |
@@ -307,8 +307,8 @@ def restore(self, checkpoint_path, on_gpu): | |||
def dump_checkpoint(self): | |||
|
|||
checkpoint = { | |||
'epoch': self.current_epoch, | |||
'global_step': self.global_step | |||
'epoch': self.current_epoch + 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that checkpoint can be saved not only at the end of the training epoch. For example, if you set val_check_interval=0.1
and after 0.15
of the training batches the training was interrupted you will continue from the second epoch whereas only 10%
of the first epoch actually was processed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, I'll look into better ways of dealing with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at it, I think resuming in such a way is probably a new feature, which should go in a new PR and have more discussion there. I've updated the test to allow for testing mid-epoch check pointing / resuming but commented out the checkpoints from mid-epochs so as to pass with the current resume method. When it is properly implemented we can use the full test.
For now I've added a warning if you load a mid-epoch checkpoint to alert the user that it will be unreliable to resume.
# Deals with peculiarity of different global step for odd vs even num_training_batches | ||
if abs((self.global_step + 1) % self.num_training_batches) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this known?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the abs
looks suspicious...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check should be !=0
if the global step matched for odd vs even number of training steps, so if that's worked out then we shouldn't need it. Unless it's intended, in which case I think just (self.global_step + 1) % self.num_training_batches
is sufficient since global step matches num_training_batches in the even case.
Thinking about it, I should probably just remove the abs, it is fine without.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the abs
and made sure it handled accumulated batches properly. Can you think of anything else that might change the global step?
Currently the only test that throws this warning now is test_restore_models/test_dp_resume
since this changes the percentage of train data used when resuming in a new trainer. Not much we can about that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚀 just check the update on callbacks from #776
@MattPainter01 welcome! |
My bad, I hadn't ran the ddp tests some of which have 0 training batches. I've put in a check to skip the warning when we have no training batches. Passes all the tests on my machine now, except the slurm tests which I can't run. |
* Properly restore current epoch and global step on resume * Add test * Move increment to saving rather than loading * Fix other tests that refer to current epoch * Formatting * Add warning for mid-epoch resuming * Formatting * Fix warning check for accumulated batches * Add variable to init * Formatting * Add check for 0 training steps * Make check more readable
Fixes #850