@@ -123,6 +123,47 @@ def __call__(
123
123
},
124
124
},
125
125
),
126
+ pc_classic_control = RunSacredAsTrial (
127
+ sacred_ex = imitation .scripts .train_preference_comparisons .train_preference_comparisons_ex ,
128
+ suggest_named_configs = lambda _ : ["reward.reward_ensemble" ],
129
+ suggest_config_updates = lambda trial : {
130
+ "seed" : trial .number ,
131
+ "environment" : {"num_vec" : 8 },
132
+ "total_timesteps" : 1e6 ,
133
+ "total_comparisons" : 1000 ,
134
+ "active_selection" : True ,
135
+ "active_selection_oversampling" : trial .suggest_int ("active_selection_oversampling" , 1 , 11 ),
136
+ "comparison_queue_size" : trial .suggest_int ("comparison_queue_size" , 1 , 1001 ), # upper bound determined by total_comparisons=1000
137
+ "exploration_frac" : trial .suggest_float ("exploration_frac" , 0.0 , 0.5 ),
138
+ "fragment_length" : trial .suggest_int ("fragment_length" , 1 , 1001 ), # trajectories are 1000 steps long
139
+ "gatherer_kwargs" : {
140
+ "temperature" : trial .suggest_float ("gatherer_temperature" , 0.0 , 2.0 ),
141
+ "discount_factor" : trial .suggest_float ("gatherer_discount_factor" , 0.95 , 1.0 ),
142
+ "sample" : trial .suggest_categorical ("gatherer_sample" , [True , False ]),
143
+ },
144
+ "initial_epoch_multiplier" : trial .suggest_float ("initial_epoch_multiplier" , 1 , 200.0 ),
145
+ "initial_comparison_frac" : trial .suggest_float ("initial_comparison_frac" , 0.01 , 1.0 ),
146
+ "num_iterations" : trial .suggest_int ("num_iterations" , 1 , 51 ),
147
+ "preference_model_kwargs" : {
148
+ "noise_prob" : trial .suggest_float ("preference_model_noise_prob" , 0.0 , 0.1 ),
149
+ "discount_factor" : trial .suggest_float ("preference_model_discount_factor" , 0.95 , 1.0 ),
150
+ },
151
+ "query_schedule" : trial .suggest_categorical ("query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]),
152
+ "trajectory_generator_kwargs" : {
153
+ "switch_prob" : trial .suggest_float ("tr_gen_switch_prob" , 0.1 , 1 ),
154
+ "random_prob" : trial .suggest_float ("tr_gen_random_prob" , 0.1 , 0.9 ),
155
+ },
156
+ "transition_oversampling" : trial .suggest_float ("transition_oversampling" , 0.9 , 2.0 ),
157
+ "reward_trainer_kwargs" : {
158
+ "epochs" : trial .suggest_int ("reward_trainer_epochs" , 1 , 11 ),
159
+ },
160
+ "rl" : {
161
+ "rl_kwargs" : {
162
+ "ent_coef" : trial .suggest_float ("rl_ent_coef" , 1e-7 , 1e-3 , log = True ),
163
+ },
164
+ },
165
+ },
166
+ ),
126
167
sqil = RunSacredAsTrial (
127
168
sacred_ex = imitation .scripts .train_imitation .train_imitation_ex ,
128
169
command_name = "sqil" ,
0 commit comments