From f217eb1fcc2f3ee561aacab161488b2142aafb9a Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 14 Mar 2023 17:55:25 +0000 Subject: [PATCH 1/9] moved crop_by_bbox to FrameData as method --- pytorch3d/implicitron/dataset/blob_loader.py | 181 ++---------------- pytorch3d/implicitron/dataset/dataset_base.py | 32 ++++ pytorch3d/implicitron/dataset/utils.py | 101 ++++++++++ tests/implicitron/test_bbox.py | 5 +- 4 files changed, 152 insertions(+), 167 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py index 6d0dc7fa4..ce59c542d 100644 --- a/pytorch3d/implicitron/dataset/blob_loader.py +++ b/pytorch3d/implicitron/dataset/blob_loader.py @@ -20,6 +20,9 @@ from pytorch3d.io import IO from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.implicitron.dataset.utils import ( + _get_bbox_from_mask, +) @dataclass @@ -85,9 +88,7 @@ def load_( ( frame_data.fg_probability, frame_data.mask_path, - frame_data.bbox_xywh, - clamp_bbox_xyxy, - frame_data.crop_bbox_xywh, + bbox_xywh, ) = self._load_crop_fg_probability(entry) scale = min( @@ -103,23 +104,17 @@ def load_( frame_data.image_path, frame_data.mask_crop, scale, - ) = self._load_crop_images( - entry, frame_data.fg_probability, clamp_bbox_xyxy - ) + ) = self._load_crop_images(entry, frame_data.fg_probability) if self.load_depths and entry.depth is not None: ( frame_data.depth_map, frame_data.depth_path, frame_data.depth_mask, - ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) + ) = self._load_mask_depth(entry, frame_data.fg_probability) if entry.viewpoint is not None: - frame_data.camera = self._get_pytorch3d_camera( - entry, - scale, - clamp_bbox_xyxy, - ) + frame_data.camera = self._get_pytorch3d_camera(entry, scale) if self.load_point_clouds and seq_annotation.point_cloud is not None: pcl_path = self._fix_point_cloud_path(seq_annotation.point_cloud.path) @@ -128,45 +123,28 @@ def load_( ) frame_data.sequence_point_cloud_path = pcl_path + if self.box_crop: + frame_data.crop_by_bbox(bbox_xywh, self.box_crop_context, ) + + return frame_data + def _load_crop_fg_probability( self, entry: types.FrameAnnotation - ) -> Tuple[ - Optional[torch.Tensor], - Optional[str], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: + ) -> Tuple[Optional[torch.Tensor],Optional[str],Optional[torch.Tensor]]: fg_probability = None full_path = None bbox_xywh = None - clamp_bbox_xyxy = None - crop_box_xywh = None - if (self.load_masks or self.box_crop) and entry.mask is not None: + if (self.load_masks) and entry.mask is not None: full_path = os.path.join(self.dataset_root, entry.mask.path) mask = _load_mask(self._local_path(full_path)) + bbox_xywh = torch.tensor(_get_bbox_from_mask(self.mask, self.box_crop_mask_thr)) if mask.shape[-2:] != entry.image.size: raise ValueError( f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" ) - bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) - - if self.box_crop: - clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( - _get_clamp_bbox( - bbox_xywh, - image_path=entry.image.path, - box_crop_context=self.box_crop_context, - ), - image_size_hw=tuple(mask.shape[-2:]), - ) - crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) - - mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) - fg_probability, _, _ = _resize_image( mask, image_height=self.image_height, @@ -174,13 +152,12 @@ def _load_crop_fg_probability( mode="nearest", ) - return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh + return fg_probability, full_path, bbox_xywh def _load_crop_images( self, entry: types.FrameAnnotation, fg_probability: Optional[torch.Tensor], - clamp_bbox_xyxy: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: assert self.dataset_root is not None and entry.image is not None path = os.path.join(self.dataset_root, entry.image.path) @@ -191,10 +168,6 @@ def _load_crop_images( f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" ) - if self.box_crop: - assert clamp_bbox_xyxy is not None - image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) - image_rgb, scale, mask_crop = _resize_image( image_rgb, image_height=self.image_height, image_width=self.image_width ) @@ -208,7 +181,6 @@ def _load_crop_images( def _load_mask_depth( self, entry: types.FrameAnnotation, - clamp_bbox_xyxy: Optional[torch.Tensor], fg_probability: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, str, torch.Tensor]: entry_depth = entry.depth @@ -216,13 +188,6 @@ def _load_mask_depth( path = os.path.join(self.dataset_root, entry_depth.path) depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] - ) - depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) - depth_map, _, _ = _resize_image( depth_map, image_height=self.image_height, @@ -239,15 +204,6 @@ def _load_mask_depth( mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) depth_mask = _load_depth_mask(self._local_path(mask_path)) - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_mask_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] - ) - depth_mask = _crop_around_box( - depth_mask, depth_mask_bbox_xyxy, mask_path - ) - depth_mask, _, _ = _resize_image( depth_mask, image_height=self.image_height, @@ -263,7 +219,6 @@ def _get_pytorch3d_camera( self, entry: types.FrameAnnotation, scale: float, - clamp_bbox_xyxy: Optional[torch.Tensor], ) -> PerspectiveCameras: entry_viewpoint = entry.viewpoint assert entry_viewpoint is not None @@ -290,9 +245,6 @@ def _get_pytorch3d_camera( # principal point and focal length in pixels principal_point_px = half_image_size_wh_orig - principal_point * rescale focal_length_px = focal_length * rescale - if self.box_crop: - assert clamp_bbox_xyxy is not None - principal_point_px -= clamp_bbox_xyxy[:2] # now, convert from pixels to PyTorch3D v0.5+ NDC convention if self.image_height is None or self.image_width is None: @@ -375,84 +327,6 @@ def _load_mask(path) -> np.ndarray: return mask[None] # fake feature channel -def _get_bbox_from_mask( - mask, thr, decrease_quant: float = 0.05 -) -> Tuple[int, int, int, int]: - # bbox in xywh - masks_for_box = np.zeros_like(mask) - while masks_for_box.sum() <= 1.0: - masks_for_box = (mask > thr).astype(np.float32) - thr -= decrease_quant - if thr <= 0.0: - warnings.warn( - f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1 - ) - - x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) - y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) - - return x0, y0, x1 - x0, y1 - y0 - - -def _crop_around_box(tensor, bbox, impath: str = ""): - # bbox is xyxy, where the upper bound is corrected with +1 - bbox = _clamp_box_to_image_bounds_and_round( - bbox, - image_size_hw=tensor.shape[-2:], - ) - tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] - assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" - return tensor - - -def _clamp_box_to_image_bounds_and_round( - bbox_xyxy: torch.Tensor, - image_size_hw: Tuple[int, int], -) -> torch.LongTensor: - bbox_xyxy = bbox_xyxy.clone() - bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) - bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) - if not isinstance(bbox_xyxy, torch.LongTensor): - bbox_xyxy = bbox_xyxy.round().long() - return bbox_xyxy # pyre-ignore [7] - - -def _get_clamp_bbox( - bbox: torch.Tensor, - box_crop_context: float = 0.0, - image_path: str = "", -) -> torch.Tensor: - # box_crop_context: rate of expansion for bbox - # returns possibly expanded bbox xyxy as float - - bbox = bbox.clone() # do not edit bbox in place - - # increase box size - if box_crop_context > 0.0: - c = box_crop_context - bbox = bbox.float() - bbox[0] -= bbox[2] * c / 2 - bbox[1] -= bbox[3] * c / 2 - bbox[2] += bbox[2] * c - bbox[3] += bbox[3] * c - - if (bbox[2:] <= 1.0).any(): - raise ValueError( - f"squashed image {image_path}!! The bounding box contains no pixels." - ) - - bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes - bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) - - return bbox_xyxy - - -def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: - wh = xyxy[2:] - xyxy[:2] - xywh = torch.cat([xyxy[:2], wh]) - return xywh - - def _load_depth(path, scale_adjustment) -> np.ndarray: if not path.lower().endswith(".png"): raise ValueError('unsupported depth file name "%s"' % path) @@ -474,14 +348,6 @@ def _load_16big_png_depth(depth_png) -> np.ndarray: return depth -def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: - assert bbox is not None - assert np.prod(orig_res) > 1e-8 - # average ratio of dimensions - rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 - return bbox * rel_size - - def _load_1bit_png_mask(file: str) -> np.ndarray: with Image.open(file) as pil_im: mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) @@ -495,21 +361,6 @@ def _load_depth_mask(path: str) -> np.ndarray: return m[None] # fake feature channel -def _get_1d_bounds(arr) -> Tuple[int, int]: - nz = np.flatnonzero(arr) - return nz[0], nz[-1] + 1 - - -def _bbox_xywh_to_xyxy( - xywh: torch.Tensor, clamp_size: Optional[int] = None -) -> torch.Tensor: - xyxy = xywh.clone() - if clamp_size is not None: - xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) - xyxy[2:] += xyxy[:2] - return xyxy - - def _safe_as_tensor(data, dtype): return torch.tensor(data, dtype=dtype) if data is not None else None diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 283ef3dcd..322d1889b 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -26,6 +26,13 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds +from pytorch3d.implicitron.dataset.utils import ( + _crop_around_box, + _clamp_box_to_image_bounds_and_round, + _bbox_xyxy_to_xywh, + _get_clamp_bbox, + _rescale_bbox, +) @dataclass @@ -144,6 +151,31 @@ def __getitem__(self, key): def __len__(self): return len(fields(self)) + def crop_by_bbox(self, bbox_xywh, box_crop_context): + clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( + _get_clamp_bbox( + bbox_xywh, + image_path=self.image.path, + box_crop_context=box_crop_context, + ), + image_size_hw=tuple(self.fg_probability.shape[-2:]), + ) + self.crop_bbox_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) + + self.fg_probability = _crop_around_box( + self.fg_probability, clamp_bbox_xyxy, self.mask_path + ) + self.image_rgb = _crop_around_box(self.image_rgb, clamp_bbox_xyxy, self.image.path) + + depth_bbox_xyxy = _rescale_bbox(clamp_bbox_xyxy, entry.image.size, self.depth_map.shape[-2:]) + self.depth_map = _crop_around_box(self.depth_map, depth_bbox_xyxy, self.depth_path) + + depth_mask_bbox_xyxy = _rescale_bbox(clamp_bbox_xyxy, entry.image.size, self.depth_mask.shape[-2:]) + self.depth_mask = _crop_around_box(self.depth_mask, depth_mask_bbox_xyxy, self.mask_path) + + + principal_point_px -= clamp_bbox_xyxy[:2] + @classmethod def collate(cls, batch): """ diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index 05252aff1..b2ac99f36 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -52,3 +52,104 @@ def is_train_frame( dtype=torch.bool, device=device, ) + + +def _get_bbox_from_mask( + mask, thr, decrease_quant: float = 0.05 + ) -> Tuple[int, int, int, int]: + # bbox in xywh + masks_for_box = np.zeros_like(mask) + while masks_for_box.sum() <= 1.0: + masks_for_box = (mask > thr).astype(np.float32) + thr -= decrease_quant + if thr <= 0.0: + warnings.warn( + f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1 + ) + + x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) + y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) + + return x0, y0, x1 - x0, y1 - y0 + + +def _crop_around_box(tensor, bbox, impath: str = ""): + # bbox is xyxy, where the upper bound is corrected with +1 + bbox = _clamp_box_to_image_bounds_and_round( + bbox, + image_size_hw=tensor.shape[-2:], + ) + tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] + assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" + return tensor + + +def _clamp_box_to_image_bounds_and_round( + bbox_xyxy: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> torch.LongTensor: + bbox_xyxy = bbox_xyxy.clone() + bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) + bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) + if not isinstance(bbox_xyxy, torch.LongTensor): + bbox_xyxy = bbox_xyxy.round().long() + return bbox_xyxy # pyre-ignore [7] + + +def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + wh = xyxy[2:] - xyxy[:2] + xywh = torch.cat([xyxy[:2], wh]) + return xywh + + +def _get_clamp_bbox( + bbox: torch.Tensor, + box_crop_context: float = 0.0, + image_path: str = "", +) -> torch.Tensor: + # box_crop_context: rate of expansion for bbox + # returns possibly expanded bbox xyxy as float + + bbox = bbox.clone() # do not edit bbox in place + + # increase box size + if box_crop_context > 0.0: + c = box_crop_context + bbox = bbox.float() + bbox[0] -= bbox[2] * c / 2 + bbox[1] -= bbox[3] * c / 2 + bbox[2] += bbox[2] * c + bbox[3] += bbox[3] * c + + if (bbox[2:] <= 1.0).any(): + raise ValueError( + f"squashed image {image_path}!! The bounding box contains no pixels." + ) + + bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes + bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) + + return bbox_xyxy + + +def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: + assert bbox is not None + assert np.prod(orig_res) > 1e-8 + # average ratio of dimensions + rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 + return bbox * rel_size + + +def _bbox_xywh_to_xyxy( + xywh: torch.Tensor, clamp_size: Optional[int] = None +) -> torch.Tensor: + xyxy = xywh.clone() + if clamp_size is not None: + xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) + xyxy[2:] += xyxy[:2] + return xyxy + + +def _get_1d_bounds(arr) -> Tuple[int, int]: + nz = np.flatnonzero(arr) + return nz[0], nz[-1] + 1 diff --git a/tests/implicitron/test_bbox.py b/tests/implicitron/test_bbox.py index 48a8421bb..16199ad1e 100644 --- a/tests/implicitron/test_bbox.py +++ b/tests/implicitron/test_bbox.py @@ -9,7 +9,9 @@ import numpy as np import torch -from pytorch3d.implicitron.dataset.blob_loader import ( +from pytorch3d.implicitron.dataset.blob_loader import _resize_image + +from pytorch3d.implicitron.dataset.utils import ( _bbox_xywh_to_xyxy, _bbox_xyxy_to_xywh, _clamp_box_to_image_bounds_and_round, @@ -18,7 +20,6 @@ _get_bbox_from_mask, _get_clamp_bbox, _rescale_bbox, - _resize_image, ) from tests.common_testing import TestCaseMixin From 664d35d66de59e815f3feec581d6ad80bc0bdea0 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 14 Mar 2023 18:16:40 +0000 Subject: [PATCH 2/9] tests fix, typos, linter --- pytorch3d/implicitron/dataset/blob_loader.py | 11 +++--- pytorch3d/implicitron/dataset/utils.py | 37 +++++++++++--------- tests/implicitron/test_blob_loader.py | 17 ++++----- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py index ce59c542d..fa3a5ac29 100644 --- a/pytorch3d/implicitron/dataset/blob_loader.py +++ b/pytorch3d/implicitron/dataset/blob_loader.py @@ -6,7 +6,6 @@ import functools import os -import warnings from dataclasses import dataclass from pathlib import Path from typing import Any, Optional, Tuple, Union @@ -17,12 +16,10 @@ from pytorch3d.implicitron.dataset import types from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.utils import _get_bbox_from_mask from pytorch3d.io import IO from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.structures.pointclouds import Pointclouds -from pytorch3d.implicitron.dataset.utils import ( - _get_bbox_from_mask, -) @dataclass @@ -124,13 +121,13 @@ def load_( frame_data.sequence_point_cloud_path = pcl_path if self.box_crop: - frame_data.crop_by_bbox(bbox_xywh, self.box_crop_context, ) + frame_data.crop_by_bbox(bbox_xywh, self.box_crop_context) return frame_data def _load_crop_fg_probability( self, entry: types.FrameAnnotation - ) -> Tuple[Optional[torch.Tensor],Optional[str],Optional[torch.Tensor]]: + ) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]: fg_probability = None full_path = None bbox_xywh = None @@ -138,7 +135,7 @@ def _load_crop_fg_probability( if (self.load_masks) and entry.mask is not None: full_path = os.path.join(self.dataset_root, entry.mask.path) mask = _load_mask(self._local_path(full_path)) - bbox_xywh = torch.tensor(_get_bbox_from_mask(self.mask, self.box_crop_mask_thr)) + bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) if mask.shape[-2:] != entry.image.size: raise ValueError( diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index b2ac99f36..6e9af933d 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import warnings +from typing import List, Optional, Tuple + +import numpy as np import torch @@ -55,22 +58,22 @@ def is_train_frame( def _get_bbox_from_mask( - mask, thr, decrease_quant: float = 0.05 - ) -> Tuple[int, int, int, int]: - # bbox in xywh - masks_for_box = np.zeros_like(mask) - while masks_for_box.sum() <= 1.0: - masks_for_box = (mask > thr).astype(np.float32) - thr -= decrease_quant - if thr <= 0.0: - warnings.warn( - f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1 - ) - - x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) - y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) - - return x0, y0, x1 - x0, y1 - y0 + mask, thr, decrease_quant: float = 0.05 +) -> Tuple[int, int, int, int]: + # bbox in xywh + masks_for_box = np.zeros_like(mask) + while masks_for_box.sum() <= 1.0: + masks_for_box = (mask > thr).astype(np.float32) + thr -= decrease_quant + if thr <= 0.0: + warnings.warn( + f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1 + ) + + x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) + y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) + + return x0, y0, x1 - x0, y1 - y0 def _crop_around_box(tensor, bbox, impath: str = ""): diff --git a/tests/implicitron/test_blob_loader.py b/tests/implicitron/test_blob_loader.py index fd8d8fd81..d2a612d48 100644 --- a/tests/implicitron/test_blob_loader.py +++ b/tests/implicitron/test_blob_loader.py @@ -69,25 +69,22 @@ def test_load_(self): fg_probability, mask_path, bbox_xywh, - clamp_bbox_xyxy, - crop_bbox_xywh, ) = self.blob_loader._load_crop_fg_probability(self.frame_annotation) assert mask_path assert torch.is_tensor(fg_probability) assert torch.is_tensor(bbox_xywh) - assert torch.is_tensor(clamp_bbox_xyxy) - assert torch.is_tensor(crop_bbox_xywh) # assert bboxes shape self.assertEqual( fg_probability.shape, torch.Size([1, self.image_height, self.image_width]) ) self.assertEqual(bbox_xywh.shape, torch.Size([4])) - self.assertEqual(clamp_bbox_xyxy.shape, torch.Size([4])) - self.assertEqual(crop_bbox_xywh.shape, torch.Size([4])) - (image_rgb, image_path, mask_crop, scale,) = self.blob_loader._load_crop_images( - self.frame_annotation, fg_probability, clamp_bbox_xyxy - ) + ( + image_rgb, + image_path, + mask_crop, + scale, + ) = self.blob_loader._load_crop_images(self.frame_annotation, fg_probability) assert torch.is_tensor(image_rgb) assert image_path assert torch.is_tensor(mask_crop) @@ -102,7 +99,6 @@ def test_load_(self): (depth_map, depth_path, depth_mask,) = self.blob_loader._load_mask_depth( self.frame_annotation, - clamp_bbox_xyxy, fg_probability, ) assert torch.is_tensor(depth_map) @@ -119,7 +115,6 @@ def test_load_(self): camera = self.blob_loader._get_pytorch3d_camera( self.frame_annotation, scale, - clamp_bbox_xyxy, ) self.assertEqual(type(camera), PerspectiveCameras) From 5c249db0a0160cf9c1b4043634a4e0a495cff6e1 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 14 Mar 2023 18:25:50 +0000 Subject: [PATCH 3/9] renamed crop to crop_ to show inplace modification --- pytorch3d/implicitron/dataset/blob_loader.py | 4 ++-- pytorch3d/implicitron/dataset/dataset_base.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py index fa3a5ac29..19417a639 100644 --- a/pytorch3d/implicitron/dataset/blob_loader.py +++ b/pytorch3d/implicitron/dataset/blob_loader.py @@ -85,7 +85,7 @@ def load_( ( frame_data.fg_probability, frame_data.mask_path, - bbox_xywh, + frame_data.bbox_xywh, ) = self._load_crop_fg_probability(entry) scale = min( @@ -121,7 +121,7 @@ def load_( frame_data.sequence_point_cloud_path = pcl_path if self.box_crop: - frame_data.crop_by_bbox(bbox_xywh, self.box_crop_context) + frame_data.crop_by_bbox_(self.box_crop_context) return frame_data diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 322d1889b..7ddc9e122 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -151,10 +151,10 @@ def __getitem__(self, key): def __len__(self): return len(fields(self)) - def crop_by_bbox(self, bbox_xywh, box_crop_context): + def crop_by_bbox_(self, box_crop_context): clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( _get_clamp_bbox( - bbox_xywh, + self.bbox_xywh, image_path=self.image.path, box_crop_context=box_crop_context, ), From 530b9a42d1ebfde8afa92dc3bded73f18d6e0a25 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 14 Mar 2023 18:33:44 +0000 Subject: [PATCH 4/9] shifting camera according to bbox --- pytorch3d/implicitron/dataset/dataset_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 7ddc9e122..2c1bb7527 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -173,8 +173,7 @@ def crop_by_bbox_(self, box_crop_context): depth_mask_bbox_xyxy = _rescale_bbox(clamp_bbox_xyxy, entry.image.size, self.depth_mask.shape[-2:]) self.depth_mask = _crop_around_box(self.depth_mask, depth_mask_bbox_xyxy, self.mask_path) - - principal_point_px -= clamp_bbox_xyxy[:2] + self.camera.principal_point_px -= clamp_bbox_xyxy[:2] @classmethod def collate(cls, batch): From e5500f329d3016740af50f4dc420b07d1383942a Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 15 Mar 2023 18:22:18 +0000 Subject: [PATCH 5/9] delegated reize_image to FrameData, made bbox_xywh optinal external parameter for load_, linter, fbcode tests --- pytorch3d/implicitron/dataset/blob_loader.py | 130 ++++++----------- pytorch3d/implicitron/dataset/dataset_base.py | 118 +++++++++++++--- .../implicitron/dataset/json_index_dataset.py | 44 ++++-- pytorch3d/implicitron/dataset/utils.py | 29 ++++ pytorch3d/implicitron/dataset/visualize.py | 1 + tests/implicitron/test_bbox.py | 2 +- tests/implicitron/test_blob_loader.py | 131 +++++++++++++----- 7 files changed, 307 insertions(+), 148 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py index 19417a639..9ccf53b2f 100644 --- a/pytorch3d/implicitron/dataset/blob_loader.py +++ b/pytorch3d/implicitron/dataset/blob_loader.py @@ -38,23 +38,23 @@ class BlobLoader: load_masks: Enable loading frame foreground masks. load_point_clouds: Enable loading sequence-level point clouds. max_points: Cap on the number of loaded points in the point cloud; - if reached, they are randomly sampled without replacement. + if reached, they are randomly sampled without replacement. mask_images: Whether to mask the images with the loaded foreground masks; - 0 value is used for background. + 0 value is used for background. mask_depths: Whether to mask the depth maps with the loaded foreground masks; 0 value is used for background. image_height: The height of the returned images, masks, and depth maps; - aspect ratio is preserved during cropping/resizing. + aspect ratio is preserved during cropping/resizing. image_width: The width of the returned images, masks, and depth maps; aspect ratio is preserved during cropping/resizing. box_crop: Enable cropping of the image around the bounding box inferred - from the foreground region of the loaded segmentation mask; masks - and depth maps are cropped accordingly; cameras are corrected. + from the foreground region of the loaded segmentation mask; masks + and depth maps are cropped accordingly; cameras are corrected. box_crop_mask_thr: The threshold used to separate pixels into foreground - and background based on the foreground_probability mask; if no value - is greater than this threshold, the loader lowers it and repeats. + and background based on the foreground_probability mask; if no value + is greater than this threshold, the loader lowers it and repeats. box_crop_context: The amount of additional padding added to each - dimension of the cropping bounding box, relative to box size. + dimension of the cropping bounding box, relative to box size. """ dataset_root: str = "" @@ -78,20 +78,18 @@ def load_( frame_data: FrameData, entry: types.FrameAnnotation, seq_annotation: types.SequenceAnnotation, + bbox_xywh: Optional[torch.Tensor] = None, ) -> FrameData: """Main method for loader. FrameData modification done inplace + if bbox_xywh not provided bbox will be calculated from mask """ ( frame_data.fg_probability, frame_data.mask_path, frame_data.bbox_xywh, - ) = self._load_crop_fg_probability(entry) + ) = self._load_fg_probability(entry, bbox_xywh) - scale = min( - self.image_height / entry.image.size[0], - self.image_width / entry.image.size[1], - ) if self.load_images and entry.image is not None: # original image size frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) @@ -99,9 +97,7 @@ def load_( ( frame_data.image_rgb, frame_data.image_path, - frame_data.mask_crop, - scale, - ) = self._load_crop_images(entry, frame_data.fg_probability) + ) = self._load_images(entry, frame_data.fg_probability) if self.load_depths and entry.depth is not None: ( @@ -110,9 +106,6 @@ def load_( frame_data.depth_mask, ) = self._load_mask_depth(entry, frame_data.fg_probability) - if entry.viewpoint is not None: - frame_data.camera = self._get_pytorch3d_camera(entry, scale) - if self.load_point_clouds and seq_annotation.point_cloud is not None: pcl_path = self._fix_point_cloud_path(seq_annotation.point_cloud.path) frame_data.sequence_point_cloud = _load_pointcloud( @@ -120,42 +113,50 @@ def load_( ) frame_data.sequence_point_cloud_path = pcl_path + clamp_bbox_xyxy = None if self.box_crop: - frame_data.crop_by_bbox_(self.box_crop_context) + clamp_bbox_xyxy = frame_data.crop_by_bbox_(self.box_crop_context) + + scale = 1.0 + + if self.image_height is not None and self.image_width is not None: + scale = frame_data.resize_frame_(self.image_height, self.image_width) + # creating camera taking to account bbox and resize scale + if entry.viewpoint is not None: + frame_data.camera = self._get_pytorch3d_camera( + entry, scale, clamp_bbox_xyxy + ) return frame_data - def _load_crop_fg_probability( - self, entry: types.FrameAnnotation + def _load_fg_probability( + self, + entry: types.FrameAnnotation, + bbox_xywh: Optional[torch.Tensor], ) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]: fg_probability = None full_path = None - bbox_xywh = None if (self.load_masks) and entry.mask is not None: full_path = os.path.join(self.dataset_root, entry.mask.path) - mask = _load_mask(self._local_path(full_path)) - bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) - - if mask.shape[-2:] != entry.image.size: + fg_probability = _load_mask(self._local_path(full_path)) + # we can use provided bbox_xywh or calculate it based on mask + if bbox_xywh is None: + bbox_xywh = torch.tensor( + _get_bbox_from_mask(fg_probability, self.box_crop_mask_thr) + ) + if fg_probability.shape[-2:] != entry.image.size: raise ValueError( - f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" + f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" ) - fg_probability, _, _ = _resize_image( - mask, - image_height=self.image_height, - image_width=self.image_width, - mode="nearest", - ) - - return fg_probability, full_path, bbox_xywh + return torch.tensor(fg_probability), full_path, bbox_xywh - def _load_crop_images( + def _load_images( self, entry: types.FrameAnnotation, fg_probability: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: + ) -> Tuple[torch.Tensor, str]: assert self.dataset_root is not None and entry.image is not None path = os.path.join(self.dataset_root, entry.image.path) image_rgb = _load_image(self._local_path(path)) @@ -165,15 +166,11 @@ def _load_crop_images( f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" ) - image_rgb, scale, mask_crop = _resize_image( - image_rgb, image_height=self.image_height, image_width=self.image_width - ) - if self.mask_images: assert fg_probability is not None image_rgb *= fg_probability - return image_rgb, path, mask_crop, scale + return image_rgb, path def _load_mask_depth( self, @@ -185,13 +182,6 @@ def _load_mask_depth( path = os.path.join(self.dataset_root, entry_depth.path) depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) - depth_map, _, _ = _resize_image( - depth_map, - image_height=self.image_height, - image_width=self.image_width, - mode="nearest", - ) - if self.mask_depths: assert fg_probability is not None depth_map *= fg_probability @@ -200,22 +190,16 @@ def _load_mask_depth( assert entry_depth.mask_path is not None mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) depth_mask = _load_depth_mask(self._local_path(mask_path)) - - depth_mask, _, _ = _resize_image( - depth_mask, - image_height=self.image_height, - image_width=self.image_width, - mode="nearest", - ) else: depth_mask = torch.ones_like(depth_map) - return depth_map, path, depth_mask + return torch.tensor(depth_map), path, torch.tensor(depth_mask) def _get_pytorch3d_camera( self, entry: types.FrameAnnotation, scale: float, + clamp_bbox_xyxy: Optional[torch.Tensor], ) -> PerspectiveCameras: entry_viewpoint = entry.viewpoint assert entry_viewpoint is not None @@ -243,6 +227,10 @@ def _get_pytorch3d_camera( principal_point_px = half_image_size_wh_orig - principal_point * rescale focal_length_px = focal_length * rescale + # changing principal_point according to bbox_crop + if clamp_bbox_xyxy is not None: + principal_point_px -= clamp_bbox_xyxy[:2] + # now, convert from pixels to PyTorch3D v0.5+ NDC convention if self.image_height is None or self.image_width is None: out_size = list(reversed(entry.image.size)) @@ -283,32 +271,6 @@ def _local_path(self, path: str) -> str: return self.path_manager.get_local_path(path) -def _resize_image( - image, image_height, image_width, mode="bilinear" -) -> Tuple[torch.Tensor, float, torch.Tensor]: - if image_height is None or image_width is None: - # skip the resizing - imre_ = torch.from_numpy(image) - return imre_, 1.0, torch.ones_like(imre_[:1]) - # takes numpy array, returns pytorch tensor - minscale = min( - image_height / image.shape[-2], - image_width / image.shape[-1], - ) - imre = torch.nn.functional.interpolate( - torch.from_numpy(image)[None], - scale_factor=minscale, - mode=mode, - align_corners=False if mode == "bilinear" else None, - recompute_scale_factor=True, - )[0] - imre_ = torch.zeros(image.shape[0], image_height, image_width) - imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre - mask = torch.zeros(1, image_height, image_width) - mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 - return imre_, minscale, mask - - def _load_image(path) -> np.ndarray: with Image.open(path) as pil_im: im = np.array(pil_im.convert("RGB")) diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 2c1bb7527..cbc871a1e 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from collections import defaultdict from dataclasses import dataclass, field, fields from typing import ( @@ -23,16 +24,17 @@ import numpy as np import torch -from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras -from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds from pytorch3d.implicitron.dataset.utils import ( - _crop_around_box, - _clamp_box_to_image_bounds_and_round, _bbox_xyxy_to_xywh, + _clamp_box_to_image_bounds_and_round, + _crop_around_box, _get_clamp_bbox, _rescale_bbox, + _resize_image, ) +from pytorch3d.renderer.camera_utils import join_cameras_as_batch +from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds @dataclass @@ -97,6 +99,7 @@ class FrameData(Mapping[str, Any]): frame_type: The type of the loaded frame specified in `subset_lists_file`, if provided. meta: A dict for storing additional frame information. + cropped: Bool to avoid cropping FrameData twice """ frame_number: Optional[torch.LongTensor] @@ -123,6 +126,7 @@ class FrameData(Mapping[str, Any]): sequence_point_cloud_idx: Optional[torch.Tensor] = None frame_type: Union[str, List[str], None] = None # known | unseen meta: dict = field(default_factory=lambda: {}) + cropped: bool = False def to(self, *args, **kwargs): new_params = {} @@ -151,29 +155,105 @@ def __getitem__(self, key): def __len__(self): return len(fields(self)) - def crop_by_bbox_(self, box_crop_context): + def crop_by_bbox_(self, box_crop_context) -> Optional[torch.Tensor]: + if self.cropped: + warnings.warn( + f"You called cropping on same frame twice " + f"sequence_name: {self.sequence_name}, skipping cropping" + ) + return None + + if ( + self.bbox_xywh is None + or self.fg_probability is None + or self.mask_path is None + or self.image_path is None + ): + warnings.warn( + "You called cropping without loading frame data" + "please call blob_loader.load_ first, skipping cropping" + ) + return None + + bbox_xyxy = _get_clamp_bbox( + self.bbox_xywh, + # pyre-ignore + image_path=self.image_path, + box_crop_context=box_crop_context, + ) clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( - _get_clamp_bbox( - self.bbox_xywh, - image_path=self.image.path, - box_crop_context=box_crop_context, - ), - image_size_hw=tuple(self.fg_probability.shape[-2:]), + bbox_xyxy, + # pyre-ignore + image_size_hw=tuple(self.image_size_hw), ) self.crop_bbox_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) self.fg_probability = _crop_around_box( - self.fg_probability, clamp_bbox_xyxy, self.mask_path + self.fg_probability, + clamp_bbox_xyxy, + # pyre-ignore + self.mask_path, + ) + self.image_rgb = _crop_around_box( + self.image_rgb, + clamp_bbox_xyxy, + # pyre-ignore + self.image_path, ) - self.image_rgb = _crop_around_box(self.image_rgb, clamp_bbox_xyxy, self.image.path) - depth_bbox_xyxy = _rescale_bbox(clamp_bbox_xyxy, entry.image.size, self.depth_map.shape[-2:]) - self.depth_map = _crop_around_box(self.depth_map, depth_bbox_xyxy, self.depth_path) + if self.depth_map is not None: + self.depth_map = _crop_around_box( + self.depth_map, + clamp_bbox_xyxy, + # pyre-ignore + self.depth_path, + ) + if self.depth_mask is not None: + self.depth_mask = _crop_around_box( + self.depth_mask, + clamp_bbox_xyxy, + # pyre-ignore + self.mask_path, + ) + self.cropped = True + return clamp_bbox_xyxy + + def resize_frame_(self, image_height, image_width) -> float: + if self.bbox_xywh is not None: + self.bbox_xywh = _rescale_bbox( + self.bbox_xywh, + np.array(self.image_size_hw), + # pyre-ignore + self.image_rgb.shape[-2:], + ) + + self.image_rgb, scale, self.mask_crop = _resize_image( + self.image_rgb, image_height=image_height, image_width=image_width + ) - depth_mask_bbox_xyxy = _rescale_bbox(clamp_bbox_xyxy, entry.image.size, self.depth_mask.shape[-2:]) - self.depth_mask = _crop_around_box(self.depth_mask, depth_mask_bbox_xyxy, self.mask_path) + self.fg_probability, _, _ = _resize_image( + self.fg_probability, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) - self.camera.principal_point_px -= clamp_bbox_xyxy[:2] + if self.depth_map is not None: + self.depth_map, _, _ = _resize_image( + self.depth_map, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + + if self.depth_mask is not None: + self.depth_mask, _, _ = _resize_image( + self.depth_mask, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + return scale @classmethod def collate(cls, batch): diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 636630680..5f9b2685a 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -14,7 +14,6 @@ import random import warnings from collections import defaultdict -from dataclasses import field from itertools import islice from typing import ( Any, @@ -161,12 +160,12 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): sort_frames: bool = False eval_batches: Any = None eval_batch_index: Any = None - subset_to_image_path: Any = None # initialised in __post_init__ - blob_loader: BlobLoader = field(init=False) - frame_annots: List[FrameAnnotsEntry] = field(init=False) - seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) - _seq_to_idx: Dict[str, List[int]] = field(init=False) + # commented because of OmegaConf (for tests to pass) + # blob_loader: BlobLoader = field(init=False) + # frame_annots: List[FrameAnnotsEntry] = field(init=False) + # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) + # _seq_to_idx: Dict[str, List[int]] = field(init=False) def __post_init__(self) -> None: self._load_frames() @@ -177,6 +176,7 @@ def __post_init__(self) -> None: self._filter_db() # also computes sequence indices self._extract_and_set_eval_batches() + # pyre-ignore self.blob_loader = BlobLoader( dataset_root=self.dataset_root, load_images=self.load_images, @@ -219,7 +219,9 @@ def join(self, other_datasets: Iterable["JsonIndexDataset"]) -> None: """ if not all(isinstance(d, JsonIndexDataset) for d in other_datasets): raise ValueError("This function can only join a list of JsonIndexDataset") + # pyre-ignore self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots]) + # pyre-ignore self.seq_annots.update( # https://gist.github.com/treyhunner/f35292e676efa0be1728 functools.reduce( @@ -295,9 +297,11 @@ def seq_frame_index_to_dataset_index( """ _dataset_seq_frame_n_index = { seq: { + # pyre-ignore self.frame_annots[idx]["frame_annotation"].frame_number: idx for idx in seq_idx } + # pyre-ignore for seq, seq_idx in self._seq_to_idx.items() } @@ -320,6 +324,7 @@ def _get_dataset_idx( # Check that the loaded frame path is consistent # with the one stored in self.frame_annots. assert os.path.normpath( + # pyre-ignore self.frame_annots[idx]["frame_annotation"].image.path ) == os.path.normpath( path @@ -369,6 +374,7 @@ def subset_from_frame_index( # Deep copy the whole dataset except frame_annots, which are large so we # deep copy only the requested subset of frame_annots. + # pyre-ignore memo = {id(self.frame_annots): None} dataset_new = copy.deepcopy(self, memo) dataset_new.frame_annots = copy.deepcopy( @@ -397,9 +403,11 @@ def subset_from_frame_index( return dataset_new def __str__(self) -> str: + # pyre-ignore return f"JsonIndexDataset #frames={len(self.frame_annots)}" def __len__(self) -> int: + # pyre-ignore return len(self.frame_annots) def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: @@ -411,6 +419,7 @@ def get_all_train_cameras(self) -> CamerasBase: """ logger.info("Loading all train cameras.") cameras = [] + # pyre-ignore for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)): frame_type = self._get_frame_type(frame_annot) if frame_type is None: @@ -420,10 +429,12 @@ def get_all_train_cameras(self) -> CamerasBase: return join_cameras_as_batch(cameras) def __getitem__(self, index) -> FrameData: + # pyre-ignore if index >= len(self.frame_annots): raise IndexError(f"index {index} out of range {len(self.frame_annots)}") entry = self.frame_annots[index]["frame_annotation"] + # pyre-ignore point_cloud = self.seq_annots[entry.sequence_name].point_cloud frame_data = FrameData( frame_number=_safe_as_tensor(entry.frame_number, torch.long), @@ -443,9 +454,8 @@ def __getitem__(self, index) -> FrameData: # Optional field frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) - self.blob_loader.load_( - frame_data, entry, self.seq_annots[entry.sequence_name] - ) + # pyre-ignore + self.blob_loader.load_(frame_data, entry, self.seq_annots[entry.sequence_name]) return frame_data def _load_frames(self) -> None: @@ -457,6 +467,7 @@ def _load_frames(self) -> None: ) if not frame_annots_list: raise ValueError("Empty dataset!") + # pyre-ignore self.frame_annots = [ FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list ] @@ -468,6 +479,7 @@ def _load_sequences(self) -> None: seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) if not seq_annots: raise ValueError("Empty sequences file!") + # pyre-ignore self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} def _load_subset_lists(self) -> None: @@ -483,6 +495,7 @@ def _load_subset_lists(self) -> None: for subset, frames in subset_to_seq_frame.items() for _, _, path in frames } + # pyre-ignore for frame in self.frame_annots: frame["subset"] = frame_path_to_subset.get( frame["frame_annotation"].image.path, None @@ -495,6 +508,7 @@ def _load_subset_lists(self) -> None: def _sort_frames(self) -> None: # Sort frames to have them grouped by sequence, ordered by timestamp + # pyre-ignore self.frame_annots = sorted( self.frame_annots, key=lambda f: ( @@ -506,6 +520,7 @@ def _sort_frames(self) -> None: def _filter_db(self) -> None: if self.remove_empty_masks: logger.info("Removing images with empty masks.") + # pyre-ignore old_len = len(self.frame_annots) msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." @@ -546,6 +561,7 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool: if len(self.limit_category_to) > 0: logger.info(f"Limiting dataset to categories: {self.limit_category_to}") + # pyre-ignore self.seq_annots = { name: entry for name, entry in self.seq_annots.items() @@ -583,6 +599,7 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool: if self.n_frames_per_sequence > 0: logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") keep_idx = [] + # pyre-ignore for seq, seq_indices in self._seq_to_idx.items(): # infer the seed from the sequence name, this is reproducible # and makes the selection differ for different sequences @@ -612,14 +629,19 @@ def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: self._invalidate_seq_to_idx() if filter_seq_annots: + # pyre-ignore self.seq_annots = { - k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx + k: v + for k, v in self.seq_annots.items() + if k in self._seq_to_idx # pyre-ignore } def _invalidate_seq_to_idx(self) -> None: seq_to_idx = defaultdict(list) + # pyre-ignore for idx, entry in enumerate(self.frame_annots): seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) + # pyre-ignore self._seq_to_idx = seq_to_idx def _local_path(self, path: str) -> str: @@ -634,6 +656,7 @@ def get_frame_numbers_and_timestamps( for idx in idxs: if ( subset_filter is not None + # pyre-ignore and self.frame_annots[idx]["subset"] not in subset_filter ): continue @@ -646,6 +669,7 @@ def get_frame_numbers_and_timestamps( def category_to_sequence_names(self) -> Dict[str, List[str]]: c2seq = defaultdict(list) + # pyre-ignore for sequence_name, sa in self.seq_annots.items(): c2seq[sa.category].append(sequence_name) return dict(c2seq) diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index 6e9af933d..aca0507dd 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -156,3 +156,32 @@ def _bbox_xywh_to_xyxy( def _get_1d_bounds(arr) -> Tuple[int, int]: nz = np.flatnonzero(arr) return nz[0], nz[-1] + 1 + + +def _resize_image( + image, image_height, image_width, mode="bilinear" +) -> Tuple[torch.Tensor, float, torch.Tensor]: + + if type(image) == np.ndarray: + image = torch.from_numpy(image) + + if image_height is None or image_width is None: + # skip the resizing + return image, 1.0, torch.ones_like(image[:1]) + # takes numpy array or tensor, returns pytorch tensor + minscale = min( + image_height / image.shape[-2], + image_width / image.shape[-1], + ) + imre = torch.nn.functional.interpolate( + image[None], + scale_factor=minscale, + mode=mode, + align_corners=False if mode == "bilinear" else None, + recompute_scale_factor=True, + )[0] + imre_ = torch.zeros(image.shape[0], image_height, image_width) + imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre + mask = torch.zeros(1, image_height, image_width) + mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 + return imre_, minscale, mask diff --git a/pytorch3d/implicitron/dataset/visualize.py b/pytorch3d/implicitron/dataset/visualize.py index 284e903a0..6d0be0362 100644 --- a/pytorch3d/implicitron/dataset/visualize.py +++ b/pytorch3d/implicitron/dataset/visualize.py @@ -44,6 +44,7 @@ def get_implicitron_sequence_pointcloud( sequence_entries = [ ei for ei in sequence_entries + # pyre-ignore[16] if dataset.frame_annots[ei]["frame_annotation"].sequence_name == sequence_name ] diff --git a/tests/implicitron/test_bbox.py b/tests/implicitron/test_bbox.py index 16199ad1e..3c45ee793 100644 --- a/tests/implicitron/test_bbox.py +++ b/tests/implicitron/test_bbox.py @@ -9,7 +9,6 @@ import numpy as np import torch -from pytorch3d.implicitron.dataset.blob_loader import _resize_image from pytorch3d.implicitron.dataset.utils import ( _bbox_xywh_to_xyxy, @@ -20,6 +19,7 @@ _get_bbox_from_mask, _get_clamp_bbox, _rescale_bbox, + _resize_image, ) from tests.common_testing import TestCaseMixin diff --git a/tests/implicitron/test_blob_loader.py b/tests/implicitron/test_blob_loader.py index d2a612d48..ef18d6258 100644 --- a/tests/implicitron/test_blob_loader.py +++ b/tests/implicitron/test_blob_loader.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import contextlib import gzip import os @@ -15,8 +21,10 @@ _load_depth_mask, _load_image, _load_mask, + _safe_as_tensor, BlobLoader, ) +from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.tools.config import get_default_args from pytorch3d.renderer.cameras import PerspectiveCameras @@ -53,6 +61,37 @@ def setUp(self): ) self.frame_annotation = frame_annots_list[0] + sequence_annotations_file = os.path.join( + self.dataset_root, category, "sequence_annotations.jgz" + ) + local_file = self.path_manager.get_local_path(sequence_annotations_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + seq_annots_list = types.load_dataclass( + zipfile, List[types.SequenceAnnotation] + ) + seq_annots = {entry.sequence_name: entry for entry in seq_annots_list} + self.seq_annotation = seq_annots[self.frame_annotation.sequence_name] + + point_cloud = self.seq_annotation.point_cloud + self.frame_data = FrameData( + frame_number=_safe_as_tensor( + self.frame_annotation.frame_number, torch.long + ), + frame_timestamp=_safe_as_tensor( + self.frame_annotation.frame_timestamp, torch.float + ), + sequence_name=self.frame_annotation.sequence_name, + sequence_category=self.seq_annotation.category, + camera_quality_score=_safe_as_tensor( + self.seq_annotation.viewpoint_quality_score, torch.float + ), + point_cloud_quality_score=_safe_as_tensor( + point_cloud.quality_score, torch.float + ) + if point_cloud is not None + else None, + ) + def test_BlobLoader_args(self): # test that BlobLoader works with get_default_args get_default_args(BlobLoader) @@ -65,58 +104,82 @@ def test_fix_point_cloud_path(self): assert self.blob_loader.dataset_root in modified_path def test_load_(self): + bbox_xywh = None + self.frame_data.image_size_hw = _safe_as_tensor( + self.frame_annotation.image.size, torch.long + ) ( - fg_probability, - mask_path, - bbox_xywh, - ) = self.blob_loader._load_crop_fg_probability(self.frame_annotation) - - assert mask_path - assert torch.is_tensor(fg_probability) - assert torch.is_tensor(bbox_xywh) + self.frame_data.fg_probability, + self.frame_data.mask_path, + self.frame_data.bbox_xywh, + ) = self.blob_loader._load_fg_probability(self.frame_annotation, bbox_xywh) + + assert self.frame_data.mask_path + assert torch.is_tensor(self.frame_data.fg_probability) + assert torch.is_tensor(self.frame_data.bbox_xywh) # assert bboxes shape - self.assertEqual( - fg_probability.shape, torch.Size([1, self.image_height, self.image_width]) + self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4])) + ( + self.frame_data.image_rgb, + self.frame_data.image_path, + ) = self.blob_loader._load_images( + self.frame_annotation, self.frame_data.fg_probability ) - self.assertEqual(bbox_xywh.shape, torch.Size([4])) + self.assertEqual(type(self.frame_data.image_rgb), np.ndarray) + assert self.frame_data.image_path + ( - image_rgb, - image_path, - mask_crop, - scale, - ) = self.blob_loader._load_crop_images(self.frame_annotation, fg_probability) - assert torch.is_tensor(image_rgb) - assert image_path - assert torch.is_tensor(mask_crop) + self.frame_data.depth_map, + depth_path, + self.frame_data.depth_mask, + ) = self.blob_loader._load_mask_depth( + self.frame_annotation, + self.frame_data.fg_probability, + ) + assert torch.is_tensor(self.frame_data.depth_map) + assert depth_path + assert torch.is_tensor(self.frame_data.depth_mask) + + clamp_bbox_xyxy = None + if self.blob_loader.box_crop: + clamp_bbox_xyxy = self.frame_data.crop_by_bbox_( + self.blob_loader.box_crop_context + ) + + # assert image and mask shapes after resize + scale = self.frame_data.resize_frame_(self.image_height, self.image_width) assert scale - # assert image and mask shapes self.assertEqual( - image_rgb.shape, torch.Size([3, self.image_height, self.image_width]) + self.frame_data.mask_crop.shape, + torch.Size([1, self.image_height, self.image_width]), ) self.assertEqual( - mask_crop.shape, torch.Size([1, self.image_height, self.image_width]) + self.frame_data.image_rgb.shape, + torch.Size([3, self.image_height, self.image_width]), ) - - (depth_map, depth_path, depth_mask,) = self.blob_loader._load_mask_depth( - self.frame_annotation, - fg_probability, + self.assertEqual( + self.frame_data.mask_crop.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.fg_probability.shape, + torch.Size([1, self.image_height, self.image_width]), ) - assert torch.is_tensor(depth_map) - assert depth_path - assert torch.is_tensor(depth_mask) - # assert image and mask shapes self.assertEqual( - depth_map.shape, torch.Size([1, self.image_height, self.image_width]) + self.frame_data.depth_map.shape, + torch.Size([1, self.image_height, self.image_width]), ) self.assertEqual( - depth_mask.shape, torch.Size([1, self.image_height, self.image_width]) + self.frame_data.depth_mask.shape, + torch.Size([1, self.image_height, self.image_width]), ) - camera = self.blob_loader._get_pytorch3d_camera( + self.frame_data.camera = self.blob_loader._get_pytorch3d_camera( self.frame_annotation, scale, + clamp_bbox_xyxy, ) - self.assertEqual(type(camera), PerspectiveCameras) + self.assertEqual(type(self.frame_data.camera), PerspectiveCameras) def test_load_image(self): path = os.path.join(self.dataset_root, self.frame_annotation.image.path) From 0fc3253d029ccf1551f0439b7c787fbb4d76f8bd Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 15 Mar 2023 18:52:32 +0000 Subject: [PATCH 6/9] using safe_as_tensor for fg_probability --- pytorch3d/implicitron/dataset/blob_loader.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py index 9ccf53b2f..13eecdf79 100644 --- a/pytorch3d/implicitron/dataset/blob_loader.py +++ b/pytorch3d/implicitron/dataset/blob_loader.py @@ -142,15 +142,17 @@ def _load_fg_probability( fg_probability = _load_mask(self._local_path(full_path)) # we can use provided bbox_xywh or calculate it based on mask if bbox_xywh is None: - bbox_xywh = torch.tensor( - _get_bbox_from_mask(fg_probability, self.box_crop_mask_thr) - ) + bbox_xywh = _get_bbox_from_mask(fg_probability, self.box_crop_mask_thr) if fg_probability.shape[-2:] != entry.image.size: raise ValueError( f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" ) - return torch.tensor(fg_probability), full_path, bbox_xywh + return ( + _safe_as_tensor(fg_probability, torch.float), + full_path, + _safe_as_tensor(bbox_xywh, torch.long), + ) def _load_images( self, From 7c8d89daa2b3908a72f847f73edee704780d3f63 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 15 Mar 2023 18:58:34 +0000 Subject: [PATCH 7/9] made resizing only for loaded objects --- pytorch3d/implicitron/dataset/dataset_base.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index cbc871a1e..d567fb0b3 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -226,17 +226,18 @@ def resize_frame_(self, image_height, image_width) -> float: # pyre-ignore self.image_rgb.shape[-2:], ) + if self.image_rgb is not None: + self.image_rgb, scale, self.mask_crop = _resize_image( + self.image_rgb, image_height=image_height, image_width=image_width + ) - self.image_rgb, scale, self.mask_crop = _resize_image( - self.image_rgb, image_height=image_height, image_width=image_width - ) - - self.fg_probability, _, _ = _resize_image( - self.fg_probability, - image_height=image_height, - image_width=image_width, - mode="nearest", - ) + if self.fg_probability is not None: + self.fg_probability, _, _ = _resize_image( + self.fg_probability, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) if self.depth_map is not None: self.depth_map, _, _ = _resize_image( From 3027cd7e5f2b615fead37f3338cfe587e84cc9db Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 15 Mar 2023 19:05:26 +0000 Subject: [PATCH 8/9] fixing scale --- pytorch3d/implicitron/dataset/dataset_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index d567fb0b3..1684251fb 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -226,6 +226,8 @@ def resize_frame_(self, image_height, image_width) -> float: # pyre-ignore self.image_rgb.shape[-2:], ) + + scale = 1.0 if self.image_rgb is not None: self.image_rgb, scale, self.mask_crop = _resize_image( self.image_rgb, image_height=image_height, image_width=image_width @@ -237,6 +239,7 @@ def resize_frame_(self, image_height, image_width) -> float: image_height=image_height, image_width=image_width, mode="nearest", + ) if self.depth_map is not None: From 7d570c179d94b28c00e3c0c749da6d7150e8d7e7 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 15 Mar 2023 19:21:46 +0000 Subject: [PATCH 9/9] fixing scale again.. --- pytorch3d/implicitron/dataset/blob_loader.py | 15 +++++++++++++-- pytorch3d/implicitron/dataset/dataset_base.py | 5 ++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py index 13eecdf79..83f39c78e 100644 --- a/pytorch3d/implicitron/dataset/blob_loader.py +++ b/pytorch3d/implicitron/dataset/blob_loader.py @@ -117,10 +117,21 @@ def load_( if self.box_crop: clamp_bbox_xyxy = frame_data.crop_by_bbox_(self.box_crop_context) - scale = 1.0 + scale = ( + min( + self.image_height / entry.image.size[0], + # pyre-ignore + self.image_width / entry.image.size[1], + ) + if self.image_height is not None and self.image_width is not None + else 1.0 + ) if self.image_height is not None and self.image_width is not None: - scale = frame_data.resize_frame_(self.image_height, self.image_width) + optional_scale = frame_data.resize_frame_( + self.image_height, self.image_width + ) + scale = optional_scale or scale # creating camera taking to account bbox and resize scale if entry.viewpoint is not None: diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 1684251fb..7c4268fb9 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -218,7 +218,7 @@ def crop_by_bbox_(self, box_crop_context) -> Optional[torch.Tensor]: self.cropped = True return clamp_bbox_xyxy - def resize_frame_(self, image_height, image_width) -> float: + def resize_frame_(self, image_height, image_width) -> Optional[float]: if self.bbox_xywh is not None: self.bbox_xywh = _rescale_bbox( self.bbox_xywh, @@ -227,7 +227,7 @@ def resize_frame_(self, image_height, image_width) -> float: self.image_rgb.shape[-2:], ) - scale = 1.0 + scale = None if self.image_rgb is not None: self.image_rgb, scale, self.mask_crop = _resize_image( self.image_rgb, image_height=image_height, image_width=image_width @@ -239,7 +239,6 @@ def resize_frame_(self, image_height, image_width) -> float: image_height=image_height, image_width=image_width, mode="nearest", - ) if self.depth_map is not None: