Skip to content

Commit fbb9689

Browse files
add goal and collision code
1 parent a44c44a commit fbb9689

File tree

1 file changed

+45
-8
lines changed

1 file changed

+45
-8
lines changed

index.ipynb

+45-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 59,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -15,7 +15,30 @@
1515
},
1616
{
1717
"cell_type": "code",
18-
"execution_count": null,
18+
"execution_count": 4,
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"def pairwise(x: torch.Tensor, dim=1):\n",
23+
" \"\"\"Split the N x A x 2 state vector to return x_ego (N x A x 2) and x_other (N x A x (A - 1) x 4)\n",
24+
" where each column index in x_other corresponds to -i agents for agent i\"\"\"\n",
25+
" x_ = x.unsqueeze(dim=dim)\n",
26+
" x = x.unsqueeze(dim=dim + 1)\n",
27+
" _, x_b = torch.broadcast_tensors(x, x_)\n",
28+
" # Create a mask for the diagonal\n",
29+
" mask = 1 - torch.eye(x.shape[dim], device=x.device).unsqueeze(0).unsqueeze(-1)\n",
30+
"\n",
31+
" # Expand the mask to match the input size\n",
32+
" mask = mask.expand(*x_b.shape)\n",
33+
"\n",
34+
" # Apply the mask to the input tensor\n",
35+
" result = x_b[mask.bool()].view(x.shape[:2] + (x.shape[1] - 1, x.shape[-1]))\n",
36+
" return x.squeeze(dim=dim + 1), result"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 2,
1942
"metadata": {},
2043
"outputs": [],
2144
"source": [
@@ -61,19 +84,28 @@
6184
" self.tasks = np.array(self.tasks)\n",
6285
"\n",
6386
"class BatchedEnvironment:\n",
64-
" def __init__(self, environments):\n",
87+
" def __init__(self, environments, num_envs):\n",
6588
" self.maps = torch.stack([torch.tensor(env.map, dtype=torch.float32) for env in environments])\n",
6689
" self.curr_states = torch.stack([torch.tensor(env.agents) for env in environments])\n",
67-
" self.goal_locations = torch.stack([torch.tensor(env.tasks) for env in environments])\n"
90+
" self.goal_locations = torch.stack([torch.tensor(env.tasks) for env in environments])\n",
91+
"\n",
92+
" def compute_goal_achieved(self):\n",
93+
" return torch.lingalg.norm(self.curr_state - self.goal_locations, dim=-1) == 0\n",
94+
" \n",
95+
" def compute_collision(self):\n",
96+
" \"\"\"Construct the agent by agent distance matrix and check for two agents occupying the same cell\"\"\"\n",
97+
" x_ego, x_other = pairwise(self.curr_states)\n",
98+
" collisions = torch.norm(x_ego.unsqueeze(dim=-2) - x_other, dim=-1) < 0.01\n",
99+
" return collisions.any(dim=-1)"
68100
]
69101
},
70102
{
71103
"cell_type": "code",
72-
"execution_count": 55,
104+
"execution_count": 5,
73105
"metadata": {},
74106
"outputs": [],
75107
"source": [
76-
"p = [\"/home/aarav/aaravpandya/example_problems/city.domain/paris_200.json\", \"/home/aarav/aaravpandya/example_problems/game.domain/brc202d_200.json\"]\n",
108+
"p = [\"./example_problems/city.domain/paris_200.json\"]\n",
77109
"envs = []\n",
78110
"for e in p:\n",
79111
" envs.append(Env(e))"
@@ -107,7 +139,7 @@
107139
],
108140
"metadata": {
109141
"kernelspec": {
110-
"display_name": "amazon",
142+
"display_name": "Python 3.10.12 64-bit",
111143
"language": "python",
112144
"name": "python3"
113145
},
@@ -121,7 +153,12 @@
121153
"name": "python",
122154
"nbconvert_exporter": "python",
123155
"pygments_lexer": "ipython3",
124-
"version": "3.11.5"
156+
"version": "3.10.12"
157+
},
158+
"vscode": {
159+
"interpreter": {
160+
"hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
161+
}
125162
}
126163
},
127164
"nbformat": 4,

0 commit comments

Comments
 (0)