Skip to content

Commit 43cd681

Browse files
Antoine Toisoulfacebook-github-bot
Antoine Toisoul
authored andcommitted
Updates to Implicitron dataset, metrics and tools
Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack) Reviewed By: shapovalov Differential Revision: D65942513 fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
1 parent 42a4a7d commit 43cd681

File tree

7 files changed

+240
-57
lines changed

7 files changed

+240
-57
lines changed

pytorch3d/implicitron/dataset/frame_data.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
4949
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
5050
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
51+
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
5152
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
5253

5354
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
@@ -158,7 +159,7 @@ def to(self, *args, **kwargs):
158159
new_params = {}
159160
for field_name in iter(self):
160161
value = getattr(self, field_name)
161-
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
162+
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
162163
new_params[field_name] = value.to(*args, **kwargs)
163164
else:
164165
new_params[field_name] = value
@@ -420,7 +421,6 @@ def collate(cls, batch):
420421
for f in fields(elem):
421422
if not f.init:
422423
continue
423-
424424
list_values = override_fields.get(
425425
f.name, [getattr(d, f.name) for d in batch]
426426
)
@@ -429,14 +429,16 @@ def collate(cls, batch):
429429
if all(list_value is not None for list_value in list_values)
430430
else None
431431
)
432-
return cls(**collated)
432+
return type(elem)(**collated)
433433

434434
elif isinstance(elem, Pointclouds):
435435
return join_pointclouds_as_batch(batch)
436436

437437
elif isinstance(elem, CamerasBase):
438438
# TODO: don't store K; enforce working in NDC space
439439
return join_cameras_as_batch(batch)
440+
elif isinstance(elem, Meshes):
441+
return join_meshes_as_batch(batch)
440442
else:
441443
return torch.utils.data.dataloader.default_collate(batch)
442444

@@ -592,6 +594,7 @@ def build(
592594
fg_mask_np: np.ndarray | None = None
593595
bbox_xywh: tuple[float, float, float, float] | None = None
594596
mask_annotation = frame_annotation.mask
597+
595598
if mask_annotation is not None:
596599
if load_blobs and self.load_masks:
597600
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)

pytorch3d/implicitron/dataset/sql_dataset.py

+88-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import logging
1212
import os
13+
1314
import urllib
1415
from dataclasses import dataclass, Field, field
1516
from typing import (
@@ -37,12 +38,13 @@
3738
FrameDataBuilder, # noqa
3839
FrameDataBuilderBase,
3940
)
41+
4042
from pytorch3d.implicitron.tools.config import (
4143
registry,
4244
ReplaceableBase,
4345
run_auto_creation,
4446
)
45-
from sqlalchemy.orm import Session
47+
from sqlalchemy.orm import scoped_session, Session, sessionmaker
4648

4749
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
4850

