Skip to content

Commit a0157d6

Browse files
committedDec 3, 2018
pre for other stable_baselines algorthems
1 parent ad1202f commit a0157d6

File tree

10 files changed

+66
-33
lines changed

10 files changed

+66
-33
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,4 @@ dist/
9494
build/
9595
summaries.001/
9696
filename.monitor.csv
97+
logs/

‎UnityVecEnv.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,21 @@ class UnityVecEnv(VecEnv):
2424
"""
2525

2626
@staticmethod
27-
def GetFilePath(env_id, inference_mode=False):
27+
def GetFilePath(env_id, inference_mode=False, n_agents=1):
2828
import psutil
2929
env_name = MarathonEnvs[env_id]
30-
if not inference_mode:
31-
env_name = env_name + '-x16'
32-
else:
30+
if inference_mode:
3331
env_name = env_name + '-run'
32+
elif n_agents is 16:
33+
env_name = env_name + '-x16'
3434
if psutil.MACOS:
3535
env_path = os.path.join('envs', env_name)
3636
elif psutil.WINDOWS:
3737
env_path = os.path.join('envs', env_name, 'Unity Environment.exe')
3838
return env_path
3939

40-
def __init__(self, env_id):
41-
env_path = UnityVecEnv.GetFilePath(env_id)
40+
def __init__(self, env_id, n_agents):
41+
env_path = UnityVecEnv.GetFilePath(env_id, n_agents=n_agents)
4242
print ("**** ", env_path)
4343
env = UnityEnv(env_path, multiagent=True)
4444
self.env = env

‎hyperparams/a2c.yml

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
MarathonEnvs:
2+
n_agents: 16
3+
n_timesteps: !!float 1e6
4+
policy: 'MlpPolicy'
5+
vf_coef: 0.25
6+
learning_rate: !!float 3e-4
7+
epsilon: !!float 1e-5
8+
normalize: true
9+
110
atari:
211
policy: 'CnnPolicy'
312
n_envs: 16

‎hyperparams/acer.yml

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
MarathonEnvs:
2+
n_agents: 16
3+
n_timesteps: !!float 1e6
4+
policy: 'MlpPolicy'
5+
learning_rate: 3e-4
6+
normalize: true
7+
18
atari:
29
policy: 'CnnPolicy'
310
n_envs: 16

‎hyperparams/acktr.yml

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
MarathonEnvs:
2+
n_agents: 16
3+
n_timesteps: !!float 1e6
4+
policy: 'MlpPolicy'
5+
learning_rate: 3e-4
6+
normalize: true
7+
18
atari:
29
policy: 'CnnPolicy'
310
n_envs: 32

‎hyperparams/ddpg.yml

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
MarathonEnvs:
2+
n_agents: 1
3+
n_timesteps: !!float 1e6
4+
policy: 'MlpPolicy'
5+
# learning_rate: 3e-4
6+
# normalize: true
7+
18
MountainCarContinuous-v0:
29
n_timesteps: 300000
310
policy: 'MlpPolicy'

‎hyperparams/ppo2.yml

+2-16
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,5 @@
1-
MarathonHopperEnv-v0:
2-
n_envs: 16
3-
n_timesteps: !!float 1e6
4-
policy: 'MlpPolicy'
5-
normalize: true
6-
n_steps: 128 # 2048 / number of agents
7-
nminibatches: 32
8-
lam: 0.95
9-
gamma: 0.99
10-
noptepochs: 10
11-
ent_coef: 0.0
12-
learning_rate: lin_3e-4
13-
cliprange: 0.2
14-
15-
MarathonWalker2DEnv-v0:
16-
n_envs: 16
1+
MarathonEnvs:
2+
n_agents: 16
173
n_timesteps: !!float 1e6
184
policy: 'MlpPolicy'
195
normalize: true

‎sb_enjoy.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444
if algo in ['dqn', 'ddpg']:
4545
args.n_envs = 1
46+
if 'n_agents' not in args:
47+
args.n_agents = 1 # 1 agent for playback
4648

4749
set_global_seeds(args.seed)
4850

@@ -54,7 +56,7 @@
5456

5557
log_dir = args.reward_log if args.reward_log != '' else None
5658

57-
env = create_test_env(env_id, n_envs=args.n_envs, is_atari=is_atari,
59+
env = create_test_env(env_id, n_envs=args.n_envs, n_agents=args.n_agents, is_atari=is_atari,
5860
stats_path=stats_path, norm_reward=args.norm_reward,
5961
seed=args.seed, log_dir=log_dir, should_render=not args.no_render)
6062

‎sb_train.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,26 @@
5151
is_atari = False
5252
if 'NoFrameskip' in env_id:
5353
is_atari = True
54+
is_marathon_envs = False
55+
if 'Marathon' in env_id:
56+
is_marathon_envs = True
5457

5558
print("=" * 10, env_id, "=" * 10)
5659

5760
# Load hyperparameters from yaml file
5861
with open('hyperparams/{}.yml'.format(args.algo), 'r') as f:
5962
if is_atari:
6063
hyperparams = yaml.load(f)['atari']
64+
elif is_marathon_envs:
65+
hyperparams = yaml.load(f)['MarathonEnvs']
6166
else:
62-
# hyperparams = yaml.load(f)['atari']
63-
hyperparams = yaml.load(f)['MarathonHopperEnv-v0']
64-
# hyperparams = yaml.load(f)[env_id]
67+
hyperparams = yaml.load(f)[env_id]
6568

6669
n_envs = hyperparams.get('n_envs', 1)
70+
n_agents = hyperparams.get('n_agents', 1)
6771

6872
print("Using {} environments".format(n_envs))
73+
print("With {} agents per enviroment".format(n_agents))
6974

7075
# Create learning rate schedules for ppo2
7176
if args.algo == "ppo2":
@@ -91,10 +96,14 @@
9196
if 'normalize' in hyperparams.keys():
9297
normalize = hyperparams['normalize']
9398
del hyperparams['normalize']
99+
if args.algo in ['dqn', 'ddpg']:
100+
print("WARNING: normalization not supported yet for DDPG/DQN")
94101

95102
# Delete keys so the dict can be pass to the model constructor
96103
if 'n_envs' in hyperparams.keys():
97104
del hyperparams['n_envs']
105+
if 'n_agents' in hyperparams.keys():
106+
del hyperparams['n_agents']
98107
del hyperparams['n_timesteps']
99108

100109
# Create the environment and wrap it if necessary
@@ -103,17 +112,21 @@
103112
env = make_atari_env(env_id, num_env=n_envs, seed=args.seed)
104113
# Frame-stacking with 4 frames
105114
env = VecFrameStack(env, n_stack=4)
106-
elif args.algo in ['dqn', 'ddpg']:
107-
if hyperparams.get('normalize', False):
108-
print("WARNING: normalization not supported yet for DDPG/DQN")
109-
env = gym.make(env_id)
110-
env.seed(args.seed)
111115
elif 'Marathon' in env_id:
112116
from UnityVecEnv import UnityVecEnv
113-
env = UnityVecEnv(env_id)
117+
if n_agents is 1:
118+
from gym_unity.envs import UnityEnv
119+
env_path = UnityVecEnv.GetFilePath(env_id, n_agents=n_agents)
120+
env = UnityEnv(env_path)
121+
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run
122+
else:
123+
env = UnityVecEnv(env_id, n_agents=n_agents)
114124
if normalize:
115125
print("Normalizing input and return")
116126
env = VecNormalize(env)
127+
elif args.algo in ['dqn', 'ddpg']:
128+
env = gym.make(env_id)
129+
env.seed(args.seed)
117130
else:
118131
if n_envs == 1:
119132
env = DummyVecEnv([make_env(env_id, 0, args.seed)])

‎utils/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,15 @@ def _init():
6363
return _init
6464

6565

66-
def create_test_env(env_id, n_envs=1, is_atari=False,
66+
def create_test_env(env_id, n_envs=1, n_agents=1, is_atari=False,
6767
stats_path=None, norm_reward=False, seed=0,
6868
log_dir='', should_render=True):
6969
"""
7070
Create environment for testing a trained agent
7171
7272
:param env_id: (str)
7373
:param n_envs: (int) number of processes
74+
:param n_agents: (int) number of agents per enviroment
7475
:param is_atari: (bool)
7576
:param stats_path: (str) path to folder containing saved running averaged
7677
:param norm_reward: (bool) Whether to normalize rewards or not when using Vecnormalize

0 commit comments

Comments
 (0)
Please sign in to comment.