Skip to content

Commit dab3b96

Browse files
djbyrneDonal ByrneBordawilliamFalconAdrian Wälchli
authored
Example: Simple RL example using DQN/Lightning (#1232)
* Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <[email protected]> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <[email protected]> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <[email protected]> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <[email protected]> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <[email protected]> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <[email protected]> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <[email protected]> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <[email protected]> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <[email protected]> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <[email protected]> * another rename Co-authored-by: Donal Byrne <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jeremy Jordan <[email protected]> Co-authored-by: Martin.B <[email protected]> Co-authored-by: Tyler Yep <[email protected]> Co-authored-by: Shunta Komatsu <[email protected]> Co-authored-by: Jack Pertschuk <[email protected]>
1 parent 4e0d0ab commit dab3b96

File tree

3 files changed

+363
-1
lines changed

3 files changed

+363
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
1112
- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))
1213
- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
1314
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
"""
2+
# Deep Reinforcement Learning: Deep Q-network (DQN)
3+
4+
this example is based off https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
5+
Second-Edition/blob/master/Chapter06/02_dqn_pong.py
6+
7+
The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the
8+
classic CartPole environment.
9+
10+
to run the template just run:
11+
python dqn.py
12+
13+
After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up tensor boards to
14+
see the metrics.
15+
16+
tensorboard --logdir default
17+
"""
18+
19+
import pytorch_lightning as pl
20+
21+
from typing import Tuple, List
22+
23+
import argparse
24+
from collections import OrderedDict, deque, namedtuple
25+
26+
import gym
27+
import numpy as np
28+
import torch
29+
import torch.nn as nn
30+
import torch.optim as optim
31+
from torch.optim import Optimizer
32+
from torch.utils.data import DataLoader
33+
from torch.utils.data.dataset import IterableDataset
34+
35+
36+
class DQN(nn.Module):
37+
"""
38+
Simple MLP network
39+
40+
Args:
41+
obs_size: observation/state size of the environment
42+
n_actions: number of discrete actions available in the environment
43+
hidden_size: size of hidden layers
44+
"""
45+
46+
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
47+
super(DQN, self).__init__()
48+
self.net = nn.Sequential(
49+
nn.Linear(obs_size, hidden_size),
50+
nn.ReLU(),
51+
nn.Linear(hidden_size, n_actions)
52+
)
53+
54+
def forward(self, x):
55+
return self.net(x.float())
56+
57+
58+
# Named tuple for storing experience steps gathered in training
59+
Experience = namedtuple(
60+
'Experience', field_names=['state', 'action', 'reward',
61+
'done', 'new_state'])
62+
63+
64+
class ReplayBuffer:
65+
"""
66+
Replay Buffer for storing past experiences allowing the agent to learn from them
67+
68+
Args:
69+
capacity: size of the buffer
70+
"""
71+
72+
def __init__(self, capacity: int) -> None:
73+
self.buffer = deque(maxlen=capacity)
74+
75+
def __len__(self) -> None:
76+
return len(self.buffer)
77+
78+
def append(self, experience: Experience) -> None:
79+
"""
80+
Add experience to the buffer
81+
82+
Args:
83+
experience: tuple (state, action, reward, done, new_state)
84+
"""
85+
self.buffer.append(experience)
86+
87+
def sample(self, batch_size: int) -> Tuple:
88+
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
89+
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
90+
91+
return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
92+
np.array(dones, dtype=np.bool), np.array(next_states))
93+
94+
95+
class RLDataset(IterableDataset):
96+
"""
97+
Iterable Dataset containing the ExperienceBuffer
98+
which will be updated with new experiences during training
99+
100+
Args:
101+
buffer: replay buffer
102+
sample_size: number of experiences to sample at a time
103+
"""
104+
105+
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
106+
self.buffer = buffer
107+
self.sample_size = sample_size
108+
109+
def __iter__(self) -> Tuple:
110+
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
111+
for i in range(len(dones)):
112+
yield states[i], actions[i], rewards[i], dones[i], new_states[i]
113+
114+
115+
class Agent:
116+
"""
117+
Base Agent class handeling the interaction with the environment
118+
119+
Args:
120+
env: training environment
121+
replay_buffer: replay buffer storing experiences
122+
"""
123+
124+
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
125+
self.env = env
126+
self.replay_buffer = replay_buffer
127+
self.reset()
128+
self.state = self.env.reset()
129+
130+
def reset(self) -> None:
131+
""" Resents the environment and updates the state"""
132+
self.state = self.env.reset()
133+
134+
def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
135+
"""
136+
Using the given network, decide what action to carry out
137+
using an epsilon-greedy policy
138+
139+
Args:
140+
net: DQN network
141+
epsilon: value to determine likelihood of taking a random action
142+
device: current device
143+
144+
Returns:
145+
action
146+
"""
147+
if np.random.random() < epsilon:
148+
action = self.env.action_space.sample()
149+
else:
150+
state = torch.tensor([self.state])
151+
152+
if device not in ['cpu']:
153+
state = state.cuda(device)
154+
155+
q_values = net(state)
156+
_, action = torch.max(q_values, dim=1)
157+
action = int(action.item())
158+
159+
return action
160+
161+
@torch.no_grad()
162+
def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]:
163+
"""
164+
Carries out a single interaction step between the agent and the environment
165+
166+
Args:
167+
net: DQN network
168+
epsilon: value to determine likelihood of taking a random action
169+
device: current device
170+
171+
Returns:
172+
reward, done
173+
"""
174+
175+
action = self.get_action(net, epsilon, device)
176+
177+
# do step in the environment
178+
new_state, reward, done, _ = self.env.step(action)
179+
180+
exp = Experience(self.state, action, reward, done, new_state)
181+
182+
self.replay_buffer.append(exp)
183+
184+
self.state = new_state
185+
if done:
186+
self.reset()
187+
return reward, done
188+
189+
190+
class DQNLightning(pl.LightningModule):
191+
""" Basic DQN Model """
192+
193+
def __init__(self, hparams: argparse.Namespace) -> None:
194+
super().__init__()
195+
self.hparams = hparams
196+
197+
self.env = gym.make(self.hparams.env)
198+
obs_size = self.env.observation_space.shape[0]
199+
n_actions = self.env.action_space.n
200+
201+
self.net = DQN(obs_size, n_actions)
202+
self.target_net = DQN(obs_size, n_actions)
203+
204+
self.buffer = ReplayBuffer(self.hparams.replay_size)
205+
self.agent = Agent(self.env, self.buffer)
206+
self.total_reward = 0
207+
self.episode_reward = 0
208+
self.populate(self.hparams.warm_start_steps)
209+
210+
def populate(self, steps: int = 1000) -> None:
211+
"""
212+
Carries out several random steps through the environment to initially fill
213+
up the replay buffer with experiences
214+
215+
Args:
216+
steps: number of random steps to populate the buffer with
217+
"""
218+
for i in range(steps):
219+
self.agent.play_step(self.net, epsilon=1.0)
220+
221+
def forward(self, x: torch.Tensor) -> torch.Tensor:
222+
"""
223+
Passes in a state x through the network and gets the q_values of each action as an output
224+
225+
Args:
226+
x: environment state
227+
228+
Returns:
229+
q values
230+
"""
231+
output = self.net(x)
232+
return output
233+
234+
def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
235+
"""
236+
Calculates the mse loss using a mini batch from the replay buffer
237+
238+
Args:
239+
batch: current mini batch of replay data
240+
241+
Returns:
242+
loss
243+
"""
244+
states, actions, rewards, dones, next_states = batch
245+
246+
state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
247+
248+
with torch.no_grad():
249+
next_state_values = self.target_net(next_states).max(1)[0]
250+
next_state_values[dones] = 0.0
251+
next_state_values = next_state_values.detach()
252+
253+
expected_state_action_values = next_state_values * self.hparams.gamma + rewards
254+
255+
return nn.MSELoss()(state_action_values, expected_state_action_values)
256+
257+
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
258+
"""
259+
Carries out a single step through the environment to update the replay buffer.
260+
Then calculates loss based on the minibatch recieved
261+
262+
Args:
263+
batch: current mini batch of replay data
264+
nb_batch: batch number
265+
266+
Returns:
267+
Training loss and log metrics
268+
"""
269+
device = self.get_device(batch)
270+
epsilon = max(self.hparams.eps_end, self.hparams.eps_start -
271+
self.global_step + 1 / self.hparams.eps_last_frame)
272+
273+
# step through environment with agent
274+
reward, done = self.agent.play_step(self.net, epsilon, device)
275+
self.episode_reward += reward
276+
277+
# calculates training loss
278+
loss = self.dqn_mse_loss(batch)
279+
280+
if self.trainer.use_dp or self.trainer.use_ddp2:
281+
loss = loss.unsqueeze(0)
282+
283+
if done:
284+
self.total_reward = self.episode_reward
285+
self.episode_reward = 0
286+
287+
# Soft update of target network
288+
if self.global_step % self.hparams.sync_rate == 0:
289+
self.target_net.load_state_dict(self.net.state_dict())
290+
291+
log = {'total_reward': torch.tensor(self.total_reward).to(device),
292+
'reward': torch.tensor(reward).to(device),
293+
'steps': torch.tensor(self.global_step).to(device)}
294+
295+
return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log})
296+
297+
def configure_optimizers(self) -> List[Optimizer]:
298+
""" Initialize Adam optimizer"""
299+
optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr)
300+
return [optimizer]
301+
302+
def __dataloader(self) -> DataLoader:
303+
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
304+
dataset = RLDataset(self.buffer, self.hparams.episode_length)
305+
dataloader = DataLoader(dataset=dataset,
306+
batch_size=self.hparams.batch_size,
307+
sampler=None
308+
)
309+
return dataloader
310+
311+
def train_dataloader(self) -> DataLoader:
312+
"""Get train loader"""
313+
return self.__dataloader()
314+
315+
def get_device(self, batch) -> str:
316+
"""Retrieve device currently being used by minibatch"""
317+
return batch[0].device.index if self.on_gpu else 'cpu'
318+
319+
320+
def main(hparams) -> None:
321+
model = DQNLightning(hparams)
322+
323+
trainer = pl.Trainer(
324+
gpus=1,
325+
distributed_backend='dp',
326+
early_stop_callback=False,
327+
val_check_interval=100
328+
)
329+
330+
trainer.fit(model)
331+
332+
333+
if __name__ == '__main__':
334+
torch.manual_seed(0)
335+
np.random.seed(0)
336+
337+
parser = argparse.ArgumentParser()
338+
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
339+
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
340+
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
341+
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
342+
parser.add_argument("--sync_rate", type=int, default=10,
343+
help="how many frames do we update the target network")
344+
parser.add_argument("--replay_size", type=int, default=1000,
345+
help="capacity of the replay buffer")
346+
parser.add_argument("--warm_start_size", type=int, default=1000,
347+
help="how many samples do we use to fill our buffer at the start of training")
348+
parser.add_argument("--eps_last_frame", type=int, default=1000,
349+
help="what frame should epsilon stop decaying")
350+
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
351+
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
352+
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
353+
parser.add_argument("--max_episode_reward", type=int, default=200,
354+
help="max episode reward in the environment")
355+
parser.add_argument("--warm_start_steps", type=int, default=1000,
356+
help="max episode reward in the environment")
357+
358+
args = parser.parse_args()
359+
360+
main(args)

pl_examples/requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
torchvision>=0.4.0
1+
torchvision>=0.4.0
2+
gym>=0.17.0

0 commit comments

Comments
 (0)