|
51 | 51 | is_atari = False
|
52 | 52 | if 'NoFrameskip' in env_id:
|
53 | 53 | is_atari = True
|
| 54 | + is_marathon_envs = False |
| 55 | + if 'Marathon' in env_id: |
| 56 | + is_marathon_envs = True |
54 | 57 |
|
55 | 58 | print("=" * 10, env_id, "=" * 10)
|
56 | 59 |
|
57 | 60 | # Load hyperparameters from yaml file
|
58 | 61 | with open('hyperparams/{}.yml'.format(args.algo), 'r') as f:
|
59 | 62 | if is_atari:
|
60 | 63 | hyperparams = yaml.load(f)['atari']
|
| 64 | + elif is_marathon_envs: |
| 65 | + hyperparams = yaml.load(f)['MarathonEnvs'] |
61 | 66 | else:
|
62 |
| - # hyperparams = yaml.load(f)['atari'] |
63 |
| - hyperparams = yaml.load(f)['MarathonHopperEnv-v0'] |
64 |
| - # hyperparams = yaml.load(f)[env_id] |
| 67 | + hyperparams = yaml.load(f)[env_id] |
65 | 68 |
|
66 | 69 | n_envs = hyperparams.get('n_envs', 1)
|
| 70 | + n_agents = hyperparams.get('n_agents', 1) |
67 | 71 |
|
68 | 72 | print("Using {} environments".format(n_envs))
|
| 73 | + print("With {} agents per enviroment".format(n_agents)) |
69 | 74 |
|
70 | 75 | # Create learning rate schedules for ppo2
|
71 | 76 | if args.algo == "ppo2":
|
|
91 | 96 | if 'normalize' in hyperparams.keys():
|
92 | 97 | normalize = hyperparams['normalize']
|
93 | 98 | del hyperparams['normalize']
|
| 99 | + if args.algo in ['dqn', 'ddpg']: |
| 100 | + print("WARNING: normalization not supported yet for DDPG/DQN") |
94 | 101 |
|
95 | 102 | # Delete keys so the dict can be pass to the model constructor
|
96 | 103 | if 'n_envs' in hyperparams.keys():
|
97 | 104 | del hyperparams['n_envs']
|
| 105 | + if 'n_agents' in hyperparams.keys(): |
| 106 | + del hyperparams['n_agents'] |
98 | 107 | del hyperparams['n_timesteps']
|
99 | 108 |
|
100 | 109 | # Create the environment and wrap it if necessary
|
|
103 | 112 | env = make_atari_env(env_id, num_env=n_envs, seed=args.seed)
|
104 | 113 | # Frame-stacking with 4 frames
|
105 | 114 | env = VecFrameStack(env, n_stack=4)
|
106 |
| - elif args.algo in ['dqn', 'ddpg']: |
107 |
| - if hyperparams.get('normalize', False): |
108 |
| - print("WARNING: normalization not supported yet for DDPG/DQN") |
109 |
| - env = gym.make(env_id) |
110 |
| - env.seed(args.seed) |
111 | 115 | elif 'Marathon' in env_id:
|
112 | 116 | from UnityVecEnv import UnityVecEnv
|
113 |
| - env = UnityVecEnv(env_id) |
| 117 | + if n_agents is 1: |
| 118 | + from gym_unity.envs import UnityEnv |
| 119 | + env_path = UnityVecEnv.GetFilePath(env_id, n_agents=n_agents) |
| 120 | + env = UnityEnv(env_path) |
| 121 | + env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run |
| 122 | + else: |
| 123 | + env = UnityVecEnv(env_id, n_agents=n_agents) |
114 | 124 | if normalize:
|
115 | 125 | print("Normalizing input and return")
|
116 | 126 | env = VecNormalize(env)
|
| 127 | + elif args.algo in ['dqn', 'ddpg']: |
| 128 | + env = gym.make(env_id) |
| 129 | + env.seed(args.seed) |
117 | 130 | else:
|
118 | 131 | if n_envs == 1:
|
119 | 132 | env = DummyVecEnv([make_env(env_id, 0, args.seed)])
|
|
0 commit comments