Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem resuming training with RectifiedAdam+Lookahead (Ranger) #1911

Closed
gtg740x opened this issue Jun 4, 2020 · 4 comments
Closed

Problem resuming training with RectifiedAdam+Lookahead (Ranger) #1911

gtg740x opened this issue Jun 4, 2020 · 4 comments

Comments

@gtg740x
Copy link

gtg740x commented Jun 4, 2020

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04
  • TensorFlow version and how it was installed (source or binary): Tensorflow 2.1 from official Tensorflow Docker container
  • TensorFlow-Addons version and how it was installed (source or binary): 0.9.1
  • Python version: 3.6.9
  • Is GPU used? (yes/no): Yes.

Describe the bug

If I train a model using the Ranger scheme of a RectifiedAdam optimizer paired with a LookAhead optimizer, I cannot interrupt and resume training as normal. Using the exact same code, but with a standard Adam optimizer training resumes as expected. When using the Ranger scheme, training does not resume as expected.

When we resume training, the model restores to the same accuracy it paused at. But once training steps resume the accuracy curve will drop for many training steps before slowly moving back to where its upward progression was trending before pausing training. The result is a much slower and choppier convergence than a run where the experiment is never paused.

If the Ranger setup is used for the full run, training progresses smoothly as expected and converges to the expected accuracy in the expected number of steps smoothly.

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

# Setup a model
model = tf.keras.Model(...)

# Setting up the optimizer:
optimizer = tfa.optimizers.RectifiedAdam(lr=learning_rate, total_steps=max_train_steps,
warmup_proportion=0.1, min_lr=min_learning_rate)
optimizer = tfa.optimizers.Lookahead(optimizer, sync_period=6, slow_step_size=0.5)

# Create the checkpoint manager
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)

total_steps = 0
for i in range(total_steps):
    total_steps += 1
    # run a generic train step on the model using the optimizer
    train_step()
    if i % 1000 == 0:
        ckpt_save_path = ckpt_manager.save(total_steps)

Run training and stop at a global_step > 1000 but < max_train_steps

Then when I try to resume training from a saved model:

# Setup a model:
model = tf.keras.Model(...)

# Setup the optimizer:
optimizer = tfa.optimizers.RectifiedAdam(lr=learning_rate, total_steps=max_train_steps,
warmup_proportion=0.1, min_lr=min_learning_rate)
optimizer = tfa.optimizers.Lookahead(optimizer, sync_period=6, slow_step_size=0.5)

# Restore the model:
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)
status = ckpt.restore(ckpt_manager.latest_checkpoint)

And re-enter the training loop above except with total_steps starting at the restored number of steps, the model restores to its previous accuracy. However, as soon as training steps resume there is an immediate dip in accuracy as if the optimizer has to "warm-up" again. Possibly due to the LookAhead slow weights?

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@gtg740x gtg740x changed the title Problem resuming training with RAdam+Lookahead (RANGER) Problem resuming training with RectifiedAdam+Lookahead (Ranger) Jun 4, 2020
@steven5401
Copy link

I have exactly same problem as you. After using

status = ckpt.restore(ckpt_manager.latest_checkpoint)
status.assert_consumed()

It shows the following log.

Traceback (most recent call last):
  File "main.py", line 69, in <module>
    app.run(main)
  File "/home/euphoria_yang/.conda/envs/tf2/lib/python3.7/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/euphoria_yang/.conda/envs/tf2/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "main.py", line 66, in main
    monitor.run()
  File "/home/euphoria_yang/Desktop/plane_detection/mvs-tf2/monitor/trainer.py", line 177, in run
    restore_status.assert_consumed()
  File "/home/euphoria_yang/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py", line 722, in assert_consumed
    .format(pretty_printer.node_names[node_id], node))
