-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
309 lines (282 loc) · 11.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from flatland_tools.env import Environment, Node
from flatland_tools.agent import TQLearningAgent
from training_utils.training_utils import train, eval_once, log2file
from typing import Tuple
from flatland.envs.agent_utils import TrainState
import pandas as pd
import numpy as np
import concurrent.futures
import os
import uuid
import time
import multiprocessing
class ModifiedEnv(Environment):
"""
A wrapper to the Flatland environment that produces a slightly different observation:
- station_id: the station identifier of the target station
- node_id: the node identifier of the current node
- delay: the delay of the train
- semaphore_edge_0:
- 0: no train coming towards me in edge 0
- 1: train coming towards me in edge 0
- 2: no train coming towards me, but there is a malfunctioning train
- semaphore_edge_1:
same as semaphore_edge_0 but for edge 1
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def observe(self):
"""
Returns [(train_id, node_id, obs, old_reward)]
"""
observations = []
for train_id, node in self.graph.trainsThatAreOnNodes(self.rail_env.agents).items():
node_id = self.node_to_id[node]
obs = self.observe_node_v2(node, train_id)
self.last_obs[node_id] = obs
old_reward = self._get_reward_for_node(node)
observations.append((train_id, node_id, obs, old_reward))
return observations
def observe_node_v2(self, node: Node, train_id: int) -> Tuple[int, int, int, int, int]:
"""
Preconditions: train_id exists and is on a node
Returns (station_id, node_id, delay, semaphore edge 0, sempahore edge 1)
"""
train = self.rail_env.agents[train_id]
# Station identifier
station_id = self.stations[train.target]
# Train delay
delay = self._discretize_delay(train_id, self._compute_delay(train_id))
# Semaphores
sem_0 = self.compute_semaphore_v2(node, 0)
sem_1 = self.compute_semaphore_v2(node, 1)
return station_id, self.node_to_id[node], delay, sem_0, sem_1
def compute_semaphore_v2(self, node: Node, edge_idx: int) -> int:
"""
Returns the semaphore value for the given node and edge index.
- 0: no train coming towards me in edge 0
- 1: train coming towards me in edge 0
- 2: no train coming towards me, but there is a malfunctioning train
"""
_, edge_data = self.graph.fw_star(node)[edge_idx]
malf_present = False
for train in self.rail_env.agents:
row, col = train.position if train.position is not None else train.initial_position
if row == node.row and col == node.col:
continue
# if train is currently on this edge and its direction is opposite to the one that is available on this edge starting from the given node
# or if train is malfunctioning
# semaphore is activated (True)
train_on_edge = edge_data.direction_mask[row, col] != -1
opposite_dir = train.direction != edge_data.direction_mask[row, col]
malfunction = train.state == TrainState.MALFUNCTION
if train_on_edge and opposite_dir:
return 1
elif malfunction:
malf_present = True
return 2 if malf_present else 0
def generate_env(malf_rate, malf_min, malf_max, malf_seed=0):
"""
Returns a modified environment with the given malfunction parameters
Parameters
----------
malf_rate: float
The malfunction rate
malf_min: int
The minimum malfunction duration
malf_max: int
The maximum malfunction duration
malf_seed: int
The seed for the malfunction generator
Returns
-------
ModifiedEnv
The modified environment
"""
return ModifiedEnv(
env_width=40,
env_height=40,
env_n_cities=7,
env_n_trains=5,
seed=13,
destination_bonus=200,
deadlock_penalty=-200,
delay_threshold=0.2,
malfunction_rate=malf_rate,
malfunction_min_duration=malf_min,
malfunction_max_duration=malf_max,
malfunction_seed=malf_seed
)
def train_launcher(*args, **kwargs):
"""
Wrapper for train function that returns metadata
"""
metadata = kwargs.pop('metadata')
train(*args, **kwargs)
return metadata
def eval_batch(
malf_rate: float,
malf_min: int,
malf_max: int,
agent: TQLearningAgent,
malf_seeds: np.ndarray[int],
exp_id: int,
first_eval_id: int,
out_dir: str
):
"""
Evaluates the agent on a batch of environments with the given malfunction parameters
Parameters
----------
malf_rate: float
The malfunction rate
malf_min: int
The minimum malfunction duration
malf_max: int
The maximum malfunction duration
agent: TQLearningAgent
The agent to evaluate
malf_seeds: np.ndarray[int]
The seeds for the malfunction generators
exp_id: int
The experiment id
first_eval_id: int
The id of the first evaluation in this batch
out_dir: str
The output directory
"""
env = generate_env(malf_rate, malf_min, malf_max)
eval_df = []
eval_df_columns = \
['Experiment id', 'Eval id', 'Eval seed', 'Cumulative reward'] + \
[f'Delay {i}' for i in range(env.rail_env.get_num_agents())] + \
['# arrived', '# arrived on time']
for idx, malf_seed in enumerate(malf_seeds):
cumulative_reward, delays, n_arrived, \
n_arrived_on_time, _ = eval_once(
env, agent, malf_seed, False
)
eval_df.append([
exp_id,
first_eval_id + idx,
malf_seed,
cumulative_reward,
] + [
delays[train_id] if train_id in delays else None for train_id in range(env.rail_env.get_num_agents())
] + [
n_arrived,
n_arrived_on_time
])
pd.DataFrame(eval_df, columns=eval_df_columns)\
.to_parquet(os.path.join(out_dir, f'{uuid.uuid4()}.parquet'))
# Malfunction configurations
malf_configs = pd.DataFrame([
[0., 0, 0],
# [1e-3, 5, 15],
# [1e-3, 15, 30],
# [5e-3, 5, 15],
# [5e-3, 15, 30]
], columns=['rate', 'min', 'max'])
# Hyperparameters
hp_configs = pd.DataFrame([
# Best hps from exp 12
[1.0, 0.99997, 0.1, 0.99999, int(4e5)], # blue
[1.0, 0.999965, 0.1, 0.99999, int(4e5)], # orange
[1.0, 0.99997, 0.01, 1., int(4e5)], # green
[1.0, 0.999965, 0.01, 1.00000, int(4e5)], # red
[1.0, 0.999975, 0.1, 0.99999, int(4e5)], # purple
[1.0, 0.999975, 0.01, 1, int(4e5)] # brown
], columns=['epsilon', 'epsilon decay', 'alpha', 'alpha decay', 'n_episodes'])
# Other parameters
out_dir = 'experiments/reproduce_determinstic'
n_workers = multiprocessing.cpu_count()
master_seed = 666
log_every = 10_000
save_every = 10_000
n_evals_calc = lambda _: 1 # lambda malf_rate: int(300 / malf_rate)
eval_batch_size = 10_000
if __name__ == '__main__':
start_time = time.time()
os.makedirs(out_dir, exist_ok=True)
master_rng = np.random.RandomState(master_seed)
master_log_file = os.path.join(out_dir, 'log.txt')
log2file(master_log_file, f'Starting training with {n_workers} workers.')
executor = concurrent.futures.ProcessPoolExecutor(n_workers)
tasks = []
exp_id = 0
for hp_idx in range(len(hp_configs)):
for malf_idx in range(len(malf_configs)):
exp_id += 1
hp = hp_configs.iloc[hp_idx]
malf = malf_configs.iloc[malf_idx]
env = generate_env(malf['rate'], malf['min'], malf['max'])
agent = TQLearningAgent()
# Experiment directory
exp_dir = os.path.join(out_dir, f'{exp_id}')
os.makedirs(exp_dir, exist_ok=True)
# Save config
pd.concat([hp, malf])\
.to_csv(os.path.join(exp_dir, f'{exp_id}_config.csv'))
log2file(master_log_file, f'Experiment {exp_id} started, with hyperparameters:\n{hp}\nand malfunction config:\n{malf}\n')
# Log file
log_file = os.path.join(exp_dir, 'log.txt')
# Q-tables directory
qtables_out_dir = os.path.join(exp_dir, 'qtables')
os.makedirs(qtables_out_dir, exist_ok=True)
# Rng seeds
train_seed = master_rng.randint(2**32)
eval_seed = master_rng.randint(2**32)
np.save(os.path.join(exp_dir, 'train_rng_seed.npy'), train_seed)
np.save(os.path.join(exp_dir, 'eval_rng_seed.npy'), eval_seed)
# Launch experiment
task = executor.submit(
train_launcher,
env,
agent,
int(hp['n_episodes']),
alphas=hp['alpha'] * hp['alpha decay'] ** np.arange(hp['n_episodes']),
epsilons=hp['epsilon'] * hp['epsilon decay'] ** np.arange(hp['n_episodes']),
with_malfunctions=True,
save_node_interactions=False,
rng_master_seed=train_seed,
log_every=log_every,
save_every=save_every,
rewards_out_dir=exp_dir,
qtables_out_dir=qtables_out_dir,
log_file=log_file,
metadata = {
'exp_id': exp_id,
'hp': hp,
'malf': malf,
'eval_seed': eval_seed,
'agent_path': os.path.join(qtables_out_dir, f'qtable_{int(hp["n_episodes"])}.pkl'),
'start_time': time.time()
}
)
tasks.append(task)
eval_tasks = []
eval_df_dir = os.path.join(out_dir, 'eval_results')
os.makedirs(eval_df_dir, exist_ok=True)
for task in concurrent.futures.as_completed(tasks):
metadata = task.result()
malf = metadata['malf']
n_evals = n_evals_calc(malf['rate'])
eval_seeds = np.random.RandomState(metadata['eval_seed']).randint(2**32, size=n_evals)
log2file(master_log_file, f'Experiment {metadata["exp_id"]} completed. Took {time.time() - metadata["start_time"]:.2f} seconds')
log2file(master_log_file, f'Launching {n_evals} evaluations for experiment {metadata["exp_id"]}')
for eval_id in range(0, n_evals, eval_batch_size):
eval_task = executor.submit(
eval_batch,
malf_rate=malf['rate'],
malf_min=malf['min'],
malf_max=malf['max'],
agent=TQLearningAgent.load(metadata['agent_path']),
malf_seeds=eval_seeds[eval_id:eval_id+eval_batch_size],
exp_id=metadata['exp_id'],
first_eval_id=eval_id,
out_dir=eval_df_dir
)
eval_tasks.append(eval_task)
concurrent.futures.wait(eval_tasks)
executor.shutdown()
log2file(master_log_file, f'All experiments completed. Training and evaluation took {time.time() - start_time:.2f} seconds with {n_workers} workers.')