Skip to content

Commit 57018d5

Browse files
committed
Add hp search space for PC/classic control.
1 parent 9bb89c9 commit 57018d5

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

tuning/hp_search_spaces.py

+41
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,47 @@ def __call__(
123123
},
124124
},
125125
),
126+
pc_classic_control=RunSacredAsTrial(
127+
sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex,
128+
suggest_named_configs=lambda _: ["reward.reward_ensemble"],
129+
suggest_config_updates=lambda trial: {
130+
"seed": trial.number,
131+
"environment": {"num_vec": 8},
132+
"total_timesteps": 1e6,
133+
"total_comparisons": 1000,
134+
"active_selection": True,
135+
"active_selection_oversampling": trial.suggest_int("active_selection_oversampling", 1, 11),
136+
"comparison_queue_size": trial.suggest_int("comparison_queue_size", 1, 1001), # upper bound determined by total_comparisons=1000
137+
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
138+
"fragment_length": trial.suggest_int("fragment_length", 1, 1001), # trajectories are 1000 steps long
139+
"gatherer_kwargs": {
140+
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
141+
"discount_factor": trial.suggest_float("gatherer_discount_factor", 0.95, 1.0),
142+
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
143+
},
144+
"initial_epoch_multiplier": trial.suggest_float("initial_epoch_multiplier", 1, 200.0),
145+
"initial_comparison_frac": trial.suggest_float("initial_comparison_frac", 0.01, 1.0),
146+
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
147+
"preference_model_kwargs": {
148+
"noise_prob": trial.suggest_float("preference_model_noise_prob", 0.0, 0.1),
149+
"discount_factor": trial.suggest_float("preference_model_discount_factor", 0.95, 1.0),
150+
},
151+
"query_schedule": trial.suggest_categorical("query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]),
152+
"trajectory_generator_kwargs": {
153+
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
154+
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
155+
},
156+
"transition_oversampling": trial.suggest_float("transition_oversampling", 0.9, 2.0),
157+
"reward_trainer_kwargs": {
158+
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
159+
},
160+
"rl": {
161+
"rl_kwargs": {
162+
"ent_coef": trial.suggest_float("rl_ent_coef", 1e-7, 1e-3, log=True),
163+
},
164+
},
165+
},
166+
),
126167
sqil=RunSacredAsTrial(
127168
sacred_ex=imitation.scripts.train_imitation.train_imitation_ex,
128169
command_name="sqil",

0 commit comments

Comments
 (0)