Skip to content

Commit a28e19a

Browse files
author
gvlos
committed
doc
1 parent 0292e68 commit a28e19a

File tree

4 files changed

+434
-5
lines changed

4 files changed

+434
-5
lines changed

distributed_q_learning/flatland_tools/agent.py

+87-3
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,119 @@
33
import pickle
44

55
class TQLearningAgent:
6+
"""
7+
TQLearningAgent is a class that implements a Distributed Q-learning agent.
8+
9+
Parameters
10+
----------
11+
gamma : float
12+
The discount factor.
13+
default_q : float
14+
The default value for the Q-table.
15+
"""
616
def __init__(self, gamma = 1., default_q = 0.):
717
self.n_actions = 3
818
self.gamma = gamma
919
self.default = [default_q] * self.n_actions
1020
self.q_table = {}
1121

1222
def __check_entry(self, state):
23+
"""
24+
Checks if the state is in the Q-table and adds it if it is not.
25+
26+
Parameters
27+
----------
28+
state : int
29+
The state to check.
30+
"""
1331
if state not in self.q_table:
1432
self.q_table[state] = self.default.copy()
1533

1634
def eval(self, state, action):
35+
"""
36+
Evaluates the Q-value of a state-action pair.
37+
38+
Parameters
39+
----------
40+
state : int
41+
The state.
42+
action : int
43+
The action.
44+
45+
Returns
46+
-------
47+
float
48+
The Q-value of the state-action pair.
49+
"""
1750
self.__check_entry(state)
1851
return self.q_table[state][action]
1952

2053
def update(self, lr, state, action, reward, next_state = None):
54+
"""
55+
Updates the Q-value of a state-action pair.
56+
57+
Parameters
58+
----------
59+
lr : float
60+
The learning rate.
61+
state : int
62+
The state.
63+
action : int
64+
The action.
65+
reward : float
66+
The reward.
67+
next_state : int
68+
The next state.
69+
"""
2170
self.__check_entry(state)
2271
self.q_table[state][action] = \
2372
(1 - lr) * self.q_table[state][action] + \
2473
lr * (reward + self.gamma * self.max_q(next_state))
2574

2675
def max_q(self, state):
76+
"""
77+
Returns the maximum Q-value of a state.
78+
79+
Parameters
80+
----------
81+
state : int
82+
The state.
83+
84+
Returns
85+
-------
86+
float
87+
The maximum Q-value of the state
88+
"""
2789
self.__check_entry(state)
2890
return max(self.q_table[state]) if state is not None else 0.
2991

3092
def max_action(self, state):
93+
"""
94+
Returns the action that maximizes the Q-value of a state.
95+
96+
Parameters
97+
----------
98+
state : int
99+
The state.
100+
101+
Returns
102+
-------
103+
int
104+
The action that maximizes the Q-value of the state.
105+
"""
31106
self.__check_entry(state)
32107
return np.argmax(self.q_table[state])
33108

34109
def dump(self, filename: str, mode: str = 'pickle'):
35110
"""
36111
Dumps the agent to a file.
37-
:param filename: The name of the file.
38-
:param mode: The mode of the dump. Can be 'pickle', 'csv', 'parquet'.
112+
113+
Parameters
114+
----------
115+
filename : str
116+
The name of the file.
117+
mode : str
118+
The mode of the dump (pickle, csv, parquet).
39119
"""
40120
if mode == 'pickle':
41121
self.__dump_pickle(filename)
@@ -47,7 +127,11 @@ def dump(self, filename: str, mode: str = 'pickle'):
47127
def load(filename: str):
48128
"""
49129
Loads the agent from a file (pickle).
50-
:param filename: The name of the file.
130+
131+
Parameters
132+
----------
133+
filename : str
134+
The name of the file.
51135
"""
52136
return TQLearningAgent.__load_pickle(filename)
53137

0 commit comments

Comments
 (0)