Skip to content

Commit 5655aa9

Browse files
andrewcohErvin T.
and
Ervin T.
authored
Prevent init normalize on --resume (#4463)
Co-authored-by: Ervin T. <[email protected]>
1 parent 4992283 commit 5655aa9

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

Diff for: ml-agents/mlagents/trainers/model_saver/tf_model_saver.py

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
9999
def _load_graph(
100100
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False
101101
) -> None:
102+
# This prevents normalizer init up from executing on load
103+
policy.first_normalization_update = False
102104
with policy.graph.as_default():
103105
logger.info(f"Loading model from {model_path}.")
104106
ckpt = tf.train.get_checkpoint_state(model_path)

Diff for: ml-agents/mlagents/trainers/tests/test_saver.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from mlagents.tf_utils import tf
99
from mlagents.trainers.model_saver.tf_model_saver import TFModelSaver
1010
from mlagents.trainers import __version__
11-
from mlagents.trainers.settings import TrainerSettings
11+
from mlagents.trainers.settings import TrainerSettings, NetworkSettings
1212
from mlagents.trainers.policy.tf_policy import TFPolicy
1313
from mlagents.trainers.tests import mock_brain as mb
1414
from mlagents.trainers.tests.test_nn_policy import create_policy_mock
15+
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory
1516
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
1617

1718

@@ -113,3 +114,66 @@ def test_checkpoint_conversion(tmpdir, rnn, visual, discrete):
113114
model_saver.register(policy)
114115
model_saver.save_checkpoint("Mock_Brain", 100)
115116
assert os.path.isfile(model_path + "/Mock_Brain-100.nn")
117+
118+
119+
# This is the normalizer test from test_nn_policy.py but with a load
120+
def test_normalizer_after_load(tmp_path):
121+
behavior_spec = mb.setup_test_behavior_specs(
122+
use_discrete=True, use_visual=False, vector_action_space=[2], vector_obs_space=1
123+
)
124+
time_horizon = 6
125+
trajectory = make_fake_trajectory(
126+
length=time_horizon,
127+
max_step_complete=True,
128+
observation_shapes=[(1,)],
129+
action_space=[2],
130+
)
131+
# Change half of the obs to 0
132+
for i in range(3):
133+
trajectory.steps[i].obs[0] = np.zeros(1, dtype=np.float32)
134+
135+
trainer_params = TrainerSettings(network_settings=NetworkSettings(normalize=True))
136+
policy = TFPolicy(0, behavior_spec, trainer_params)
137+
138+
trajectory_buffer = trajectory.to_agentbuffer()
139+
policy.update_normalization(trajectory_buffer["vector_obs"])
140+
141+
# Check that the running mean and variance is correct
142+
steps, mean, variance = policy.sess.run(
143+
[policy.normalization_steps, policy.running_mean, policy.running_variance]
144+
)
145+
146+
assert steps == 6
147+
assert mean[0] == 0.5
148+
assert variance[0] / steps == pytest.approx(0.25, abs=0.01)
149+
# Save ckpt and load into another policy
150+
path1 = os.path.join(tmp_path, "runid1")
151+
model_saver = TFModelSaver(trainer_params, path1)
152+
model_saver.register(policy)
153+
mock_brain_name = "MockBrain"
154+
model_saver.save_checkpoint(mock_brain_name, 6)
155+
assert len(os.listdir(tmp_path)) > 0
156+
policy1 = TFPolicy(0, behavior_spec, trainer_params)
157+
model_saver = TFModelSaver(trainer_params, path1, load=True)
158+
model_saver.register(policy1)
159+
model_saver.initialize_or_load(policy1)
160+
161+
# Make another update to new policy, this time with all 1's
162+
time_horizon = 10
163+
trajectory = make_fake_trajectory(
164+
length=time_horizon,
165+
max_step_complete=True,
166+
observation_shapes=[(1,)],
167+
action_space=[2],
168+
)
169+
trajectory_buffer = trajectory.to_agentbuffer()
170+
policy1.update_normalization(trajectory_buffer["vector_obs"])
171+
172+
# Check that the running mean and variance is correct
173+
steps, mean, variance = policy1.sess.run(
174+
[policy1.normalization_steps, policy1.running_mean, policy1.running_variance]
175+
)
176+
177+
assert steps == 16
178+
assert mean[0] == 0.8125
179+
assert variance[0] / steps == pytest.approx(0.152, abs=0.01)

0 commit comments

Comments
 (0)