Skip to content

Commit e6a6e97

Browse files
eval/fit return dictionary
1 parent cc1fbac commit e6a6e97

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

Diff for: src/opendr/planning/end_to_end_planning/e2e_planning_learner.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def fit(self, env=None, logging_path='', silent=False, verbose=True):
7676
print('env should be gym.Env')
7777
return
7878
self.last_checkpoint_time_step = 0
79+
self.mean_reward = -10
7980
self.logdir = logging_path
8081
if isinstance(self.env, DummyVecEnv):
8182
self.env = self.env.envs[0]
@@ -85,6 +86,7 @@ def fit(self, env=None, logging_path='', silent=False, verbose=True):
8586
self.env = DummyVecEnv([lambda: self.env])
8687
self.agent.set_env(self.env)
8788
self.agent.learn(total_timesteps=self.iters, callback=self.callback)
89+
return {"last_20_episodes_mean_reward": self.mean_reward}
8890

8991
def eval(self, env):
9092
"""
@@ -108,7 +110,7 @@ def eval(self, env):
108110
sum_of_rewards += rewards
109111
if dones:
110112
break
111-
return sum_of_rewards
113+
return {"rewards_collected": sum_of_rewards}
112114

113115
def save(self, path):
114116
"""
@@ -161,14 +163,14 @@ def callback(self, _locals, _globals):
161163
x, y = ts2xy(load_results(self.logdir), 'timesteps')
162164

163165
if len(y) > 20:
164-
mean_reward = np.mean(y[-20:])
166+
self.mean_reward = np.mean(y[-20:])
165167
else:
166168
return True
167169

168170
if x[-1] - self.last_checkpoint_time_step > self.checkpoint_after_iter:
169171
self.last_checkpoint_time_step = x[-1]
170172
check_point_path = Path(self.logdir,
171-
'checkpoint_save' + str(x[-1]) + 'with_mean_rew' + str(mean_reward))
173+
'checkpoint_save' + str(x[-1]) + 'with_mean_rew' + str(self.mean_reward))
172174
self.save(str(check_point_path))
173175

174176
return True

Diff for: tests/sources/tools/planning/end_to_end_planning/test_end_to_end_planning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_infer(self):
5858
self.assertTrue((action < self.env.action_space.n), "Actions above discrete action space dimensions")
5959

6060
def test_eval(self):
61-
episode_reward = self.learner.eval(self.env)
61+
episode_reward = self.learner.eval(self.env)["rewards_collected"]
6262
self.assertTrue((episode_reward > -100), "Episode reward cannot be lower than -100")
6363
self.assertTrue((episode_reward < 100), "Episode reward cannot pass 100")
6464

0 commit comments

Comments
 (0)