@@ -76,6 +76,7 @@ def fit(self, env=None, logging_path='', silent=False, verbose=True):
76
76
print ('env should be gym.Env' )
77
77
return
78
78
self .last_checkpoint_time_step = 0
79
+ self .mean_reward = - 10
79
80
self .logdir = logging_path
80
81
if isinstance (self .env , DummyVecEnv ):
81
82
self .env = self .env .envs [0 ]
@@ -85,6 +86,7 @@ def fit(self, env=None, logging_path='', silent=False, verbose=True):
85
86
self .env = DummyVecEnv ([lambda : self .env ])
86
87
self .agent .set_env (self .env )
87
88
self .agent .learn (total_timesteps = self .iters , callback = self .callback )
89
+ return {"last_20_episodes_mean_reward" : self .mean_reward }
88
90
89
91
def eval (self , env ):
90
92
"""
@@ -108,7 +110,7 @@ def eval(self, env):
108
110
sum_of_rewards += rewards
109
111
if dones :
110
112
break
111
- return sum_of_rewards
113
+ return { "rewards_collected" : sum_of_rewards }
112
114
113
115
def save (self , path ):
114
116
"""
@@ -161,14 +163,14 @@ def callback(self, _locals, _globals):
161
163
x , y = ts2xy (load_results (self .logdir ), 'timesteps' )
162
164
163
165
if len (y ) > 20 :
164
- mean_reward = np .mean (y [- 20 :])
166
+ self . mean_reward = np .mean (y [- 20 :])
165
167
else :
166
168
return True
167
169
168
170
if x [- 1 ] - self .last_checkpoint_time_step > self .checkpoint_after_iter :
169
171
self .last_checkpoint_time_step = x [- 1 ]
170
172
check_point_path = Path (self .logdir ,
171
- 'checkpoint_save' + str (x [- 1 ]) + 'with_mean_rew' + str (mean_reward ))
173
+ 'checkpoint_save' + str (x [- 1 ]) + 'with_mean_rew' + str (self . mean_reward ))
172
174
self .save (str (check_point_path ))
173
175
174
176
return True
0 commit comments