Skip to content

Commit abdb7e9

Browse files
committed
update
1 parent 55c71de commit abdb7e9

File tree

4 files changed

+79
-22
lines changed

4 files changed

+79
-22
lines changed

agents/BasicAgent.py

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
class BasicAgent(AbstractAgent):
55
def __init__(self, train, screen_size):
66
super(BasicAgent, self).__init__(screen_size)
7+
self.old_state = None
8+
self.old_action = None
79

810
def step(self, obs):
911
if self._MOVE_SCREEN.id in obs.observation.available_actions:
@@ -32,6 +34,10 @@ def step(self, obs):
3234

3335
assert move != ""
3436

37+
self.old_state = marine_coordinates
38+
self.old_action = move
39+
x = obs.reward
40+
3541
return self._dir_to_sc2_action(move, marine_coordinates)
3642
else:
3743
return self._SELECT_ARMY

agents/QLearningAgent.py

+56-15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import random
2+
13
from agents.AbstractAgent import AbstractAgent
24
import pandas as pd
35
import numpy as np
@@ -25,33 +27,64 @@ def __init__(self, train, screen_size, explore=1):
2527
self.states = []
2628
for x in range(-64, 65):
2729
for y in range(-64, 65):
28-
self.states.append((x, y))
30+
self.states.append("("+str(x) + "," + str(y) + ")")
2931
self.q_table = self.init_q_table()
3032
self.alpha = 0.1
3133
self.gamma = 0.9
3234
self.old_state = None
3335
self.old_action = None
3436

35-
def step(self, obs):
37+
def step(self, obs, epsilon):
3638
# TODO step method
3739
if self._MOVE_SCREEN.id in obs.observation.available_actions:
40+
# get q_state from position
3841
marine = self._get_marine(obs)
3942
if marine is None:
4043
return self._NO_OP
4144
marine_coordinates = self._get_unit_pos(marine)
42-
action = self.get_new_action(marine_coordinates)
45+
beacon = self._get_beacon(obs)
46+
if beacon is None:
47+
return self._NO_OP
48+
beacon_coordinates = self._get_unit_pos(beacon)
49+
50+
q_state = self.get_q_state_from_position(marine_position=marine_coordinates,
51+
beacon_position=beacon_coordinates)
52+
53+
# epsilon integration
54+
rnd = random.random()
55+
if rnd > epsilon:
56+
action = self.get_new_action(q_state)
57+
else:
58+
action = random.choice(list(self.actions))
59+
4360
if self.train:
44-
pass
61+
if self.old_state == None and self.old_action == None:
62+
# first step where there is no previous state
63+
self.old_state = get_row_index_in_string_format(q_state)
64+
self.old_action = action
65+
else:
66+
t = obs.reward == 1 # terminate when beacon reached
67+
self.update_q_value(self.old_state, self.old_action, marine_coordinates, obs.reward, t) # update q_value
68+
69+
# set previous state and action
70+
self.old_state = get_row_index_in_string_format(q_state)
71+
self.old_action = action
72+
73+
return self._dir_to_sc2_action(action, marine_coordinates)
4574
else:
4675
return self._dir_to_sc2_action(action, marine_coordinates)
4776
else:
77+
self.old_state = None
78+
self.old_action = None
4879
return self._SELECT_ARMY # initialize army in first step
4980

5081
def save_model(self, path):
51-
self.q_table.to_pickle(path)
82+
# save model as pkl
83+
self.q_table.to_pickle(path + ".pkl")
5284

5385
def load_model(self, path):
54-
self.q_table = pd.read_pickle(path)
86+
# load model from pkl
87+
self.q_table = pd.read_pickle(path + ".pkl")
5588

5689
def get_new_action(self, state):
5790
"""
@@ -65,8 +98,12 @@ def get_new_action(self, state):
6598
"""
6699
# TODO get_new_action method
67100
index = get_row_index_in_string_format(state)
68-
action = np.argmax(self.q_table.loc[index])
69-
return self.actions[action]
101+
options = self.q_table.loc[index]
102+
m = max(options)
103+
indices = [index for index, value in enumerate(options) if value == m]
104+
choice = random.choice(indices)
105+
action = list(self.actions)[choice]
106+
return action
70107

71108
def get_q_value(self, q_table_column_index, q_table_row_index):
72109
"""
@@ -80,20 +117,22 @@ def get_q_value(self, q_table_column_index, q_table_row_index):
80117
action (float): The value for the given indices.
81118
"""
82119
# TODO get_new_action method
83-
q_value = self.q_table.loc[q_table_row_index, q_table_column_index]
120+
q_value = self.q_table.loc[q_table_column_index][q_table_row_index]
84121
return float(q_value)
85122

