Skip to content

Commit ebdbfde

Browse files
Ildar Salakhievfacebook-github-bot
Ildar Salakhiev
authored andcommitted
Extract BlobLoader class from JsonIndexDataset and moving crop_by_bbox to FrameData
Summary: extracted blob loader added documentation for blob_loader did some refactoring on fields for detailed steps and discussions see: #1463 fairinternal/pixar_replay#160 Reviewed By: bottler Differential Revision: D44061728 fbshipit-source-id: eefb21e9679003045d73729f96e6a93a1d4d2d51
1 parent c759fc5 commit ebdbfde

15 files changed

+1421
-694
lines changed

pytorch3d/implicitron/dataset/data_loader_map_provider.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
Sampler,
1919
)
2020

21-
from .dataset_base import DatasetBase, FrameData
21+
from .dataset_base import DatasetBase
2222
from .dataset_map_provider import DatasetMap
23+
from .frame_data import FrameData
2324
from .scene_batch_sampler import SceneBatchSampler
2425
from .utils import is_known_frame_scalar
2526

pytorch3d/implicitron/dataset/dataset_base.py

+4-194
Original file line numberDiff line numberDiff line change
@@ -5,217 +5,27 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from collections import defaultdict
8-
from dataclasses import dataclass, field, fields
8+
from dataclasses import dataclass
99
from typing import (
10-
Any,
1110
ClassVar,
1211
Dict,
1312
Iterable,
1413
Iterator,
1514
List,
16-
Mapping,
1715
Optional,
1816
Sequence,
1917
Tuple,
2018
Type,
21-
Union,
2219
)
2320

24-
import numpy as np
2521
import torch
26-
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
27-
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
28-
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
2922

