|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": 59, |
| 5 | + "execution_count": 1, |
6 | 6 | "metadata": {},
|
7 | 7 | "outputs": [],
|
8 | 8 | "source": [
|
|
15 | 15 | },
|
16 | 16 | {
|
17 | 17 | "cell_type": "code",
|
18 |
| - "execution_count": null, |
| 18 | + "execution_count": 4, |
| 19 | + "metadata": {}, |
| 20 | + "outputs": [], |
| 21 | + "source": [ |
| 22 | + "def pairwise(x: torch.Tensor, dim=1):\n", |
| 23 | + " \"\"\"Split the N x A x 2 state vector to return x_ego (N x A x 2) and x_other (N x A x (A - 1) x 4)\n", |
| 24 | + " where each column index in x_other corresponds to -i agents for agent i\"\"\"\n", |
| 25 | + " x_ = x.unsqueeze(dim=dim)\n", |
| 26 | + " x = x.unsqueeze(dim=dim + 1)\n", |
| 27 | + " _, x_b = torch.broadcast_tensors(x, x_)\n", |
| 28 | + " # Create a mask for the diagonal\n", |
| 29 | + " mask = 1 - torch.eye(x.shape[dim], device=x.device).unsqueeze(0).unsqueeze(-1)\n", |
| 30 | + "\n", |
| 31 | + " # Expand the mask to match the input size\n", |
| 32 | + " mask = mask.expand(*x_b.shape)\n", |
| 33 | + "\n", |
| 34 | + " # Apply the mask to the input tensor\n", |
| 35 | + " result = x_b[mask.bool()].view(x.shape[:2] + (x.shape[1] - 1, x.shape[-1]))\n", |
| 36 | + " return x.squeeze(dim=dim + 1), result" |
| 37 | + ] |
| 38 | + }, |
| 39 | + { |
| 40 | + "cell_type": "code", |
| 41 | + "execution_count": 2, |
19 | 42 | "metadata": {},
|
20 | 43 | "outputs": [],
|
21 | 44 | "source": [
|
|
61 | 84 | " self.tasks = np.array(self.tasks)\n",
|
62 | 85 | "\n",
|
63 | 86 | "class BatchedEnvironment:\n",
|
64 |
| - " def __init__(self, environments):\n", |
| 87 | + " def __init__(self, environments, num_envs):\n", |
65 | 88 | " self.maps = torch.stack([torch.tensor(env.map, dtype=torch.float32) for env in environments])\n",
|
66 | 89 | " self.curr_states = torch.stack([torch.tensor(env.agents) for env in environments])\n",
|
67 |
| - " self.goal_locations = torch.stack([torch.tensor(env.tasks) for env in environments])\n" |
| 90 | + " self.goal_locations = torch.stack([torch.tensor(env.tasks) for env in environments])\n", |
| 91 | + "\n", |
| 92 | + " def compute_goal_achieved(self):\n", |
| 93 | + " return torch.lingalg.norm(self.curr_state - self.goal_locations, dim=-1) == 0\n", |
| 94 | + " \n", |
| 95 | + " def compute_collision(self):\n", |
| 96 | + " \"\"\"Construct the agent by agent distance matrix and check for two agents occupying the same cell\"\"\"\n", |
| 97 | + " x_ego, x_other = pairwise(self.curr_states)\n", |
| 98 | + " collisions = torch.norm(x_ego.unsqueeze(dim=-2) - x_other, dim=-1) < 0.01\n", |
| 99 | + " return collisions.any(dim=-1)" |
68 | 100 | ]
|
69 | 101 | },
|
70 | 102 | {
|
71 | 103 | "cell_type": "code",
|
72 |
| - "execution_count": 55, |
| 104 | + "execution_count": 5, |
73 | 105 | "metadata": {},
|
74 | 106 | "outputs": [],
|
75 | 107 | "source": [
|
76 |
| - "p = [\"/home/aarav/aaravpandya/example_problems/city.domain/paris_200.json\", \"/home/aarav/aaravpandya/example_problems/game.domain/brc202d_200.json\"]\n", |
| 108 | + "p = [\"./example_problems/city.domain/paris_200.json\"]\n", |
77 | 109 | "envs = []\n",
|
78 | 110 | "for e in p:\n",
|
79 | 111 | " envs.append(Env(e))"
|
|
107 | 139 | ],
|
108 | 140 | "metadata": {
|
109 | 141 | "kernelspec": {
|
110 |
| - "display_name": "amazon", |
| 142 | + "display_name": "Python 3.10.12 64-bit", |
111 | 143 | "language": "python",
|
112 | 144 | "name": "python3"
|
113 | 145 | },
|
|
121 | 153 | "name": "python",
|
122 | 154 | "nbconvert_exporter": "python",
|
123 | 155 | "pygments_lexer": "ipython3",
|
124 |
| - "version": "3.11.5" |
| 156 | + "version": "3.10.12" |
| 157 | + }, |
| 158 | + "vscode": { |
| 159 | + "interpreter": { |
| 160 | + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" |
| 161 | + } |
125 | 162 | }
|
126 | 163 | },
|
127 | 164 | "nbformat": 4,
|
|
0 commit comments