@@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
9193
engine verbatim. Don’t expose it to end users of your application!
9294
pick_categories: Restrict the dataset to the given list of categories.
9395
pick_sequences: A Sequence of sequence names to restrict the dataset to.
96+
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
9497
exclude_sequences: A Sequence of the names of the sequences to exclude.
9598
limit_sequences_per_category_to: Limit the dataset to the first up to N
9699
sequences within each category (applies after all other sequence filters
@@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
105108
more frames than that; applied after other frame-level filters.
106109
seed: The seed of the random generator sampling `n_frames_per_sequence`
107110
random frames per sequence.
111+
preload_metadata: If True, the metadata is preloaded into memory.
112+
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
113+
scoped_session: If True, allows different parts of the code to share
114+
a global session to access the database.
108115
"""
109116

110117
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
@@ -123,13 +130,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
123130
pick_categories: Tuple[str, ...] = ()
124131

125132
pick_sequences: Tuple[str, ...] = ()
133+
pick_sequences_sql_clause: Optional[str] = None
126134
exclude_sequences: Tuple[str, ...] = ()
127135
limit_sequences_per_category_to: int = 0
128136
limit_sequences_to: int = 0
129137
limit_to: int = 0
130138
n_frames_per_sequence: int = -1
131139
seed: int = 0
132140
remove_empty_masks_poll_whole_table_threshold: int = 300_000
141+
preload_metadata: bool = False
142+
precompute_seq_to_idx: bool = False
133143
# we set it manually in the constructor
134144
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
135145
_sql_engine: sa.engine.Engine = field(
@@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
142152
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
143153
frame_data_builder_class_type: str = "FrameDataBuilder"
144154

155+
scoped_session: bool = False
156+
145157
def __post_init__(self) -> None:
146158
if sa.__version__ < "2.0":
147159
raise ImportError("This class requires SQL Alchemy 2.0 or later")
@@ -169,6 +181,9 @@ def __post_init__(self) -> None:
169181
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
170182
)
171183

184+
if self.preload_metadata:
185+
self._sql_engine = self._preload_database(self._sql_engine)
186+
172187
sequences = self._get_filtered_sequences_if_any()
173188

174189
if self.subsets:
@@ -192,6 +207,20 @@ def __post_init__(self) -> None:
192207

193208
logger.info(str(self))
194209

210+
if self.scoped_session:
211+
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
212+
213+
if self.precompute_seq_to_idx:
214+
# This is deprecated and will be removed in the future.
215+
# After we backport https://github.com/facebookresearch/uco3d/pull/3
216+
logger.warning(
217+
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
218+
)
219+
self._index["rowid"] = np.arange(len(self._index))
220+
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
221+
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
222+
del self._index["rowid"]
223+
195224
def __len__(self) -> int:
196225
return len(self._index)
197226

@@ -252,9 +281,15 @@ def _get_item(
252281
seq_stmt = sa.select(self.sequence_annotations_type).where(
253282
self.sequence_annotations_type.sequence_name == seq
254283
)
255-
with Session(self._sql_engine) as session:
256-
entry = session.scalars(stmt).one()
257-
seq_metadata = session.scalars(seq_stmt).one()
284+
if self.scoped_session:
285+
# pyre-ignore
286+
with scoped_session(self._session_factory)() as session:
287+
entry = session.scalars(stmt).one()
288+
seq_metadata = session.scalars(seq_stmt).one()
289+
else:
290+
with Session(self._sql_engine) as session:
291+
entry = session.scalars(stmt).one()
292+
seq_metadata = session.scalars(seq_stmt).one()
258293

259294
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
260295

@@ -363,6 +398,20 @@ def sequence_frames_in_order(
363398

364399
yield from index_slice.itertuples(index=False)
365400

401+
# override
402+
def sequence_indices_in_order(
403+
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
404+
) -> Iterator[int]:
405+
"""Same as `sequence_frames_in_order` but returns the iterator over
406+
only dataset indices.
407+
"""
408+
if self.precompute_seq_to_idx and subset_filter is None:
409+
# pyre-ignore
410+
yield from self._seq_to_indices[seq_name]
411+
else:
412+
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
413+
yield idx
414+
366415
# override
367416
def get_eval_batches(self) -> Optional[List[Any]]:
368417
"""
@@ -396,11 +445,35 @@ def is_filtered(self) -> bool:
396445
or self.limit_sequences_to > 0
397446
or self.limit_sequences_per_category_to > 0
398447
or len(self.pick_sequences) > 0
448+
or self.pick_sequences_sql_clause is not None
399449
or len(self.exclude_sequences) > 0
400450
or len(self.pick_categories) > 0
401451
or self.n_frames_per_sequence > 0
402452
)
403453

454+
def _preload_database(
455+
self, source_engine: sa.engine.base.Engine
456+
) -> sa.engine.base.Engine:
457+
destination_engine = sa.create_engine("sqlite:///:memory:")
458+
metadata = sa.MetaData()
459+
metadata.reflect(bind=source_engine)
460+
metadata.create_all(bind=destination_engine)
461+
462+
with source_engine.connect() as source_conn:
463+
with destination_engine.connect() as destination_conn:
464+
for table_obj in metadata.tables.values():
465+
# Select all rows from the source table
466+
source_rows = source_conn.execute(table_obj.select())
467+
468+
# Insert rows into the destination table
469+
for row in source_rows:
470+
destination_conn.execute(table_obj.insert().values(row))
471+
472+
# Commit the changes for each table
473+
destination_conn.commit()
474+
475+
return destination_engine
476+
404477
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
405478
# maximum possible filter (if limit_sequences_per_category_to == 0):
406479
# WHERE category IN 'self.pick_categories'
@@ -413,6 +486,9 @@ def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
413486
*self._get_pick_filters(),
414487
*self._get_exclude_filters(),
415488
]
489+
if self.pick_sequences_sql_clause:
490+
print("Applying the custom SQL clause.")
491+
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
416492