86123
def update_q_value(self, old_state, old_action, new_state, reward, terminal):
87124
# TODO update_q_value method
88-
old_state_str = get_row_index_in_string_format(old_state)
89125
new_state_str = get_row_index_in_string_format(new_state)
90-
q_value = self.q_table[old_state_str, old_action]
126+
q_value = self.get_q_value(q_table_column_index=old_state,
127+
q_table_row_index=old_action)
91128
if not terminal:
92-
new_q_value = q_value + self.alpha + (reward + self.gamma * max(self.q_table[new_state_str]) + q_value)
129+
max_new = max(self.q_table.loc[new_state_str])
130+
new_q_value = q_value + self.alpha * (reward + (self.gamma * max_new) - q_value)
93131
else:
94-
new_q_value = q_value + self.alpha + (reward - q_value)
132+
new_q_value = q_value + self.alpha * (reward - q_value)
133+
print("final", old_state, new_q_value)
95134

96-
self.q_table[old_state_str, old_action] = new_q_value
135+
self.q_table.at[old_state, old_action] = new_q_value
97136

98137

99138

@@ -122,4 +161,6 @@ def init_q_table(self):
122161
The row indices must be in the format '(x,y)'
123162
The column indices must be in the format 'action' (e.g. 'W')
124163
"""
125-
return pd.DataFrame(np.random.rand(len(self.states), len(self.actions)), index=self.states, columns=self.actions)
164+
return pd.DataFrame(np.random.rand(len(self.states), len(self.actions)), index=self.states, columns=self.actions)
165+
166+
#return pd.DataFrame(0, index=self.states, columns=self.actions)

runners/main_runner.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def __init__(self, agent, env, train):
1515

1616
self.moving_avg = collections.deque(maxlen=50)
1717
self.score = 0
18+
self.scores = []
1819
self.episode = 1
20+
self.epsilon = 1.0
1921

2022
self.graph_path = Path("..", "graphs", type(agent).__name__, datetime.datetime.now().strftime("%y%m%d_%H%M"),
2123
'train' if self.train else 'run')
@@ -28,31 +30,40 @@ def __init__(self, agent, env, train):
2830

2931
if not self.train and os.path.isdir(self.weights_path_load):
3032
self.agent.load_model(str(self.weights_path_load))
33+
pass
3134
else:
3235
self.weights_path_load.mkdir(parents=True, exist_ok=True)
3336

3437
def summarize(self):
3538
# Graphs in tensorboard
3639
self.writer.add_scalar('Score per Episode', self.score, global_step=self.episode)
37-
40+
self.writer.add_scalar('Epsilon', self.epsilon, global_step=self.episode)
3841
if self.train and self.episode % 10 == 0:
3942
self.agent.save_model(str(self.weights_path_save))
4043
try:
4144
self.agent.update_target_model()
4245
except AttributeError:
4346
...
47+
self.scores.append(self.score)
48+
if len(self.scores) > 50:
49+
self.scores.pop(0)
50+
self.writer.add_scalar('Moving Average', (sum(self.scores) / 50), global_step=self.episode)
4451

4552
self.episode += 1
4653
self.score = 0
4754
self.writer.flush()
4855

49-
def run(self, episodes):
50-
while self.episode <= episodes:
56+
def run(self):
57+
while self.score < 20:
5158
obs = self.env.reset()
5259
while True:
53-
action = self.agent.step(obs)
60+
action = self.agent.step(obs, self.epsilon)
5461
if obs.last():
62+
# epsilon decreases linear until 0.01
63+
if self.epsilon > 0.1:
64+
self.epsilon -= 0.0001
5565
break
66+
5667
obs = self.env.step(action)
5768
self.score += obs.reward
5869

trainScripts/trainQLAgent.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from absl import app
33

44
from env import Env
5-
from runners.basic_runner import Runner
5+
from runners.main_runner import Runner
66
from agents.QLearningAgent import QLearningAgent
77

88
_CONFIG = dict(
9-
episodes=100,
109
screen_size=64,
1110
minimap_size=64,
1211
visualize=False,
@@ -33,7 +32,7 @@ def main(unused_argv):
3332
train=_CONFIG['train']
3433
)
3534

36-
runner.run(episodes=_CONFIG['episodes'])
35+
runner.run()
3736

3837

3938
if __name__ == "__main__":

0 commit comments

Comments
 (0)