30-
31-
@dataclass
32-
class FrameData(Mapping[str, Any]):
33-
"""
34-
A type of the elements returned by indexing the dataset object.
35-
It can represent both individual frames and batches of thereof;
36-
in this documentation, the sizes of tensors refer to single frames;
37-
add the first batch dimension for the collation result.
38-
39-
Args:
40-
frame_number: The number of the frame within its sequence.
41-
0-based continuous integers.
42-
sequence_name: The unique name of the frame's sequence.
43-
sequence_category: The object category of the sequence.
44-
frame_timestamp: The time elapsed since the start of a sequence in sec.
45-
image_size_hw: The size of the image in pixels; (height, width) tensor
46-
of shape (2,).
47-
image_path: The qualified path to the loaded image (with dataset_root).
48-
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
49-
of the frame; elements are floats in [0, 1].
50-
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
51-
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
52-
are a result of zero-padding of the image after cropping around
53-
the object bounding box; elements are floats in {0.0, 1.0}.
54-
depth_path: The qualified path to the frame's depth map.
55-
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
56-
of the frame; values correspond to distances from the camera;
57-
use `depth_mask` and `mask_crop` to filter for valid pixels.
58-
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
59-
depth map that are valid for evaluation, they have been checked for
60-
consistency across views; elements are floats in {0.0, 1.0}.
61-
mask_path: A qualified path to the foreground probability mask.
62-
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
63-
pixels belonging to the captured object; elements are floats
64-
in [0, 1].
65-
bbox_xywh: The bounding box tightly enclosing the foreground object in the
66-
format (x0, y0, width, height). The convention assumes that
67-
`x0+width` and `y0+height` includes the boundary of the box.
68-
I.e., to slice out the corresponding crop from an image tensor `I`
69-
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
70-
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
71-
in the original image coordinates in the format (x0, y0, width, height).
72-
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
73-
from `bbox_xywh` due to padding (which can happen e.g. due to
74-
setting `JsonIndexDataset.box_crop_context > 0`)
75-
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
76-
corrected for cropping if it happened.
77-
camera_quality_score: The score proportional to the confidence of the
78-
frame's camera estimation (the higher the more accurate).
79-
point_cloud_quality_score: The score proportional to the accuracy of the
80-
frame's sequence point cloud (the higher the more accurate).
81-
sequence_point_cloud_path: The path to the sequence's point cloud.
82-
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
83-
point cloud corresponding to the frame's sequence. When the object
84-
represents a batch of frames, point clouds may be deduplicated;
85-
see `sequence_point_cloud_idx`.
86-
sequence_point_cloud_idx: Integer indices mapping frame indices to the
87-
corresponding point clouds in `sequence_point_cloud`; to get the
88-
corresponding point cloud to `image_rgb[i]`, use
89-
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
90-
frame_type: The type of the loaded frame specified in
91-
`subset_lists_file`, if provided.
92-
meta: A dict for storing additional frame information.
93-
"""
94-
95-
frame_number: Optional[torch.LongTensor]
96-
sequence_name: Union[str, List[str]]
97-
sequence_category: Union[str, List[str]]
98-
frame_timestamp: Optional[torch.Tensor] = None
99-
image_size_hw: Optional[torch.Tensor] = None
100-
image_path: Union[str, List[str], None] = None
101-
image_rgb: Optional[torch.Tensor] = None
102-
# masks out padding added due to cropping the square bit
103-
mask_crop: Optional[torch.Tensor] = None
104-
depth_path: Union[str, List[str], None] = None
105-
depth_map: Optional[torch.Tensor] = None
106-
depth_mask: Optional[torch.Tensor] = None
107-
mask_path: Union[str, List[str], None] = None
108-
fg_probability: Optional[torch.Tensor] = None
109-
bbox_xywh: Optional[torch.Tensor] = None
110-
crop_bbox_xywh: Optional[torch.Tensor] = None
111-
camera: Optional[PerspectiveCameras] = None
112-
camera_quality_score: Optional[torch.Tensor] = None
113-
point_cloud_quality_score: Optional[torch.Tensor] = None
114-
sequence_point_cloud_path: Union[str, List[str], None] = None
115-
sequence_point_cloud: Optional[Pointclouds] = None
116-
sequence_point_cloud_idx: Optional[torch.Tensor] = None
117-
frame_type: Union[str, List[str], None] = None # known | unseen
118-
meta: dict = field(default_factory=lambda: {})
119-
120-
def to(self, *args, **kwargs):
121-
new_params = {}
122-
for f in fields(self):
123-
value = getattr(self, f.name)
124-
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
125-
new_params[f.name] = value.to(*args, **kwargs)
126-
else:
127-
new_params[f.name] = value
128-
return type(self)(**new_params)
129-
130-
def cpu(self):
131-
return self.to(device=torch.device("cpu"))
132-
133-
def cuda(self):
134-
return self.to(device=torch.device("cuda"))
135-
136-
# the following functions make sure **frame_data can be passed to functions
137-
def __iter__(self):
138-
for f in fields(self):
139-
yield f.name
140-
141-
def __getitem__(self, key):
142-
return getattr(self, key)
143-
144-
def __len__(self):
145-
return len(fields(self))
146-
147-
@classmethod
148-
def collate(cls, batch):
149-
"""
150-
Given a list objects `batch` of class `cls`, collates them into a batched
151-
representation suitable for processing with deep networks.
152-
"""
153-
154-
elem = batch[0]
155-
156-
if isinstance(elem, cls):
157-
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
158-
id_to_idx = defaultdict(list)
159-
for i, pc_id in enumerate(pointcloud_ids):
160-
id_to_idx[pc_id].append(i)
161-
162-
sequence_point_cloud = []
163-
sequence_point_cloud_idx = -np.ones((len(batch),))
164-
for i, ind in enumerate(id_to_idx.values()):
165-
sequence_point_cloud_idx[ind] = i
166-
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
167-
assert (sequence_point_cloud_idx >= 0).all()
168-
169-
override_fields = {
170-
"sequence_point_cloud": sequence_point_cloud,
171-
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
172-
}
173-
# note that the pre-collate value of sequence_point_cloud_idx is unused
174-
175-
collated = {}
176-
for f in fields(elem):
177-
list_values = override_fields.get(
178-
f.name, [getattr(d, f.name) for d in batch]
179-
)
180-
collated[f.name] = (
181-
cls.collate(list_values)
182-
if all(list_value is not None for list_value in list_values)
183-
else None
184-
)
185-
return cls(**collated)
186-
187-
elif isinstance(elem, Pointclouds):
188-
return join_pointclouds_as_batch(batch)
189-
190-
elif isinstance(elem, CamerasBase):
191-
# TODO: don't store K; enforce working in NDC space
192-
return join_cameras_as_batch(batch)
193-
else:
194-
return torch.utils.data._utils.collate.default_collate(batch)
195-
196-
197-
class _GenericWorkaround:
198-
"""
199-
OmegaConf.structured has a weirdness when you try to apply
200-
it to a dataclass whose first base class is a Generic which is not
201-
Dict. The issue is with a function called get_dict_key_value_types
202-
in omegaconf/_utils.py.
203-
For example this fails:
204-
205-
@dataclass(eq=False)
206-
class D(torch.utils.data.Dataset[int]):
207-
a: int = 3
208-
209-
OmegaConf.structured(D)
210-
211-
We avoid the problem by adding this class as an extra base class.
212-
"""
213-
214-
pass
23+
from pytorch3d.implicitron.dataset.frame_data import FrameData
24+
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
21525

21626

21727
@dataclass(eq=False)
218-
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
28+
class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):
21929
"""
22030
Base class to describe a dataset to be used with Implicitron.
22131

0 commit comments

Comments
 (0)