AssertionError: Unresolved object in checkpoint (root).optimizer.iter: attributes {
  name: "VARIABLE_VALUE"
  checkpoint_key: "optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE"
}

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
W0617 10:41:08.183369 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-1.gamma2
W0617 10:41:08.183589 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-1.gamma2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-1.ell
W0617 10:41:08.183701 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-1.ell
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-1.sigma2
W0617 10:41:08.183793 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-1.sigma2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.conv.conv.kernel
W0617 10:41:08.183856 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.conv.norm.gamma
W0617 10:41:08.183921 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.conv.norm.beta
W0617 10:41:08.183984 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.downscale.conv.kernel
W0617 10:41:08.184047 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.downscale.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.downscale.norm.gamma
W0617 10:41:08.184109 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.downscale.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.downscale.norm.beta
W0617 10:41:08.184170 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv1.downscale.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.conv.conv.kernel
W0617 10:41:08.184232 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.conv.norm.gamma
W0617 10:41:08.184294 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.conv.norm.beta
W0617 10:41:08.184356 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.downscale.conv.kernel
W0617 10:41:08.184422 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.downscale.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.downscale.norm.gamma
W0617 10:41:08.184482 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.downscale.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.downscale.norm.beta
W0617 10:41:08.184522 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv2.downscale.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.conv.conv.kernel
W0617 10:41:08.184558 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.conv.norm.gamma
W0617 10:41:08.184594 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.conv.norm.beta
W0617 10:41:08.184630 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.downscale.conv.kernel
W0617 10:41:08.184666 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.downscale.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.downscale.norm.gamma
W0617 10:41:08.184701 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.downscale.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.downscale.norm.beta
W0617 10:41:08.184736 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv3.downscale.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.conv.conv.kernel
W0617 10:41:08.184772 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.conv.norm.gamma
W0617 10:41:08.184807 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.conv.norm.beta
W0617 10:41:08.184842 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.downscale.conv.kernel
W0617 10:41:08.184877 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.downscale.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.downscale.norm.gamma
W0617 10:41:08.184912 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.downscale.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.downscale.norm.beta
W0617 10:41:08.184947 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv4.downscale.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.conv.conv.kernel
W0617 10:41:08.184982 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.conv.norm.gamma
W0617 10:41:08.185016 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.conv.norm.beta
W0617 10:41:08.185051 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.downscale.conv.kernel
W0617 10:41:08.185087 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.downscale.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.downscale.norm.gamma
W0617 10:41:08.185122 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.downscale.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.downscale.norm.beta
W0617 10:41:08.185157 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-0.conv5.downscale.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv5.conv.conv.kernel
W0617 10:41:08.185191 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv5.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv5.conv.norm.gamma
W0617 10:41:08.185226 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv5.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv5.conv.norm.beta
W0617 10:41:08.185261 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv5.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv5.conv.conv.kernel
W0617 10:41:08.185296 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv5.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv5.conv.norm.gamma
W0617 10:41:08.185330 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv5.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv5.conv.norm.beta
W0617 10:41:08.185365 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv5.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv4.conv.conv.kernel
W0617 10:41:08.185401 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv4.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv4.conv.norm.gamma
W0617 10:41:08.185436 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv4.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv4.conv.norm.beta
W0617 10:41:08.185471 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv4.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv4.conv.conv.kernel
W0617 10:41:08.185506 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv4.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv4.conv.norm.gamma
W0617 10:41:08.185541 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv4.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv4.conv.norm.beta
W0617 10:41:08.185576 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv4.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp4.conv.conv.kernel
W0617 10:41:08.185611 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp4.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp4.conv.conv.bias
W0617 10:41:08.185646 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp4.conv.conv.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv3.conv.conv.kernel
W0617 10:41:08.185681 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv3.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv3.conv.norm.gamma
W0617 10:41:08.185716 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv3.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv3.conv.norm.beta
W0617 10:41:08.185751 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv3.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv3.conv.conv.kernel
W0617 10:41:08.185786 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv3.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv3.conv.norm.gamma
W0617 10:41:08.185821 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv3.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv3.conv.norm.beta
W0617 10:41:08.185857 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv3.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp3.conv.conv.kernel
W0617 10:41:08.185892 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp3.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp3.conv.conv.bias
W0617 10:41:08.185927 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp3.conv.conv.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv2.conv.conv.kernel
W0617 10:41:08.185962 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv2.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv2.conv.norm.gamma
W0617 10:41:08.185997 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv2.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv2.conv.norm.beta
W0617 10:41:08.186032 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv2.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv2.conv.conv.kernel
W0617 10:41:08.186066 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv2.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv2.conv.norm.gamma
W0617 10:41:08.186101 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv2.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv2.conv.norm.beta
W0617 10:41:08.186137 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv2.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp2.conv.conv.kernel
W0617 10:41:08.186172 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp2.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp2.conv.conv.bias
W0617 10:41:08.186207 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp2.conv.conv.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv1.conv.conv.kernel
W0617 10:41:08.186242 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv1.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv1.conv.norm.gamma
W0617 10:41:08.186277 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv1.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv1.conv.norm.beta
W0617 10:41:08.186311 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.upconv1.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv1.conv.conv.kernel
W0617 10:41:08.186347 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv1.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv1.conv.norm.gamma
W0617 10:41:08.186381 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv1.conv.norm.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv1.conv.norm.beta
W0617 10:41:08.186416 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.iconv1.conv.norm.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp1.conv.conv.kernel
W0617 10:41:08.186451 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp1.conv.conv.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp1.conv.conv.bias
W0617 10:41:08.186486 140164126582528 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'slow' for (root).model.layer_with_weights-2.disp1.conv.conv.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
W0617 10:41:08.193641 140164126582528 util.py:152] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

It seems that lookahead does not maintain its slot "slow" well. @gtg740x Can you check the log of status.assert_consumed() ?

@bhack
Copy link
Contributor

bhack commented Jul 22, 2020

/cc @CyberZHG

@bhack
Copy link
Contributor

bhack commented Aug 27, 2020

Can you check with #2126?

@seanpmorgan
Copy link
Member

TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision:
TensorFlow Addons Wind Down

Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA:
Keras
Keras-CV
Keras-NLP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants