|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | from collections import defaultdict
|
8 |
| -from dataclasses import dataclass, field, fields |
| 8 | +from dataclasses import dataclass |
9 | 9 | from typing import (
|
10 |
| - Any, |
11 | 10 | ClassVar,
|
12 | 11 | Dict,
|
13 | 12 | Iterable,
|
14 | 13 | Iterator,
|
15 | 14 | List,
|
16 |
| - Mapping, |
17 | 15 | Optional,
|
18 | 16 | Sequence,
|
19 | 17 | Tuple,
|
20 | 18 | Type,
|
21 |
| - Union, |
22 | 19 | )
|
23 | 20 |
|
24 |
| -import numpy as np |
25 | 21 | import torch
|
26 |
| -from pytorch3d.renderer.camera_utils import join_cameras_as_batch |
27 |
| -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras |
28 |
| -from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds |
29 | 22 |
|
30 |
| - |
31 |
| -@dataclass |
32 |
| -class FrameData(Mapping[str, Any]): |
33 |
| - """ |
34 |
| - A type of the elements returned by indexing the dataset object. |
35 |
| - It can represent both individual frames and batches of thereof; |
36 |
| - in this documentation, the sizes of tensors refer to single frames; |
37 |
| - add the first batch dimension for the collation result. |
38 |
| -
|
39 |
| - Args: |
40 |
| - frame_number: The number of the frame within its sequence. |
41 |
| - 0-based continuous integers. |
42 |
| - sequence_name: The unique name of the frame's sequence. |
43 |
| - sequence_category: The object category of the sequence. |
44 |
| - frame_timestamp: The time elapsed since the start of a sequence in sec. |
45 |
| - image_size_hw: The size of the image in pixels; (height, width) tensor |
46 |
| - of shape (2,). |
47 |
| - image_path: The qualified path to the loaded image (with dataset_root). |
48 |
| - image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image |
49 |
| - of the frame; elements are floats in [0, 1]. |
50 |
| - mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image |
51 |
| - regions. Regions can be invalid (mask_crop[i,j]=0) in case they |
52 |
| - are a result of zero-padding of the image after cropping around |
53 |
| - the object bounding box; elements are floats in {0.0, 1.0}. |
54 |
| - depth_path: The qualified path to the frame's depth map. |
55 |
| - depth_map: A float Tensor of shape `(1, H, W)` holding the depth map |
56 |
| - of the frame; values correspond to distances from the camera; |
57 |
| - use `depth_mask` and `mask_crop` to filter for valid pixels. |
58 |
| - depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the |
59 |
| - depth map that are valid for evaluation, they have been checked for |
60 |
| - consistency across views; elements are floats in {0.0, 1.0}. |
61 |
| - mask_path: A qualified path to the foreground probability mask. |
62 |
| - fg_probability: A Tensor of `(1, H, W)` denoting the probability of the |
63 |
| - pixels belonging to the captured object; elements are floats |
64 |
| - in [0, 1]. |
65 |
| - bbox_xywh: The bounding box tightly enclosing the foreground object in the |
66 |
| - format (x0, y0, width, height). The convention assumes that |
67 |
| - `x0+width` and `y0+height` includes the boundary of the box. |
68 |
| - I.e., to slice out the corresponding crop from an image tensor `I` |
69 |
| - we execute `crop = I[..., y0:y0+height, x0:x0+width]` |
70 |
| - crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` |
71 |
| - in the original image coordinates in the format (x0, y0, width, height). |
72 |
| - The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs |
73 |
| - from `bbox_xywh` due to padding (which can happen e.g. due to |
74 |
| - setting `JsonIndexDataset.box_crop_context > 0`) |
75 |
| - camera: A PyTorch3D camera object corresponding the frame's viewpoint, |
76 |
| - corrected for cropping if it happened. |
77 |
| - camera_quality_score: The score proportional to the confidence of the |
78 |
| - frame's camera estimation (the higher the more accurate). |
79 |
| - point_cloud_quality_score: The score proportional to the accuracy of the |
80 |
| - frame's sequence point cloud (the higher the more accurate). |
81 |
| - sequence_point_cloud_path: The path to the sequence's point cloud. |
82 |
| - sequence_point_cloud: A PyTorch3D Pointclouds object holding the |
83 |
| - point cloud corresponding to the frame's sequence. When the object |
84 |
| - represents a batch of frames, point clouds may be deduplicated; |
85 |
| - see `sequence_point_cloud_idx`. |
86 |
| - sequence_point_cloud_idx: Integer indices mapping frame indices to the |
87 |
| - corresponding point clouds in `sequence_point_cloud`; to get the |
88 |
| - corresponding point cloud to `image_rgb[i]`, use |
89 |
| - `sequence_point_cloud[sequence_point_cloud_idx[i]]`. |
90 |
| - frame_type: The type of the loaded frame specified in |
91 |
| - `subset_lists_file`, if provided. |
92 |
| - meta: A dict for storing additional frame information. |
93 |
| - """ |
94 |
| - |
95 |
| - frame_number: Optional[torch.LongTensor] |
96 |
| - sequence_name: Union[str, List[str]] |
97 |
| - sequence_category: Union[str, List[str]] |
98 |
| - frame_timestamp: Optional[torch.Tensor] = None |
99 |
| - image_size_hw: Optional[torch.Tensor] = None |
100 |
| - image_path: Union[str, List[str], None] = None |
101 |
| - image_rgb: Optional[torch.Tensor] = None |
102 |
| - # masks out padding added due to cropping the square bit |
103 |
| - mask_crop: Optional[torch.Tensor] = None |
104 |
| - depth_path: Union[str, List[str], None] = None |
105 |
| - depth_map: Optional[torch.Tensor] = None |
106 |
| - depth_mask: Optional[torch.Tensor] = None |
107 |
| - mask_path: Union[str, List[str], None] = None |
108 |
| - fg_probability: Optional[torch.Tensor] = None |
109 |
| - bbox_xywh: Optional[torch.Tensor] = None |
110 |
| - crop_bbox_xywh: Optional[torch.Tensor] = None |
111 |
| - camera: Optional[PerspectiveCameras] = None |
112 |
| - camera_quality_score: Optional[torch.Tensor] = None |
113 |
| - point_cloud_quality_score: Optional[torch.Tensor] = None |
114 |
| - sequence_point_cloud_path: Union[str, List[str], None] = None |
115 |
| - sequence_point_cloud: Optional[Pointclouds] = None |
116 |
| - sequence_point_cloud_idx: Optional[torch.Tensor] = None |
117 |
| - frame_type: Union[str, List[str], None] = None # known | unseen |
118 |
| - meta: dict = field(default_factory=lambda: {}) |
119 |
| - |
120 |
| - def to(self, *args, **kwargs): |
121 |
| - new_params = {} |
122 |
| - for f in fields(self): |
123 |
| - value = getattr(self, f.name) |
124 |
| - if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): |
125 |
| - new_params[f.name] = value.to(*args, **kwargs) |
126 |
| - else: |
127 |
| - new_params[f.name] = value |
128 |
| - return type(self)(**new_params) |
129 |
| - |
130 |
| - def cpu(self): |
131 |
| - return self.to(device=torch.device("cpu")) |
132 |
| - |
133 |
| - def cuda(self): |
134 |
| - return self.to(device=torch.device("cuda")) |
135 |
| - |
136 |
| - # the following functions make sure **frame_data can be passed to functions |
137 |
| - def __iter__(self): |
138 |
| - for f in fields(self): |
139 |
| - yield f.name |
140 |
| - |
141 |
| - def __getitem__(self, key): |
142 |
| - return getattr(self, key) |
143 |
| - |
144 |
| - def __len__(self): |
145 |
| - return len(fields(self)) |
146 |
| - |
147 |
| - @classmethod |
148 |
| - def collate(cls, batch): |
149 |
| - """ |
150 |
| - Given a list objects `batch` of class `cls`, collates them into a batched |
151 |
| - representation suitable for processing with deep networks. |
152 |
| - """ |
153 |
| - |
154 |
| - elem = batch[0] |
155 |
| - |
156 |
| - if isinstance(elem, cls): |
157 |
| - pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] |
158 |
| - id_to_idx = defaultdict(list) |
159 |
| - for i, pc_id in enumerate(pointcloud_ids): |
160 |
| - id_to_idx[pc_id].append(i) |
161 |
| - |
162 |
| - sequence_point_cloud = [] |
163 |
| - sequence_point_cloud_idx = -np.ones((len(batch),)) |
164 |
| - for i, ind in enumerate(id_to_idx.values()): |
165 |
| - sequence_point_cloud_idx[ind] = i |
166 |
| - sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) |
167 |
| - assert (sequence_point_cloud_idx >= 0).all() |
168 |
| - |
169 |
| - override_fields = { |
170 |
| - "sequence_point_cloud": sequence_point_cloud, |
171 |
| - "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), |
172 |
| - } |
173 |
| - # note that the pre-collate value of sequence_point_cloud_idx is unused |
174 |
| - |
175 |
| - collated = {} |
176 |
| - for f in fields(elem): |
177 |
| - list_values = override_fields.get( |
178 |
| - f.name, [getattr(d, f.name) for d in batch] |
179 |
| - ) |
180 |
| - collated[f.name] = ( |
181 |
| - cls.collate(list_values) |
182 |
| - if all(list_value is not None for list_value in list_values) |
183 |
| - else None |
184 |
| - ) |
185 |
| - return cls(**collated) |
186 |
| - |
187 |
| - elif isinstance(elem, Pointclouds): |
188 |
| - return join_pointclouds_as_batch(batch) |
189 |
| - |
190 |
| - elif isinstance(elem, CamerasBase): |
191 |
| - # TODO: don't store K; enforce working in NDC space |
192 |
| - return join_cameras_as_batch(batch) |
193 |
| - else: |
194 |
| - return torch.utils.data._utils.collate.default_collate(batch) |
195 |
| - |
196 |
| - |
197 |
| -class _GenericWorkaround: |
198 |
| - """ |
199 |
| - OmegaConf.structured has a weirdness when you try to apply |
200 |
| - it to a dataclass whose first base class is a Generic which is not |
201 |
| - Dict. The issue is with a function called get_dict_key_value_types |
202 |
| - in omegaconf/_utils.py. |
203 |
| - For example this fails: |
204 |
| -
|
205 |
| - @dataclass(eq=False) |
206 |
| - class D(torch.utils.data.Dataset[int]): |
207 |
| - a: int = 3 |
208 |
| -
|
209 |
| - OmegaConf.structured(D) |
210 |
| -
|
211 |
| - We avoid the problem by adding this class as an extra base class. |
212 |
| - """ |
213 |
| - |
214 |
| - pass |
| 23 | +from pytorch3d.implicitron.dataset.frame_data import FrameData |
| 24 | +from pytorch3d.implicitron.dataset.utils import GenericWorkaround |
215 | 25 |
|
216 | 26 |
|
217 | 27 | @dataclass(eq=False)
|
218 |
| -class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): |
| 28 | +class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]): |
219 | 29 | """
|
220 | 30 | Base class to describe a dataset to be used with Implicitron.
|
221 | 31 |
|
|
0 commit comments