Skip to content

Commit 64a5bfa

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Adding SQL Dataset related files to the build script
Summary: Now that we have SQLAlchemy 2.0, we can fully use them. Reviewed By: bottler Differential Revision: D66920096 fbshipit-source-id: 25c0ea1c4f7361e66348035519627dc961b9e6e6
1 parent 055ab3a commit 64a5bfa

File tree

2 files changed

+57
-35
lines changed

2 files changed

+57
-35
lines changed

pytorch3d/implicitron/dataset/sql_dataset.py

+54-32
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import json
99
import logging
1010
import os
11-
from dataclasses import dataclass
11+
import urllib
12+
from dataclasses import dataclass, Field, field
1213
from typing import (
1314
Any,
1415
ClassVar,
@@ -29,9 +30,9 @@
2930
import torch
3031
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
3132

32-
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
33+
from pytorch3d.implicitron.dataset.frame_data import (
3334
FrameData,
34-
FrameDataBuilder,
35+
FrameDataBuilder, # noqa
3536
FrameDataBuilderBase,
3637
)
3738
from pytorch3d.implicitron.tools.config import (
@@ -51,7 +52,7 @@
5152

5253

5354
@registry.register
54-
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
55+
class SqlIndexDataset(DatasetBase, ReplaceableBase):
5556
"""
5657
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
5758
The length is returned after all sequence and frame filters are applied (see param
@@ -125,9 +126,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
125126
seed: int = 0
126127
remove_empty_masks_poll_whole_table_threshold: int = 300_000
127128
# we set it manually in the constructor
128-
# _index: pd.DataFrame = field(init=False)
129-
130-
frame_data_builder: FrameDataBuilderBase
129+
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
130+
_sql_engine: sa.engine.Engine = field(
131+
init=False, metadata={"omegaconf_ignore": True}
132+
)
133+
eval_batches: Optional[List[Any]] = field(
134+
init=False, metadata={"omegaconf_ignore": True}
135+
)
136+
137+
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
131138
frame_data_builder_class_type: str = "FrameDataBuilder"
132139

133140
def __post_init__(self) -> None:
@@ -138,17 +145,23 @@ def __post_init__(self) -> None:
138145
raise ValueError("sqlite_metadata_file must be set")
139146

140147
if self.dataset_root:
141-
frame_builder_type = self.frame_data_builder_class_type
142-
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
143-
"dataset_root"
144-
] = self.dataset_root
148+
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
149+
getattr(self, frame_args)["dataset_root"] = self.dataset_root
150+
getattr(self, frame_args)["path_manager"] = self.path_manager
145151

146152
run_auto_creation(self)
147-
self.frame_data_builder.path_manager = self.path_manager
148153

149-
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
154+
if self.path_manager is not None:
155+
self.sqlite_metadata_file = self.path_manager.get_local_path(
156+
self.sqlite_metadata_file
157+
)
158+
self.subset_lists_file = self.path_manager.get_local_path(
159+
self.subset_lists_file
160+
)
161+
162+
# NOTE: sqlite-specific args (read-only mode).
150163
self._sql_engine = sa.create_engine(
151-
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
164+
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
152165
)
153166

154167
sequences = self._get_filtered_sequences_if_any()
@@ -166,16 +179,15 @@ def __post_init__(self) -> None:
166179
if len(index) == 0:
167180
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
168181

169-
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
182+
self._index = index.set_index(["sequence_name", "frame_number"])
170183

171-
self.eval_batches = None # pyre-ignore
184+
self.eval_batches = None
172185
if self.eval_batches_file:
173186
self.eval_batches = self._load_filter_eval_batches()
174187

175188
logger.info(str(self))
176189

177190
def __len__(self) -> int:
178-
# pyre-ignore[16]
179191
return len(self._index)
180192

181193
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
@@ -250,7 +262,6 @@ def _get_item(
250262
return frame_data
251263

252264
def __str__(self) -> str:
253-
# pyre-ignore[16]
254265
return f"SqlIndexDataset #frames={len(self._index)}"
255266

256267
def sequence_names(self) -> Iterable[str]:
@@ -335,12 +346,12 @@ def sequence_frames_in_order(
335346
rows = self._index.index.get_loc(seq_name)
336347
if isinstance(rows, slice):
337348
assert rows.stop is not None, "Unexpected result from pandas"
338-
rows = range(rows.start or 0, rows.stop, rows.step or 1)
349+
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
339350
else:
340-
rows = np.where(rows)[0]
351+
rows_seq = list(np.where(rows)[0])
341352

342353
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
343-
rows, seq_name, subset_filter
354+
rows_seq, seq_name, subset_filter
344355
)
345356
index_slice["idx"] = idx
346357

@@ -461,14 +472,15 @@ def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
461472
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
462473

463474
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
464-
assert self.subsets is not None
475+
subsets = self.subsets
476+
assert subsets is not None
465477
with open(subset_lists_path, "r") as f:
466478
subset_to_seq_frame = json.load(f)
467479

468480
seq_frame_list = sum(
469481
(
470482
[(*row, subset) for row in subset_to_seq_frame[subset]]
471-
for subset in self.subsets
483+
for subset in subsets
472484
),
473485
[],
474486
)
@@ -522,7 +534,7 @@ def _build_index_from_subset_lists(
522534
stmt = sa.select(
523535
self.frame_annotations_type.sequence_name,
524536
self.frame_annotations_type.frame_number,
525-
).where(self.frame_annotations_type._mask_mass == 0)
537+
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
526538
with Session(self._sql_engine) as session:
527539
to_remove = session.execute(stmt).all()
528540

@@ -586,7 +598,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]):
586598
stmt = sa.select(
587599
self.frame_annotations_type.sequence_name,
588600
self.frame_annotations_type.frame_number,
589-
self.frame_annotations_type._image_path,
601+
self.frame_annotations_type._image_path, # pyre-ignore[16]
590602
sa.null().label("subset"),
591603
)
592604
where_conditions = []
@@ -600,7 +612,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]):
600612
logger.info(" excluding samples with empty masks")
601613
where_conditions.append(
602614
sa.or_(
603-
self.frame_annotations_type._mask_mass.is_(None),
615+
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
604616
self.frame_annotations_type._mask_mass != 0,
605617
)
606618
)
@@ -634,15 +646,18 @@ def _load_filter_eval_batches(self):
634646
assert self.eval_batches_file
635647
logger.info(f"Loading eval batches from {self.eval_batches_file}")
636648

637-
if not os.path.isfile(self.eval_batches_file):
649+
if (
650+
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
651+
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
638652
# The batch indices file does not exist.
639653
# Most probably the user has not specified the root folder.
640654
raise ValueError(
641655
f"Looking for dataset json file in {self.eval_batches_file}. "
642656
+ "Please specify a correct dataset_root folder."
643657
)
644658

645-
with open(self.eval_batches_file, "r") as f:
659+
eval_batches_file = self._local_path(self.eval_batches_file)
660+
with open(eval_batches_file, "r") as f:
646661
eval_batches = json.load(f)
647662

648663
# limit the dataset to sequences to allow multiple evaluations in one file
@@ -758,11 +773,18 @@ def _get_temp_index_table_instance(self, table_name: str = "__index"):
758773
prefixes=["TEMP"], # NOTE SQLite specific!
759774
)
760775

776+
@classmethod
777+
def pre_expand(cls) -> None:
778+
# remove dataclass annotations that are not meant to be init params
779+
# because they cause troubles for OmegaConf
780+
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
781+
if isinstance(attr_value, Field) and attr_value.metadata.get(
782+
"omegaconf_ignore", False
783+
):
784+
delattr(cls, attr)
785+
del cls.__annotations__[attr]
786+
761787

762788
def _seq_name_to_seed(seq_name) -> int:
763789
"""Generates numbers in [0, 2 ** 28)"""
764790
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
765-
766-
767-
def _safe_as_tensor(data, dtype):
768-
return torch.tensor(data, dtype=dtype) if data is not None else None

pytorch3d/implicitron/dataset/sql_dataset_provider.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444

4545
@registry.register
46-
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
46+
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
4747
"""
4848
Generates the training, validation, and testing dataset objects for
4949
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
@@ -193,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
193193

194194
# this is a mould that is never constructed, used to build self._dataset_map values
195195
dataset_class_type: str = "SqlIndexDataset"
196-
dataset: SqlIndexDataset
196+
dataset: SqlIndexDataset # pyre-ignore [13]
197197

198-
path_manager_factory: PathManagerFactory
198+
path_manager_factory: PathManagerFactory # pyre-ignore [13]
199199
path_manager_factory_class_type: str = "PathManagerFactory"
200200

201201
def __post_init__(self):

0 commit comments

Comments
 (0)