417493
def add_where(stmt):
418494
return stmt.where(*where_conditions) if where_conditions else stmt
@@ -749,9 +825,15 @@ def _get_frame_no_coalesced_ts_by_row_indices(
749825
self.frame_annotations_type.sequence_name == seq_name,
750826
self.frame_annotations_type.frame_number.in_(frames),
751827
)
828+
frame_no_ts = None
752829

753-
with self._sql_engine.connect() as connection:
754-
frame_no_ts = pd.read_sql_query(stmt, connection)
830+
if self.scoped_session:
831+
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
832+
with scoped_session(self._session_factory)() as session: # pyre-ignore
833+
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
834+
else:
835+
with self._sql_engine.connect() as connection:
836+
frame_no_ts = pd.read_sql_query(stmt, connection)
755837

756838
if len(frame_no_ts) != len(index_slice):
757839
raise ValueError(

pytorch3d/implicitron/dataset/sql_dataset_provider.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,14 @@ def __post_init__(self):
284284
logger.info(f"Val dataset: {str(val_dataset)}")
285285

286286
logger.debug("Extracting test dataset.")
287-
eval_batches_file = self._get_lists_file("eval_batches")
288-
del common_dataset_kwargs["eval_batches_file"]
287+
if self.eval_batches_path is None:
288+
eval_batches_file = None
289+
else:
290+
eval_batches_file = self._get_lists_file("eval_batches")
291+
292+
if "eval_batches_file" in common_dataset_kwargs:
293+
common_dataset_kwargs.pop("eval_batches_file", None)
294+
289295
test_dataset = dataset_type(
290296
**common_dataset_kwargs,
291297
subsets=self._get_subsets(self.test_subsets, True),

pytorch3d/implicitron/dataset/utils.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -211,21 +211,29 @@ def resize_image(
211211
if isinstance(image, np.ndarray):
212212
image = torch.from_numpy(image)
213213

214-
if image_height is None or image_width is None:
214+
if (
215+
image_height is None
216+
or image_width is None
217+
or image.shape[-2] == 0
218+
or image.shape[-1] == 0
219+
):
215220
# skip the resizing
216221
return image, 1.0, torch.ones_like(image[:1])
222+
217223
# takes numpy array or tensor, returns pytorch tensor
218224
minscale = min(
219225
image_height / image.shape[-2],
220226
image_width / image.shape[-1],
221227
)
228+
222229
imre = torch.nn.functional.interpolate(
223230
image[None],
224231
scale_factor=minscale,
225232
mode=mode,
226233
align_corners=False if mode == "bilinear" else None,
227234
recompute_scale_factor=True,
228235
)[0]
236+
229237
imre_ = torch.zeros(image.shape[0], image_height, image_width)
230238
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
231239
mask = torch.zeros(1, image_height, image_width)
@@ -238,20 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
238246
return im.astype(np.float32) / 255.0
239247

240248

241-
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
249+
def load_image(
250+
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
251+
) -> np.ndarray:
242252
"""
243253
Load an image from a path and return it as a numpy array.
244254
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
245255
returned as the fourth channel.
246256
Otherwise, the image is read as RGB and a three-channel image is returned.
247257
"""
248-
249258
with Image.open(path) as pil_im:
250259
# Check if the image has an alpha channel
251260
if try_read_alpha and pil_im.mode == "RGBA":
252261
im = np.array(pil_im)
253262
else:
254-
im = np.array(pil_im.convert("RGB"))
263+
im = np.array(pil_im.convert(pil_format))
255264

256265
return transpose_normalize_image(im)
257266

@@ -389,7 +398,7 @@ def adjust_camera_to_image_scale_(
389398
)
390399
camera.focal_length = focal_length_scaled[None]
391400
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
392-
camera.principal_point = principal_point_scaled[None]
401+
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
393402

394403

395404
# NOTE this cache is per-worker; they are implemented as processes.

0 commit comments

Comments
 (0)