From 575101955b0d1af4704d15579342503b6307bff1 Mon Sep 17 00:00:00 2001 From: vedrenne Date: Thu, 23 May 2024 17:14:57 +0200 Subject: [PATCH 1/5] fix __getitem__ for classes inheriting Transform3d --- pytorch3d/transforms/transform3d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index b2ee2593..444c86ce 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -198,7 +198,9 @@ def __getitem__( """ if isinstance(index, int): index = [index] - return self.__class__(matrix=self.get_matrix()[index]) + instance = self.__class__.__new__(self.__class__) + instance._matrix = self.get_matrix()[index] + return instance def compose(self, *others: "Transform3d") -> "Transform3d": """ From f1ba05e4f1af7a980554a9db75686feaf641dd43 Mon Sep 17 00:00:00 2001 From: vedrenne Date: Tue, 28 May 2024 15:08:15 +0200 Subject: [PATCH 2/5] fix __getitem__ for classes inheriting Transform3d --- pytorch3d/transforms/transform3d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 444c86ce..aa522538 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -200,6 +200,8 @@ def __getitem__( index = [index] instance = self.__class__.__new__(self.__class__) instance._matrix = self.get_matrix()[index] + for attr in ('_transforms', '_lu', 'device', 'dtype'): + setattr(instance, attr, getattr(self, attr)) return instance def compose(self, *others: "Transform3d") -> "Transform3d": From ec0c7826e2a8d5140909b18d1ba0f8378e7be28e Mon Sep 17 00:00:00 2001 From: vedrenne Date: Mon, 10 Jun 2024 14:51:49 +0200 Subject: [PATCH 3/5] fix transforms indexing by implementing __getitem__ for each subclass --- pytorch3d/transforms/transform3d.py | 58 ++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index aa522538..fef29845 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -198,11 +198,7 @@ def __getitem__( """ if isinstance(index, int): index = [index] - instance = self.__class__.__new__(self.__class__) - instance._matrix = self.get_matrix()[index] - for attr in ('_transforms', '_lu', 'device', 'dtype'): - setattr(instance, attr, getattr(self, attr)) - return instance + return self.__class__(matrix=self.get_matrix()[index]) def compose(self, *others: "Transform3d") -> "Transform3d": """ @@ -568,6 +564,22 @@ def _get_matrix_inverse(self) -> torch.Tensor: i_matrix = self._matrix * inv_mask return i_matrix + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(self.get_matrix()[index, 3, :3]) + class Scale(Transform3d): def __init__( @@ -617,6 +629,26 @@ def _get_matrix_inverse(self) -> torch.Tensor: imat = torch.diag_embed(ixyz, dim1=1, dim2=2) return imat + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + mat = self.get_matrix()[index] + x = mat[:, 0, 0] + y = mat[:, 1, 1] + z = mat[:, 2, 2] + return self.__class__(x, y, z) + class Rotate(Transform3d): def __init__( @@ -659,6 +691,22 @@ def _get_matrix_inverse(self) -> torch.Tensor: """ return self._matrix.permute(0, 2, 1).contiguous() + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(self.get_matrix()[index, :3, :3]) + class RotateAxisAngle(Rotate): def __init__( From 23d3cb50bac9ad37475ba38096d680f8346abd5b Mon Sep 17 00:00:00 2001 From: vedrenne Date: Tue, 11 Jun 2024 09:11:32 +0200 Subject: [PATCH 4/5] add getitem tests for Transform3d subclasses --- tests/test_transforms.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5a2d729f..6851afbf 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -685,6 +685,15 @@ def test_inverse(self): self.assertTrue(torch.allclose(im, im_comp)) self.assertTrue(torch.allclose(im, im_2)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + xyz = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32) + t3d = Translate(xyz) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Translate) + class TestScale(unittest.TestCase): def test_single_python_scalar(self): @@ -871,6 +880,15 @@ def test_inverse(self): self.assertTrue(torch.allclose(im, im_comp)) self.assertTrue(torch.allclose(im, im_2)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + s = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32) + t3d = Scale(s) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Scale) + class TestTransformBroadcast(unittest.TestCase): def test_broadcast_transform_points(self): @@ -986,6 +1004,15 @@ def test_inverse(self, batch_size=5): self.assertTrue(torch.allclose(im, im_comp, atol=1e-4)) self.assertTrue(torch.allclose(im, im_2, atol=1e-4)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + r = random_rotations(batch_size, dtype=torch.float32, device=device) + t3d = Rotate(r) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Rotate) + class TestRotateAxisAngle(unittest.TestCase): def test_rotate_x_python_scalar(self): From 2d78e182a5bb9fd0e02bcaab563504ff8b7fe021 Mon Sep 17 00:00:00 2001 From: vedrenne Date: Fri, 13 Sep 2024 17:03:13 +0200 Subject: [PATCH 5/5] add trimming option in icp --- pytorch3d/ops/points_alignment.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/pytorch3d/ops/points_alignment.py b/pytorch3d/ops/points_alignment.py index 96e4b410..786795ca 100644 --- a/pytorch3d/ops/points_alignment.py +++ b/pytorch3d/ops/points_alignment.py @@ -39,6 +39,7 @@ def iterative_closest_point( X: Union[torch.Tensor, "Pointclouds"], Y: Union[torch.Tensor, "Pointclouds"], init_transform: Optional[SimilarityTransform] = None, + trim_fraction: Union[float, torch.Tensor] = 0., max_iterations: int = 100, relative_rmse_thr: float = 1e-6, estimate_scale: bool = False, @@ -67,6 +68,11 @@ def iterative_closest_point( shape `(minibatch, d, d)`, `T` is a batch of translations of shape `(minibatch, d)` and `s` is a batch of scaling factors of shape `(minibatch,)`. + **trim_fraction**: A float or 1d `Tensor` of shape `(minibatch,)` in [0, 1] + specifying the ratio of outliers in each point cloud. If float, assume + the same outliers ratio for all point clouds in the batch. Outliers will + be detected by taking the `trim_fraction * num_points_X` highest values of + `s[i] X[i] R[i] + T[i] = Y[NN[i]]`. **max_iterations**: The maximum number of ICP iterations. **relative_rmse_thr**: A threshold on the relative root mean squared error used to terminate the algorithm. @@ -152,6 +158,17 @@ def iterative_closest_point( T = Xt.new_zeros((b, dim)) s = Xt.new_ones(b) + # initialize trim fraction parameter + if isinstance(trim_fraction, float): + trim_fraction = torch.as_tensor(trim_fraction) + trim_fraction = trim_fraction.to(Xt.device) # type: ignore + if trim_fraction.ndim == 0: + trim_fraction = trim_fraction.repeat(b) + trim = trim_fraction.min() > 0.0 + + # initial mask: no trim considered, only padding + mask = mask_X.bool().clone() + prev_rmse = None rmse = None iteration = -1 @@ -170,7 +187,7 @@ def iterative_closest_point( R, T, s = corresponding_points_alignment( Xt_init, Xt_nn_points, - weights=mask_X, + weights=mask, estimate_scale=estimate_scale, allow_reflection=allow_reflection, ) @@ -184,7 +201,15 @@ def iterative_closest_point( # compute the root mean squared error # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2) - rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0] + + # trimming: select `1 - trim_fraction` lowest distances. + if trim: + diff_thresholds = Xt_sq_diff[mask_X.bool()].quantile(1 - trim_fraction) + mask_trim = Xt_sq_diff < diff_thresholds[:, None] + # final mask is (trim_mask AND pad_mask) + mask = torch.logical_and(mask_trim, mask_X) + + rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask).sqrt()[:, 0, 0] # compute the relative rmse if prev_rmse is None: