Skip to content

Commit e69dbd6

Browse files
committed
More formatting fixes.
1 parent 60fc75a commit e69dbd6

File tree

4 files changed

+130
-49
lines changed

4 files changed

+130
-49
lines changed

src/imitation/scripts/config/tuning.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ray.tune as tune
44
import sacred
5-
from torch import nn
65

76
from imitation.algorithms import dagger as dagger_alg
87
from imitation.scripts.parallel import parallel_ex
@@ -200,11 +199,11 @@ def pc():
200199
"config_updates": {
201200
"active_selection_oversampling": tune.randint(1, 11),
202201
"comparison_queue_size": tune.randint(
203-
1, 1001
202+
1, 1001,
204203
), # upper bound determined by total_comparisons=1000
205204
"exploration_frac": tune.uniform(0.0, 0.5),
206205
"fragment_length": tune.randint(
207-
1, 1001
206+
1, 1001,
208207
), # trajectories are 1000 steps long
209208
"gatherer_kwargs": {
210209
"temperature": tune.uniform(0.0, 2.0),
@@ -218,7 +217,7 @@ def pc():
218217
"discount_factor": tune.uniform(0.95, 1.0),
219218
},
220219
"query_schedule": tune.choice(
221-
["hyperbolic", "constant", "inverse_quadratic"]
220+
["hyperbolic", "constant", "inverse_quadratic",]
222221
),
223222
"trajectory_generator_kwargs": {
224223
"switch_prob": tune.uniform(0.1, 1),

tuning/hp_search_spaces.py

+111-38
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414
"""
1515

1616
import dataclasses
17-
from typing import Any, Callable, Dict, List, Mapping
17+
from typing import Any, Callable, Dict, List, Mapping, Optional
1818

1919
import optuna
2020
import sacred
21-
import stable_baselines3 as sb3
2221

2322
import imitation.scripts.train_imitation
24-
import imitation.scripts.train_preference_comparisons
23+
import imitation.scripts.train_preference_comparisons as train_pc_script
2524

2625

2726
@dataclasses.dataclass
@@ -42,19 +41,27 @@ class RunSacredAsTrial:
4241
suggest_config_updates: Callable[[optuna.Trial], Mapping[str, Any]]
4342

4443
"""Command name to pass to sacred.run."""
45-
command_name: str = None
44+
command_name: Optional[str] = None
4645

4746
def __call__(
48-
self, trial: optuna.Trial, run_options: Dict, extra_named_configs: List[str]
47+
self,
48+
trial: optuna.Trial,
49+
run_options: Dict,
50+
extra_named_configs: List[str],
4951
) -> float:
5052
"""Run the sacred experiment and return the performance.
5153
5254
Args:
5355
trial: The optuna trial to sample hyperparameters for.
5456
run_options: Options to pass to sacred.run(options=).
5557
extra_named_configs: Additional named configs to pass to sacred.run.
56-
"""
5758
59+
Returns:
60+
The performance of the trial.
61+
62+
Raises:
63+
RuntimeError: If the trial fails.
64+
"""
5865
config_updates = self.suggest_config_updates(trial)
5966
named_configs = self.suggest_named_configs(trial) + extra_named_configs
6067

@@ -71,15 +78,16 @@ def __call__(
7178
)
7279
if result.status != "COMPLETED":
7380
raise RuntimeError(
74-
f"Trial failed with {result.fail_trace()} and status {result.status}."
81+
f"Trial failed with {result.fail_trace()} and status {result.status}.",
7582
)
7683
return result.result["imit_stats"]["monitor_return_mean"]
7784

7885

79-
"""A mapping from algorithm names to functions that run the algorithm as an optuna trial."""
86+
"""A mapping from algorithm names to functions that run the algorithm as an optuna
87+
trial."""
8088
objectives_by_algo = dict(
8189
pc=RunSacredAsTrial(
82-
sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex,
90+
sacred_ex=train_pc_script.train_preference_comparisons_ex,
8391
suggest_named_configs=lambda _: ["reward.reward_ensemble"],
8492
suggest_config_updates=lambda trial: {
8593
"seed": trial.number,
@@ -88,61 +96,87 @@ def __call__(
8896
"total_comparisons": 1000,
8997
"active_selection": True,
9098
"active_selection_oversampling": trial.suggest_int(
91-
"active_selection_oversampling", 1, 11
99+
"active_selection_oversampling",
100+
1,
101+
11,
92102
),
93103
"comparison_queue_size": trial.suggest_int(
94-
"comparison_queue_size", 1, 1001
104+
"comparison_queue_size",
105+
1,
106+
1001,
95107
), # upper bound determined by total_comparisons=1000
96108
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
97109
"fragment_length": trial.suggest_int(
98-
"fragment_length", 1, 1001
110+
"fragment_length",
111+
1,
112+
1001,
99113
), # trajectories are 1000 steps long
100114
"gatherer_kwargs": {
101115
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
102116
"discount_factor": trial.suggest_float(
103-
"gatherer_discount_factor", 0.95, 1.0
117+
"gatherer_discount_factor",
118+
0.95,
119+
1.0,
104120
),
105121
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
106122
},
107123
"initial_epoch_multiplier": trial.suggest_float(
108-
"initial_epoch_multiplier", 1, 200.0
124+
"initial_epoch_multiplier",
125+
1,
126+
200.0,
109127
),
110128
"initial_comparison_frac": trial.suggest_float(
111-
"initial_comparison_frac", 0.01, 1.0
129+
"initial_comparison_frac",
130+
0.01,
131+
1.0,
112132
),
113133
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
114134
"preference_model_kwargs": {
115135
"noise_prob": trial.suggest_float(
116-
"preference_model_noise_prob", 0.0, 0.1
136+
"preference_model_noise_prob",
137+
0.0,
138+
0.1,
117139
),
118140
"discount_factor": trial.suggest_float(
119-
"preference_model_discount_factor", 0.95, 1.0
141+
"preference_model_discount_factor",
142+
0.95,
143+
1.0,
120144
),
121145
},
122146
"query_schedule": trial.suggest_categorical(
123-
"query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]
147+
"query_schedule",
148+
[
149+
"hyperbolic",
150+
"constant",
151+
"inverse_quadratic",
152+
],
124153
),
125154
"trajectory_generator_kwargs": {
126155
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
127156
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
128157
},
129158
"transition_oversampling": trial.suggest_float(
130-
"transition_oversampling", 0.9, 2.0
159+
"transition_oversampling",
160+
0.9,
161+
2.0,
131162
),
132163
"reward_trainer_kwargs": {
133164
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
134165
},
135166
"rl": {
136167
"rl_kwargs": {
137168
"ent_coef": trial.suggest_float(
138-
"rl_ent_coef", 1e-7, 1e-3, log=True
169+
"rl_ent_coef",
170+
1e-7,
171+
1e-3,
172+
log=True,
139173
),
140174
},
141175
},
142176
},
143177
),
144178
pc_classic_control=RunSacredAsTrial(
145-
sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex,
179+
sacred_ex=train_pc_script.train_preference_comparisons_ex,
146180
suggest_named_configs=lambda _: ["reward.reward_ensemble"],
147181
suggest_config_updates=lambda trial: {
148182
"seed": trial.number,
@@ -151,54 +185,80 @@ def __call__(
151185
"total_comparisons": 1000,
152186
"active_selection": True,
153187
"active_selection_oversampling": trial.suggest_int(
154-
"active_selection_oversampling", 1, 11
188+
"active_selection_oversampling",
189+
1,
190+
11,
155191
),
156192
"comparison_queue_size": trial.suggest_int(
157-
"comparison_queue_size", 1, 1001
193+
"comparison_queue_size",
194+
1,
195+
1001,
158196
), # upper bound determined by total_comparisons=1000
159197
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
160198
"fragment_length": trial.suggest_int(
161-
"fragment_length", 1, 201
199+
"fragment_length",
200+
1,
201+
201,
162202
), # trajectories are 1000 steps long
163203
"gatherer_kwargs": {
164204
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
165205
"discount_factor": trial.suggest_float(
166-
"gatherer_discount_factor", 0.95, 1.0
206+
"gatherer_discount_factor",
207+
0.95,
208+
1.0,
167209
),
168210
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
169211
},
170212
"initial_epoch_multiplier": trial.suggest_float(
171-
"initial_epoch_multiplier", 1, 200.0
213+
"initial_epoch_multiplier",
214+
1,
215+
200.0,
172216
),
173217
"initial_comparison_frac": trial.suggest_float(
174-
"initial_comparison_frac", 0.01, 1.0
218+
"initial_comparison_frac",
219+
0.01,
220+
1.0,
175221
),
176222
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
177223
"preference_model_kwargs": {
178224
"noise_prob": trial.suggest_float(
179-
"preference_model_noise_prob", 0.0, 0.1
225+
"preference_model_noise_prob",
226+
0.0,
227+
0.1,
180228
),
181229
"discount_factor": trial.suggest_float(
182-
"preference_model_discount_factor", 0.95, 1.0
230+
"preference_model_discount_factor",
231+
0.95,
232+
1.0,
183233
),
184234
},
185235
"query_schedule": trial.suggest_categorical(
186-
"query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]
236+
"query_schedule",
237+
[
238+
"hyperbolic",
239+
"constant",
240+
"inverse_quadratic",
241+
],
187242
),
188243
"trajectory_generator_kwargs": {
189244
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
190245
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
191246
},
192247
"transition_oversampling": trial.suggest_float(
193-
"transition_oversampling", 0.9, 2.0
248+
"transition_oversampling",
249+
0.9,
250+
2.0,
194251
),
195252
"reward_trainer_kwargs": {
196253
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
197254
},
198255
"rl": {
199256
"rl_kwargs": {
200257
"ent_coef": trial.suggest_float(
201-
"rl_ent_coef", 1e-7, 1e-3, log=True
258+
"rl_ent_coef",
259+
1e-7,
260+
1e-3,
261+
log=True,
202262
),
203263
},
204264
},
@@ -217,28 +277,41 @@ def __call__(
217277
"rl": {
218278
"rl_kwargs": {
219279
"learning_rate": trial.suggest_float(
220-
"learning_rate", 1e-6, 1e-2, log=True
280+
"learning_rate",
281+
1e-6,
282+
1e-2,
283+
log=True,
221284
),
222285
"buffer_size": trial.suggest_int("buffer_size", 1000, 100000),
223286
"learning_starts": trial.suggest_int(
224-
"learning_starts", 1000, 10000
287+
"learning_starts",
288+
1000,
289+
10000,
225290
),
226291
"batch_size": trial.suggest_int("batch_size", 32, 128),
227292
"tau": trial.suggest_float("tau", 0.0, 1.0),
228293
"gamma": trial.suggest_float("gamma", 0.9, 0.999),
229294
"train_freq": trial.suggest_int("train_freq", 1, 40),
230295
"gradient_steps": trial.suggest_int("gradient_steps", 1, 10),
231296
"target_update_interval": trial.suggest_int(
232-
"target_update_interval", 1, 10000
297+
"target_update_interval",
298+
1,
299+
10000,
233300
),
234301
"exploration_fraction": trial.suggest_float(
235-
"exploration_fraction", 0.01, 0.5
302+
"exploration_fraction",
303+
0.01,
304+
0.5,
236305
),
237306
"exploration_final_eps": trial.suggest_float(
238-
"exploration_final_eps", 0.01, 1.0
307+
"exploration_final_eps",
308+
0.01,
309+
1.0,
239310
),
240311
"exploration_initial_eps": trial.suggest_float(
241-
"exploration_initial_eps", 0.01, 0.5
312+
"exploration_initial_eps",
313+
0.01,
314+
0.5,
242315
),
243316
"max_grad_norm": trial.suggest_float("max_grad_norm", 0.1, 10.0),
244317
},

tuning/rerun_best_trial.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Script to re-run the best trials from a previous hyperparameter tuning run."""
22
import argparse
33
import random
4-
from typing import List, Tuple
54

65
import hp_search_spaces
76
import optuna
@@ -11,7 +10,7 @@
1110
def make_parser() -> argparse.ArgumentParser:
1211
parser = argparse.ArgumentParser(
1312
description="Re-run the best trial from a previous tuning run.",
14-
epilog=f"Example usage:\n" f"python rerun_best_trials.py tuning_run.json\n",
13+
epilog="Example usage:\npython rerun_best_trials.py tuning_run.json\n",
1514
formatter_class=argparse.RawDescriptionHelpFormatter,
1615
)
1716
parser.add_argument(
@@ -40,6 +39,12 @@ def infer_algo_name(study: optuna.Study) -> str:
4039
"""Infer the algo name from the study name.
4140
4241
Assumes that the study name is of the form "tuning_{algo}_with_{named_configs}".
42+
43+
Args:
44+
study: The optuna study.
45+
46+
Returns:
47+
algo name
4348
"""
4449
assert study.study_name.startswith("tuning_")
4550
assert "_with_" in study.study_name
@@ -51,7 +56,7 @@ def main():
5156
args = parser.parse_args()
5257
study: optuna.Study = optuna.load_study(
5358
storage=optuna.storages.JournalStorage(
54-
optuna.storages.JournalFileStorage(args.journal_log)
59+
optuna.storages.JournalFileStorage(args.journal_log),
5560
),
5661
# in our case, we have one journal file per study so the study name can be
5762
# inferred
@@ -73,7 +78,7 @@ def main():
7378
)
7479
if result.status != "COMPLETED":
7580
raise RuntimeError(
76-
f"Trial failed with {result.fail_trace()} and status {result.status}."
81+
f"Trial failed with {result.fail_trace()} and status {result.status}.",
7782
)
7883

7984

0 commit comments

Comments
 (0)