1
+ import random
2
+
1
3
from agents .AbstractAgent import AbstractAgent
2
4
import pandas as pd
3
5
import numpy as np
@@ -25,33 +27,64 @@ def __init__(self, train, screen_size, explore=1):
25
27
self .states = []
26
28
for x in range (- 64 , 65 ):
27
29
for y in range (- 64 , 65 ):
28
- self .states .append (( x , y ) )
30
+ self .states .append ("(" + str ( x ) + "," + str ( y ) + ")" )
29
31
self .q_table = self .init_q_table ()
30
32
self .alpha = 0.1
31
33
self .gamma = 0.9
32
34
self .old_state = None
33
35
self .old_action = None
34
36
35
- def step (self , obs ):
37
+ def step (self , obs , epsilon ):
36
38
# TODO step method
37
39
if self ._MOVE_SCREEN .id in obs .observation .available_actions :
40
+ # get q_state from position
38
41
marine = self ._get_marine (obs )
39
42
if marine is None :
40
43
return self ._NO_OP
41
44
marine_coordinates = self ._get_unit_pos (marine )
42
- action = self .get_new_action (marine_coordinates )
45
+ beacon = self ._get_beacon (obs )
46
+ if beacon is None :
47
+ return self ._NO_OP
48
+ beacon_coordinates = self ._get_unit_pos (beacon )
49
+
50
+ q_state = self .get_q_state_from_position (marine_position = marine_coordinates ,
51
+ beacon_position = beacon_coordinates )
52
+
53
+ # epsilon integration
54
+ rnd = random .random ()
55
+ if rnd > epsilon :
56
+ action = self .get_new_action (q_state )
57
+ else :
58
+ action = random .choice (list (self .actions ))
59
+
43
60
if self .train :
44
- pass
61
+ if self .old_state == None and self .old_action == None :
62
+ # first step where there is no previous state
63
+ self .old_state = get_row_index_in_string_format (q_state )
64
+ self .old_action = action
65
+ else :
66
+ t = obs .reward == 1 # terminate when beacon reached
67
+ self .update_q_value (self .old_state , self .old_action , marine_coordinates , obs .reward , t ) # update q_value
68
+
69
+ # set previous state and action
70
+ self .old_state = get_row_index_in_string_format (q_state )
71
+ self .old_action = action
72
+
73
+ return self ._dir_to_sc2_action (action , marine_coordinates )
45
74
else :
46
75
return self ._dir_to_sc2_action (action , marine_coordinates )
47
76
else :
77
+ self .old_state = None
78
+ self .old_action = None
48
79
return self ._SELECT_ARMY # initialize army in first step
49
80
50
81
def save_model (self , path ):
51
- self .q_table .to_pickle (path )
82
+ # save model as pkl
83
+ self .q_table .to_pickle (path + ".pkl" )
52
84
53
85
def load_model (self , path ):
54
- self .q_table = pd .read_pickle (path )
86
+ # load model from pkl
87
+ self .q_table = pd .read_pickle (path + ".pkl" )
55
88
56
89
def get_new_action (self , state ):
57
90
"""
@@ -65,8 +98,12 @@ def get_new_action(self, state):
65
98
"""
66
99
# TODO get_new_action method
67
100
index = get_row_index_in_string_format (state )
68
- action = np .argmax (self .q_table .loc [index ])
69
- return self .actions [action ]
101
+ options = self .q_table .loc [index ]
102
+ m = max (options )
103
+ indices = [index for index , value in enumerate (options ) if value == m ]
104
+ choice = random .choice (indices )
105
+ action = list (self .actions )[choice ]
106
+ return action
70
107
71
108
def get_q_value (self , q_table_column_index , q_table_row_index ):
72
109
"""
@@ -80,20 +117,22 @@ def get_q_value(self, q_table_column_index, q_table_row_index):
80
117
action (float): The value for the given indices.
81
118
"""
82
119
# TODO get_new_action method
83
- q_value = self .q_table .loc [q_table_row_index , q_table_column_index ]
120
+ q_value = self .q_table .loc [q_table_column_index ][ q_table_row_index ]
84
121
return float (q_value )
85
122
86
123
def update_q_value (self , old_state , old_action , new_state , reward , terminal ):
87
124
# TODO update_q_value method
88
- old_state_str = get_row_index_in_string_format (old_state )
89
125
new_state_str = get_row_index_in_string_format (new_state )
90
- q_value = self .q_table [old_state_str , old_action ]
126
+ q_value = self .get_q_value (q_table_column_index = old_state ,
127
+ q_table_row_index = old_action )
91
128
if not terminal :
92
- new_q_value = q_value + self .alpha + (reward + self .gamma * max (self .q_table [new_state_str ]) + q_value )
129
+ max_new = max (self .q_table .loc [new_state_str ])
130
+ new_q_value = q_value + self .alpha * (reward + (self .gamma * max_new ) - q_value )
93
131
else :
94
- new_q_value = q_value + self .alpha + (reward - q_value )
132
+ new_q_value = q_value + self .alpha * (reward - q_value )
133
+ print ("final" , old_state , new_q_value )
95
134
96
- self .q_table [ old_state_str , old_action ] = new_q_value
135
+ self .q_table . at [ old_state , old_action ] = new_q_value
97
136
98
137
99
138
@@ -122,4 +161,6 @@ def init_q_table(self):
122
161
The row indices must be in the format '(x,y)'
123
162
The column indices must be in the format 'action' (e.g. 'W')
124
163
"""
125
- return pd .DataFrame (np .random .rand (len (self .states ), len (self .actions )), index = self .states , columns = self .actions )
164
+ return pd .DataFrame (np .random .rand (len (self .states ), len (self .actions )), index = self .states , columns = self .actions )
165
+
166
+ #return pd.DataFrame(0, index=self.states, columns=self.actions)
0 commit comments