3
3
import pickle
4
4
5
5
class TQLearningAgent :
6
+ """
7
+ TQLearningAgent is a class that implements a Distributed Q-learning agent.
8
+
9
+ Parameters
10
+ ----------
11
+ gamma : float
12
+ The discount factor.
13
+ default_q : float
14
+ The default value for the Q-table.
15
+ """
6
16
def __init__ (self , gamma = 1. , default_q = 0. ):
7
17
self .n_actions = 3
8
18
self .gamma = gamma
9
19
self .default = [default_q ] * self .n_actions
10
20
self .q_table = {}
11
21
12
22
def __check_entry (self , state ):
23
+ """
24
+ Checks if the state is in the Q-table and adds it if it is not.
25
+
26
+ Parameters
27
+ ----------
28
+ state : int
29
+ The state to check.
30
+ """
13
31
if state not in self .q_table :
14
32
self .q_table [state ] = self .default .copy ()
15
33
16
34
def eval (self , state , action ):
35
+ """
36
+ Evaluates the Q-value of a state-action pair.
37
+
38
+ Parameters
39
+ ----------
40
+ state : int
41
+ The state.
42
+ action : int
43
+ The action.
44
+
45
+ Returns
46
+ -------
47
+ float
48
+ The Q-value of the state-action pair.
49
+ """
17
50
self .__check_entry (state )
18
51
return self .q_table [state ][action ]
19
52
20
53
def update (self , lr , state , action , reward , next_state = None ):
54
+ """
55
+ Updates the Q-value of a state-action pair.
56
+
57
+ Parameters
58
+ ----------
59
+ lr : float
60
+ The learning rate.
61
+ state : int
62
+ The state.
63
+ action : int
64
+ The action.
65
+ reward : float
66
+ The reward.
67
+ next_state : int
68
+ The next state.
69
+ """
21
70
self .__check_entry (state )
22
71
self .q_table [state ][action ] = \
23
72
(1 - lr ) * self .q_table [state ][action ] + \
24
73
lr * (reward + self .gamma * self .max_q (next_state ))
25
74
26
75
def max_q (self , state ):
76
+ """
77
+ Returns the maximum Q-value of a state.
78
+
79
+ Parameters
80
+ ----------
81
+ state : int
82
+ The state.
83
+
84
+ Returns
85
+ -------
86
+ float
87
+ The maximum Q-value of the state
88
+ """
27
89
self .__check_entry (state )
28
90
return max (self .q_table [state ]) if state is not None else 0.
29
91
30
92
def max_action (self , state ):
93
+ """
94
+ Returns the action that maximizes the Q-value of a state.
95
+
96
+ Parameters
97
+ ----------
98
+ state : int
99
+ The state.
100
+
101
+ Returns
102
+ -------
103
+ int
104
+ The action that maximizes the Q-value of the state.
105
+ """
31
106
self .__check_entry (state )
32
107
return np .argmax (self .q_table [state ])
33
108
34
109
def dump (self , filename : str , mode : str = 'pickle' ):
35
110
"""
36
111
Dumps the agent to a file.
37
- :param filename: The name of the file.
38
- :param mode: The mode of the dump. Can be 'pickle', 'csv', 'parquet'.
112
+
113
+ Parameters
114
+ ----------
115
+ filename : str
116
+ The name of the file.
117
+ mode : str
118
+ The mode of the dump (pickle, csv, parquet).
39
119
"""
40
120
if mode == 'pickle' :
41
121
self .__dump_pickle (filename )
@@ -47,7 +127,11 @@ def dump(self, filename: str, mode: str = 'pickle'):
47
127
def load (filename : str ):
48
128
"""
49
129
Loads the agent from a file (pickle).
50
- :param filename: The name of the file.
130
+
131
+ Parameters
132
+ ----------
133
+ filename : str
134
+ The name of the file.
51
135
"""
52
136
return TQLearningAgent .__load_pickle (filename )
53
137
0 commit comments