|
| 1 | +""" |
| 2 | +# Deep Reinforcement Learning: Deep Q-network (DQN) |
| 3 | +
|
| 4 | +this example is based off https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- |
| 5 | +Second-Edition/blob/master/Chapter06/02_dqn_pong.py |
| 6 | +
|
| 7 | +The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the |
| 8 | +classic CartPole environment. |
| 9 | +
|
| 10 | +to run the template just run: |
| 11 | +python dqn.py |
| 12 | +
|
| 13 | +After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up tensor boards to |
| 14 | +see the metrics. |
| 15 | +
|
| 16 | +tensorboard --logdir default |
| 17 | +""" |
| 18 | + |
| 19 | +import pytorch_lightning as pl |
| 20 | + |
| 21 | +from typing import Tuple, List |
| 22 | + |
| 23 | +import argparse |
| 24 | +from collections import OrderedDict, deque, namedtuple |
| 25 | + |
| 26 | +import gym |
| 27 | +import numpy as np |
| 28 | +import torch |
| 29 | +import torch.nn as nn |
| 30 | +import torch.optim as optim |
| 31 | +from torch.optim import Optimizer |
| 32 | +from torch.utils.data import DataLoader |
| 33 | +from torch.utils.data.dataset import IterableDataset |
| 34 | + |
| 35 | + |
| 36 | +class DQN(nn.Module): |
| 37 | + """ |
| 38 | + Simple MLP network |
| 39 | +
|
| 40 | + Args: |
| 41 | + obs_size: observation/state size of the environment |
| 42 | + n_actions: number of discrete actions available in the environment |
| 43 | + hidden_size: size of hidden layers |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): |
| 47 | + super(DQN, self).__init__() |
| 48 | + self.net = nn.Sequential( |
| 49 | + nn.Linear(obs_size, hidden_size), |
| 50 | + nn.ReLU(), |
| 51 | + nn.Linear(hidden_size, n_actions) |
| 52 | + ) |
| 53 | + |
| 54 | + def forward(self, x): |
| 55 | + return self.net(x.float()) |
| 56 | + |
| 57 | + |
| 58 | +# Named tuple for storing experience steps gathered in training |
| 59 | +Experience = namedtuple( |
| 60 | + 'Experience', field_names=['state', 'action', 'reward', |
| 61 | + 'done', 'new_state']) |
| 62 | + |
| 63 | + |
| 64 | +class ReplayBuffer: |
| 65 | + """ |
| 66 | + Replay Buffer for storing past experiences allowing the agent to learn from them |
| 67 | +
|
| 68 | + Args: |
| 69 | + capacity: size of the buffer |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, capacity: int) -> None: |
| 73 | + self.buffer = deque(maxlen=capacity) |
| 74 | + |
| 75 | + def __len__(self) -> None: |
| 76 | + return len(self.buffer) |
| 77 | + |
| 78 | + def append(self, experience: Experience) -> None: |
| 79 | + """ |
| 80 | + Add experience to the buffer |
| 81 | +
|
| 82 | + Args: |
| 83 | + experience: tuple (state, action, reward, done, new_state) |
| 84 | + """ |
| 85 | + self.buffer.append(experience) |
| 86 | + |
| 87 | + def sample(self, batch_size: int) -> Tuple: |
| 88 | + indices = np.random.choice(len(self.buffer), batch_size, replace=False) |
| 89 | + states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices]) |
| 90 | + |
| 91 | + return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), |
| 92 | + np.array(dones, dtype=np.bool), np.array(next_states)) |
| 93 | + |
| 94 | + |
| 95 | +class RLDataset(IterableDataset): |
| 96 | + """ |
| 97 | + Iterable Dataset containing the ExperienceBuffer |
| 98 | + which will be updated with new experiences during training |
| 99 | +
|
| 100 | + Args: |
| 101 | + buffer: replay buffer |
| 102 | + sample_size: number of experiences to sample at a time |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: |
| 106 | + self.buffer = buffer |
| 107 | + self.sample_size = sample_size |
| 108 | + |
| 109 | + def __iter__(self) -> Tuple: |
| 110 | + states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size) |
| 111 | + for i in range(len(dones)): |
| 112 | + yield states[i], actions[i], rewards[i], dones[i], new_states[i] |
| 113 | + |
| 114 | + |
| 115 | +class Agent: |
| 116 | + """ |
| 117 | + Base Agent class handeling the interaction with the environment |
| 118 | +
|
| 119 | + Args: |
| 120 | + env: training environment |
| 121 | + replay_buffer: replay buffer storing experiences |
| 122 | + """ |
| 123 | + |
| 124 | + def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: |
| 125 | + self.env = env |
| 126 | + self.replay_buffer = replay_buffer |
| 127 | + self.reset() |
| 128 | + self.state = self.env.reset() |
| 129 | + |
| 130 | + def reset(self) -> None: |
| 131 | + """ Resents the environment and updates the state""" |
| 132 | + self.state = self.env.reset() |
| 133 | + |
| 134 | + def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: |
| 135 | + """ |
| 136 | + Using the given network, decide what action to carry out |
| 137 | + using an epsilon-greedy policy |
| 138 | +
|
| 139 | + Args: |
| 140 | + net: DQN network |
| 141 | + epsilon: value to determine likelihood of taking a random action |
| 142 | + device: current device |
| 143 | +
|
| 144 | + Returns: |
| 145 | + action |
| 146 | + """ |
| 147 | + if np.random.random() < epsilon: |
| 148 | + action = self.env.action_space.sample() |
| 149 | + else: |
| 150 | + state = torch.tensor([self.state]) |
| 151 | + |
| 152 | + if device not in ['cpu']: |
| 153 | + state = state.cuda(device) |
| 154 | + |
| 155 | + q_values = net(state) |
| 156 | + _, action = torch.max(q_values, dim=1) |
| 157 | + action = int(action.item()) |
| 158 | + |
| 159 | + return action |
| 160 | + |
| 161 | + @torch.no_grad() |
| 162 | + def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]: |
| 163 | + """ |
| 164 | + Carries out a single interaction step between the agent and the environment |
| 165 | +
|
| 166 | + Args: |
| 167 | + net: DQN network |
| 168 | + epsilon: value to determine likelihood of taking a random action |
| 169 | + device: current device |
| 170 | +
|
| 171 | + Returns: |
| 172 | + reward, done |
| 173 | + """ |
| 174 | + |
| 175 | + action = self.get_action(net, epsilon, device) |
| 176 | + |
| 177 | + # do step in the environment |
| 178 | + new_state, reward, done, _ = self.env.step(action) |
| 179 | + |
| 180 | + exp = Experience(self.state, action, reward, done, new_state) |
| 181 | + |
| 182 | + self.replay_buffer.append(exp) |
| 183 | + |
| 184 | + self.state = new_state |
| 185 | + if done: |
| 186 | + self.reset() |
| 187 | + return reward, done |
| 188 | + |
| 189 | + |
| 190 | +class DQNLightning(pl.LightningModule): |
| 191 | + """ Basic DQN Model """ |
| 192 | + |
| 193 | + def __init__(self, hparams: argparse.Namespace) -> None: |
| 194 | + super().__init__() |
| 195 | + self.hparams = hparams |
| 196 | + |
| 197 | + self.env = gym.make(self.hparams.env) |
| 198 | + obs_size = self.env.observation_space.shape[0] |
| 199 | + n_actions = self.env.action_space.n |
| 200 | + |
| 201 | + self.net = DQN(obs_size, n_actions) |
| 202 | + self.target_net = DQN(obs_size, n_actions) |
| 203 | + |
| 204 | + self.buffer = ReplayBuffer(self.hparams.replay_size) |
| 205 | + self.agent = Agent(self.env, self.buffer) |
| 206 | + self.total_reward = 0 |
| 207 | + self.episode_reward = 0 |
| 208 | + self.populate(self.hparams.warm_start_steps) |
| 209 | + |
| 210 | + def populate(self, steps: int = 1000) -> None: |
| 211 | + """ |
| 212 | + Carries out several random steps through the environment to initially fill |
| 213 | + up the replay buffer with experiences |
| 214 | +
|
| 215 | + Args: |
| 216 | + steps: number of random steps to populate the buffer with |
| 217 | + """ |
| 218 | + for i in range(steps): |
| 219 | + self.agent.play_step(self.net, epsilon=1.0) |
| 220 | + |
| 221 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 222 | + """ |
| 223 | + Passes in a state x through the network and gets the q_values of each action as an output |
| 224 | +
|
| 225 | + Args: |
| 226 | + x: environment state |
| 227 | +
|
| 228 | + Returns: |
| 229 | + q values |
| 230 | + """ |
| 231 | + output = self.net(x) |
| 232 | + return output |
| 233 | + |
| 234 | + def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
| 235 | + """ |
| 236 | + Calculates the mse loss using a mini batch from the replay buffer |
| 237 | +
|
| 238 | + Args: |
| 239 | + batch: current mini batch of replay data |
| 240 | +
|
| 241 | + Returns: |
| 242 | + loss |
| 243 | + """ |
| 244 | + states, actions, rewards, dones, next_states = batch |
| 245 | + |
| 246 | + state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) |
| 247 | + |
| 248 | + with torch.no_grad(): |
| 249 | + next_state_values = self.target_net(next_states).max(1)[0] |
| 250 | + next_state_values[dones] = 0.0 |
| 251 | + next_state_values = next_state_values.detach() |
| 252 | + |
| 253 | + expected_state_action_values = next_state_values * self.hparams.gamma + rewards |
| 254 | + |
| 255 | + return nn.MSELoss()(state_action_values, expected_state_action_values) |
| 256 | + |
| 257 | + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: |
| 258 | + """ |
| 259 | + Carries out a single step through the environment to update the replay buffer. |
| 260 | + Then calculates loss based on the minibatch recieved |
| 261 | +
|
| 262 | + Args: |
| 263 | + batch: current mini batch of replay data |
| 264 | + nb_batch: batch number |
| 265 | +
|
| 266 | + Returns: |
| 267 | + Training loss and log metrics |
| 268 | + """ |
| 269 | + device = self.get_device(batch) |
| 270 | + epsilon = max(self.hparams.eps_end, self.hparams.eps_start - |
| 271 | + self.global_step + 1 / self.hparams.eps_last_frame) |
| 272 | + |
| 273 | + # step through environment with agent |
| 274 | + reward, done = self.agent.play_step(self.net, epsilon, device) |
| 275 | + self.episode_reward += reward |
| 276 | + |
| 277 | + # calculates training loss |
| 278 | + loss = self.dqn_mse_loss(batch) |
| 279 | + |
| 280 | + if self.trainer.use_dp or self.trainer.use_ddp2: |
| 281 | + loss = loss.unsqueeze(0) |
| 282 | + |
| 283 | + if done: |
| 284 | + self.total_reward = self.episode_reward |
| 285 | + self.episode_reward = 0 |
| 286 | + |
| 287 | + # Soft update of target network |
| 288 | + if self.global_step % self.hparams.sync_rate == 0: |
| 289 | + self.target_net.load_state_dict(self.net.state_dict()) |
| 290 | + |
| 291 | + log = {'total_reward': torch.tensor(self.total_reward).to(device), |
| 292 | + 'reward': torch.tensor(reward).to(device), |
| 293 | + 'steps': torch.tensor(self.global_step).to(device)} |
| 294 | + |
| 295 | + return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log}) |
| 296 | + |
| 297 | + def configure_optimizers(self) -> List[Optimizer]: |
| 298 | + """ Initialize Adam optimizer""" |
| 299 | + optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr) |
| 300 | + return [optimizer] |
| 301 | + |
| 302 | + def __dataloader(self) -> DataLoader: |
| 303 | + """Initialize the Replay Buffer dataset used for retrieving experiences""" |
| 304 | + dataset = RLDataset(self.buffer, self.hparams.episode_length) |
| 305 | + dataloader = DataLoader(dataset=dataset, |
| 306 | + batch_size=self.hparams.batch_size, |
| 307 | + sampler=None |
| 308 | + ) |
| 309 | + return dataloader |
| 310 | + |
| 311 | + def train_dataloader(self) -> DataLoader: |
| 312 | + """Get train loader""" |
| 313 | + return self.__dataloader() |
| 314 | + |
| 315 | + def get_device(self, batch) -> str: |
| 316 | + """Retrieve device currently being used by minibatch""" |
| 317 | + return batch[0].device.index if self.on_gpu else 'cpu' |
| 318 | + |
| 319 | + |
| 320 | +def main(hparams) -> None: |
| 321 | + model = DQNLightning(hparams) |
| 322 | + |
| 323 | + trainer = pl.Trainer( |
| 324 | + gpus=1, |
| 325 | + distributed_backend='dp', |
| 326 | + early_stop_callback=False, |
| 327 | + val_check_interval=100 |
| 328 | + ) |
| 329 | + |
| 330 | + trainer.fit(model) |
| 331 | + |
| 332 | + |
| 333 | +if __name__ == '__main__': |
| 334 | + torch.manual_seed(0) |
| 335 | + np.random.seed(0) |
| 336 | + |
| 337 | + parser = argparse.ArgumentParser() |
| 338 | + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") |
| 339 | + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") |
| 340 | + parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") |
| 341 | + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") |
| 342 | + parser.add_argument("--sync_rate", type=int, default=10, |
| 343 | + help="how many frames do we update the target network") |
| 344 | + parser.add_argument("--replay_size", type=int, default=1000, |
| 345 | + help="capacity of the replay buffer") |
| 346 | + parser.add_argument("--warm_start_size", type=int, default=1000, |
| 347 | + help="how many samples do we use to fill our buffer at the start of training") |
| 348 | + parser.add_argument("--eps_last_frame", type=int, default=1000, |
| 349 | + help="what frame should epsilon stop decaying") |
| 350 | + parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") |
| 351 | + parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") |
| 352 | + parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") |
| 353 | + parser.add_argument("--max_episode_reward", type=int, default=200, |
| 354 | + help="max episode reward in the environment") |
| 355 | + parser.add_argument("--warm_start_steps", type=int, default=1000, |
| 356 | + help="max episode reward in the environment") |
| 357 | + |
| 358 | + args = parser.parse_args() |
| 359 | + |
| 360 | + main(args) |
0 commit comments