Skip to content

Commit 0fa9ce8

Browse files
committed
FromVec broadcasting, increase torch version
1 parent 82b0223 commit 0fa9ce8

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Zachary Teed and Jia Deng, CVPR 2021
2525

2626
### Requirements:
2727
* Cuda >= 10.1 (with nvcc compiler)
28-
* PyTorch >= 1.7
28+
* PyTorch >= 1.8
2929

3030
We recommend installing within a virtual enviornment. Make sure you clone using the `--recursive` flag. If you are using Anaconda, the following command can be used to install all dependencies
3131
```

Diff for: lietorch/group_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(cls, ctx, group_id, *inputs):
8383
def backward(cls, ctx, grad):
8484
inputs = ctx.saved_tensors
8585
J = lietorch_backends.projector(ctx.group_id, *inputs)
86-
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J))
86+
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2)
8787

8888
class ToVec(torch.autograd.Function):
8989
""" convert group object to vector """
@@ -98,5 +98,5 @@ def forward(cls, ctx, group_id, *inputs):
9898
def backward(cls, ctx, grad):
9999
inputs = ctx.saved_tensors
100100
J = lietorch_backends.projector(ctx.group_id, *inputs)
101-
return None, torch.matmul(grad.unsqueeze(-2), J)
101+
return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)
102102

Diff for: lietorch/run_tests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def fn(a):
218218
return Group.InitFromVec(a).vec()
219219

220220
D = Group.embedded_dim
221-
a = torch.randn(1, D, requires_grad=True, device=device).double()
221+
a = torch.randn(1, 2, D, requires_grad=True, device=device).double()
222222

223223
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
224224

0 commit comments

Comments
 (0)