14
14
"""
15
15
16
16
import dataclasses
17
- from typing import Any , Callable , Dict , List , Mapping
17
+ from typing import Any , Callable , Dict , List , Mapping , Optional
18
18
19
19
import optuna
20
20
import sacred
21
- import stable_baselines3 as sb3
22
21
23
22
import imitation .scripts .train_imitation
24
- import imitation .scripts .train_preference_comparisons
23
+ import imitation .scripts .train_preference_comparisons as train_pc_script
25
24
26
25
27
26
@dataclasses .dataclass
@@ -42,19 +41,27 @@ class RunSacredAsTrial:
42
41
suggest_config_updates : Callable [[optuna .Trial ], Mapping [str , Any ]]
43
42
44
43
"""Command name to pass to sacred.run."""
45
- command_name : str = None
44
+ command_name : Optional [ str ] = None
46
45
47
46
def __call__ (
48
- self , trial : optuna .Trial , run_options : Dict , extra_named_configs : List [str ]
47
+ self ,
48
+ trial : optuna .Trial ,
49
+ run_options : Dict ,
50
+ extra_named_configs : List [str ],
49
51
) -> float :
50
52
"""Run the sacred experiment and return the performance.
51
53
52
54
Args:
53
55
trial: The optuna trial to sample hyperparameters for.
54
56
run_options: Options to pass to sacred.run(options=).
55
57
extra_named_configs: Additional named configs to pass to sacred.run.
56
- """
57
58
59
+ Returns:
60
+ The performance of the trial.
61
+
62
+ Raises:
63
+ RuntimeError: If the trial fails.
64
+ """
58
65
config_updates = self .suggest_config_updates (trial )
59
66
named_configs = self .suggest_named_configs (trial ) + extra_named_configs
60
67
@@ -71,15 +78,16 @@ def __call__(
71
78
)
72
79
if result .status != "COMPLETED" :
73
80
raise RuntimeError (
74
- f"Trial failed with { result .fail_trace ()} and status { result .status } ."
81
+ f"Trial failed with { result .fail_trace ()} and status { result .status } ." ,
75
82
)
76
83
return result .result ["imit_stats" ]["monitor_return_mean" ]
77
84
78
85
79
- """A mapping from algorithm names to functions that run the algorithm as an optuna trial."""
86
+ """A mapping from algorithm names to functions that run the algorithm as an optuna
87
+ trial."""
80
88
objectives_by_algo = dict (
81
89
pc = RunSacredAsTrial (
82
- sacred_ex = imitation . scripts . train_preference_comparisons .train_preference_comparisons_ex ,
90
+ sacred_ex = train_pc_script .train_preference_comparisons_ex ,
83
91
suggest_named_configs = lambda _ : ["reward.reward_ensemble" ],
84
92
suggest_config_updates = lambda trial : {
85
93
"seed" : trial .number ,
@@ -88,61 +96,87 @@ def __call__(
88
96
"total_comparisons" : 1000 ,
89
97
"active_selection" : True ,
90
98
"active_selection_oversampling" : trial .suggest_int (
91
- "active_selection_oversampling" , 1 , 11
99
+ "active_selection_oversampling" ,
100
+ 1 ,
101
+ 11 ,
92
102
),
93
103
"comparison_queue_size" : trial .suggest_int (
94
- "comparison_queue_size" , 1 , 1001
104
+ "comparison_queue_size" ,
105
+ 1 ,
106
+ 1001 ,
95
107
), # upper bound determined by total_comparisons=1000
96
108
"exploration_frac" : trial .suggest_float ("exploration_frac" , 0.0 , 0.5 ),
97
109
"fragment_length" : trial .suggest_int (
98
- "fragment_length" , 1 , 1001
110
+ "fragment_length" ,
111
+ 1 ,
112
+ 1001 ,
99
113
), # trajectories are 1000 steps long
100
114
"gatherer_kwargs" : {
101
115
"temperature" : trial .suggest_float ("gatherer_temperature" , 0.0 , 2.0 ),
102
116
"discount_factor" : trial .suggest_float (
103
- "gatherer_discount_factor" , 0.95 , 1.0
117
+ "gatherer_discount_factor" ,
118
+ 0.95 ,
119
+ 1.0 ,
104
120
),
105
121
"sample" : trial .suggest_categorical ("gatherer_sample" , [True , False ]),
106
122
},
107
123
"initial_epoch_multiplier" : trial .suggest_float (
108
- "initial_epoch_multiplier" , 1 , 200.0
124
+ "initial_epoch_multiplier" ,
125
+ 1 ,
126
+ 200.0 ,
109
127
),
110
128
"initial_comparison_frac" : trial .suggest_float (
111
- "initial_comparison_frac" , 0.01 , 1.0
129
+ "initial_comparison_frac" ,
130
+ 0.01 ,
131
+ 1.0 ,
112
132
),
113
133
"num_iterations" : trial .suggest_int ("num_iterations" , 1 , 51 ),
114
134
"preference_model_kwargs" : {
115
135
"noise_prob" : trial .suggest_float (
116
- "preference_model_noise_prob" , 0.0 , 0.1
136
+ "preference_model_noise_prob" ,
137
+ 0.0 ,
138
+ 0.1 ,
117
139
),
118
140
"discount_factor" : trial .suggest_float (
119
- "preference_model_discount_factor" , 0.95 , 1.0
141
+ "preference_model_discount_factor" ,
142
+ 0.95 ,
143
+ 1.0 ,
120
144
),
121
145
},
122
146
"query_schedule" : trial .suggest_categorical (
123
- "query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]
147
+ "query_schedule" ,
148
+ [
149
+ "hyperbolic" ,
150
+ "constant" ,
151
+ "inverse_quadratic" ,
152
+ ],
124
153
),
125
154
"trajectory_generator_kwargs" : {
126
155
"switch_prob" : trial .suggest_float ("tr_gen_switch_prob" , 0.1 , 1 ),
127
156
"random_prob" : trial .suggest_float ("tr_gen_random_prob" , 0.1 , 0.9 ),
128
157
},
129
158
"transition_oversampling" : trial .suggest_float (
130
- "transition_oversampling" , 0.9 , 2.0
159
+ "transition_oversampling" ,
160
+ 0.9 ,
161
+ 2.0 ,
131
162
),
132
163
"reward_trainer_kwargs" : {
133
164
"epochs" : trial .suggest_int ("reward_trainer_epochs" , 1 , 11 ),
134
165
},
135
166
"rl" : {
136
167
"rl_kwargs" : {
137
168
"ent_coef" : trial .suggest_float (
138
- "rl_ent_coef" , 1e-7 , 1e-3 , log = True
169
+ "rl_ent_coef" ,
170
+ 1e-7 ,
171
+ 1e-3 ,
172
+ log = True ,
139
173
),
140
174
},
141
175
},
142
176
},
143
177
),
144
178
pc_classic_control = RunSacredAsTrial (
145
- sacred_ex = imitation . scripts . train_preference_comparisons .train_preference_comparisons_ex ,
179
+ sacred_ex = train_pc_script .train_preference_comparisons_ex ,
146
180
suggest_named_configs = lambda _ : ["reward.reward_ensemble" ],
147
181
suggest_config_updates = lambda trial : {
148
182
"seed" : trial .number ,
@@ -151,54 +185,80 @@ def __call__(
151
185
"total_comparisons" : 1000 ,
152
186
"active_selection" : True ,
153
187
"active_selection_oversampling" : trial .suggest_int (
154
- "active_selection_oversampling" , 1 , 11
188
+ "active_selection_oversampling" ,
189
+ 1 ,
190
+ 11 ,
155
191
),
156
192
"comparison_queue_size" : trial .suggest_int (
157
- "comparison_queue_size" , 1 , 1001
193
+ "comparison_queue_size" ,
194
+ 1 ,
195
+ 1001 ,
158
196
), # upper bound determined by total_comparisons=1000
159
197
"exploration_frac" : trial .suggest_float ("exploration_frac" , 0.0 , 0.5 ),
160
198
"fragment_length" : trial .suggest_int (
161
- "fragment_length" , 1 , 201
199
+ "fragment_length" ,
200
+ 1 ,
201
+ 201 ,
162
202
), # trajectories are 1000 steps long
163
203
"gatherer_kwargs" : {
164
204
"temperature" : trial .suggest_float ("gatherer_temperature" , 0.0 , 2.0 ),
165
205
"discount_factor" : trial .suggest_float (
166
- "gatherer_discount_factor" , 0.95 , 1.0
206
+ "gatherer_discount_factor" ,
207
+ 0.95 ,
208
+ 1.0 ,
167
209
),
168
210
"sample" : trial .suggest_categorical ("gatherer_sample" , [True , False ]),
169
211
},
170
212
"initial_epoch_multiplier" : trial .suggest_float (
171
- "initial_epoch_multiplier" , 1 , 200.0
213
+ "initial_epoch_multiplier" ,
214
+ 1 ,
215
+ 200.0 ,
172
216
),
173
217
"initial_comparison_frac" : trial .suggest_float (
174
- "initial_comparison_frac" , 0.01 , 1.0
218
+ "initial_comparison_frac" ,
219
+ 0.01 ,
220
+ 1.0 ,
175
221
),
176
222
"num_iterations" : trial .suggest_int ("num_iterations" , 1 , 51 ),
177
223
"preference_model_kwargs" : {
178
224
"noise_prob" : trial .suggest_float (
179
- "preference_model_noise_prob" , 0.0 , 0.1
225
+ "preference_model_noise_prob" ,
226
+ 0.0 ,
227
+ 0.1 ,
180
228
),
181
229
"discount_factor" : trial .suggest_float (
182
- "preference_model_discount_factor" , 0.95 , 1.0
230
+ "preference_model_discount_factor" ,
231
+ 0.95 ,
232
+ 1.0 ,
183
233
),
184
234
},
185
235
"query_schedule" : trial .suggest_categorical (
186
- "query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]
236
+ "query_schedule" ,
237
+ [
238
+ "hyperbolic" ,
239
+ "constant" ,
240
+ "inverse_quadratic" ,
241
+ ],
187
242
),
188
243
"trajectory_generator_kwargs" : {
189
244
"switch_prob" : trial .suggest_float ("tr_gen_switch_prob" , 0.1 , 1 ),
190
245
"random_prob" : trial .suggest_float ("tr_gen_random_prob" , 0.1 , 0.9 ),
191
246
},
192
247
"transition_oversampling" : trial .suggest_float (
193
- "transition_oversampling" , 0.9 , 2.0
248
+ "transition_oversampling" ,
249
+ 0.9 ,
250
+ 2.0 ,
194
251
),
195
252
"reward_trainer_kwargs" : {
196
253
"epochs" : trial .suggest_int ("reward_trainer_epochs" , 1 , 11 ),
197
254
},
198
255
"rl" : {
199
256
"rl_kwargs" : {
200
257
"ent_coef" : trial .suggest_float (
201
- "rl_ent_coef" , 1e-7 , 1e-3 , log = True
258
+ "rl_ent_coef" ,
259
+ 1e-7 ,
260
+ 1e-3 ,
261
+ log = True ,
202
262
),
203
263
},
204
264
},
@@ -217,28 +277,41 @@ def __call__(
217
277
"rl" : {
218
278
"rl_kwargs" : {
219
279
"learning_rate" : trial .suggest_float (
220
- "learning_rate" , 1e-6 , 1e-2 , log = True
280
+ "learning_rate" ,
281
+ 1e-6 ,
282
+ 1e-2 ,
283
+ log = True ,
221
284
),
222
285
"buffer_size" : trial .suggest_int ("buffer_size" , 1000 , 100000 ),
223
286
"learning_starts" : trial .suggest_int (
224
- "learning_starts" , 1000 , 10000
287
+ "learning_starts" ,
288
+ 1000 ,
289
+ 10000 ,
225
290
),
226
291
"batch_size" : trial .suggest_int ("batch_size" , 32 , 128 ),
227
292
"tau" : trial .suggest_float ("tau" , 0.0 , 1.0 ),
228
293
"gamma" : trial .suggest_float ("gamma" , 0.9 , 0.999 ),
229
294
"train_freq" : trial .suggest_int ("train_freq" , 1 , 40 ),
230
295
"gradient_steps" : trial .suggest_int ("gradient_steps" , 1 , 10 ),
231
296
"target_update_interval" : trial .suggest_int (
232
- "target_update_interval" , 1 , 10000
297
+ "target_update_interval" ,
298
+ 1 ,
299
+ 10000 ,
233
300
),
234
301
"exploration_fraction" : trial .suggest_float (
235
- "exploration_fraction" , 0.01 , 0.5
302
+ "exploration_fraction" ,
303
+ 0.01 ,
304
+ 0.5 ,
236
305
),
237
306
"exploration_final_eps" : trial .suggest_float (
238
- "exploration_final_eps" , 0.01 , 1.0
307
+ "exploration_final_eps" ,
308
+ 0.01 ,
309
+ 1.0 ,
239
310
),
240
311
"exploration_initial_eps" : trial .suggest_float (
241
- "exploration_initial_eps" , 0.01 , 0.5
312
+ "exploration_initial_eps" ,
313
+ 0.01 ,
314
+ 0.5 ,
242
315
),
243
316
"max_grad_norm" : trial .suggest_float ("max_grad_norm" , 0.1 , 10.0 ),
244
317
},
0 commit comments