Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start accumulate gradients schedule at epoch 0 (continued) #2513

Merged
merged 7 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {1: accumulate_grad_batches}
schedule = {0: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_accumulation_and_early_stopping(tmpdir):
'Learning rate was not altered after running learning rate finder'
assert len(lrfinder.results['lr']) == 100, \
'Early stopping for learning rate finder did not work'
assert lrfinder._total_batch_idx == 190, \
assert lrfinder._total_batch_idx == 100 * 2, \
'Accumulation parameter did not work'


Expand Down
59 changes: 31 additions & 28 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
model_2.eval()


def test_gradient_accumulation_scheduling(tmpdir):
@pytest.mark.parametrize('schedule_expected', [({1: 2, 3: 4}, [1, 2, 4]), (3, [3, 3, 3]), (4, [4, 4, 4])])
def test_gradient_accumulation_scheduling(tmpdir, schedule_expected):
"""
Test grad accumulation by the freq of optimizer updates
"""
Expand All @@ -123,59 +124,61 @@ def test_gradient_accumulation_scheduling(tmpdir):
with pytest.raises(TypeError):
assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

model = EvalModelTemplate()
schedule = schedule_expected[0]
expected = schedule_expected[1]

trainer = Trainer(accumulate_grad_batches=schedule,
limit_train_batches=0.8,
limit_val_batches=0.8,
max_epochs=4,
default_root_dir=tmpdir)

# test optimizer call freq matches scheduler
def _optimizer_step(self, epoch, batch_idx, optimizer,
optimizer_idx, second_order_closure=None):
def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
second_order_closure=None, on_tpu=False,
using_native_amp=False, using_lbfgs=False):
# only test the first 12 batches in epoch
if batch_idx < 12:
if epoch == 0:
# reset counter when starting epoch
if batch_idx == 0:
self.prev_called_batch_idx = 0
if batch_idx == expected[0] - 1:
model.prev_called_batch_idx = expected[0] - 1

# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 1
assert trainer.accumulate_grad_batches == expected[0]

assert batch_idx == self.prev_called_batch_idx
self.prev_called_batch_idx += 1
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[0]

elif 1 <= epoch <= 2:
# reset counter when starting epoch
if batch_idx == 1:
self.prev_called_batch_idx = 1
if batch_idx == expected[1] - 1:
model.prev_called_batch_idx = expected[1] - 1

# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 2
assert trainer.accumulate_grad_batches == expected[1]

assert batch_idx == self.prev_called_batch_idx
self.prev_called_batch_idx += 2
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[1]

else:
if batch_idx == 3:
self.prev_called_batch_idx = 3
if batch_idx == expected[2] - 1:
model.prev_called_batch_idx = expected[2] - 1

# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 4
assert trainer.accumulate_grad_batches == expected[2]

assert batch_idx == self.prev_called_batch_idx
self.prev_called_batch_idx += 3
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[2]

optimizer.step()

# clear gradients
optimizer.zero_grad()

model = EvalModelTemplate()
schedule = {1: 2, 3: 4}

trainer = Trainer(accumulate_grad_batches=schedule,
limit_train_batches=0.1,
limit_val_batches=0.1,
max_epochs=2,
default_root_dir=tmpdir)

# for the test
trainer.optimizer_step = _optimizer_step
model.optimizer_step = _optimizer_step
model.prev_called_batch_idx = 0

trainer.fit(model)
Expand Down