forked from CyberAgentAILab/cmaes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptuna_solver.py
59 lines (45 loc) · 1.65 KB
/
optuna_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import optuna
from kurobako import solver
from kurobako.solver.optuna import OptunaSolverFactory
parser = argparse.ArgumentParser()
parser.add_argument("sampler", choices=["cmaes", "pycma"])
parser.add_argument(
"--loglevel", choices=["debug", "info", "warning", "error"], default="warning"
)
args = parser.parse_args()
if args.loglevel == "debug":
optuna.logging.set_verbosity(optuna.logging.DEBUG)
elif args.loglevel == "info":
optuna.logging.set_verbosity(optuna.logging.INFO)
elif args.loglevel == "warning":
optuna.logging.set_verbosity(optuna.logging.WARNING)
elif args.loglevel == "error":
optuna.logging.set_verbosity(optuna.logging.ERROR)
class CMASolverFactory(OptunaSolverFactory):
def specification(self):
spec = super().specification()
spec.name = "cmaes"
return spec
class PyCMASolverFactory(OptunaSolverFactory):
def specification(self):
spec = super().specification()
spec.name = "pycma"
return spec
def create_cmaes_study(seed):
sampler = optuna.samplers.CmaEsSampler(seed=seed, warn_independent_sampling=True,)
return optuna.create_study(sampler=sampler)
def create_pycma_study(seed):
sampler = optuna.integration.CmaEsSampler(
seed=seed, warn_independent_sampling=True,
)
return optuna.create_study(sampler=sampler)
if __name__ == "__main__":
if args.sampler == "cmaes":
factory = CMASolverFactory(create_cmaes_study)
elif args.sampler == "pycma":
factory = PyCMASolverFactory(create_pycma_study)
else:
raise ValueError("unsupported sampler")
runner = solver.SolverRunner(factory)
runner.run()