8
8
import json
9
9
import logging
10
10
import os
11
- from dataclasses import dataclass
11
+ import urllib
12
+ from dataclasses import dataclass , Field , field
12
13
from typing import (
13
14
Any ,
14
15
ClassVar ,
29
30
import torch
30
31
from pytorch3d .implicitron .dataset .dataset_base import DatasetBase
31
32
32
- from pytorch3d .implicitron .dataset .frame_data import ( # noqa
33
+ from pytorch3d .implicitron .dataset .frame_data import (
33
34
FrameData ,
34
- FrameDataBuilder ,
35
+ FrameDataBuilder , # noqa
35
36
FrameDataBuilderBase ,
36
37
)
37
38
from pytorch3d .implicitron .tools .config import (
51
52
52
53
53
54
@registry .register
54
- class SqlIndexDataset (DatasetBase , ReplaceableBase ): # pyre-ignore
55
+ class SqlIndexDataset (DatasetBase , ReplaceableBase ):
55
56
"""
56
57
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
57
58
The length is returned after all sequence and frame filters are applied (see param
@@ -125,9 +126,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
125
126
seed : int = 0
126
127
remove_empty_masks_poll_whole_table_threshold : int = 300_000
127
128
# we set it manually in the constructor
128
- # _index: pd.DataFrame = field(init=False)
129
-
130
- frame_data_builder : FrameDataBuilderBase
129
+ _index : pd .DataFrame = field (init = False , metadata = {"omegaconf_ignore" : True })
130
+ _sql_engine : sa .engine .Engine = field (
131
+ init = False , metadata = {"omegaconf_ignore" : True }
132
+ )
133
+ eval_batches : Optional [List [Any ]] = field (
134
+ init = False , metadata = {"omegaconf_ignore" : True }
135
+ )
136
+
137
+ frame_data_builder : FrameDataBuilderBase # pyre-ignore[13]
131
138
frame_data_builder_class_type : str = "FrameDataBuilder"
132
139
133
140
def __post_init__ (self ) -> None :
@@ -138,17 +145,23 @@ def __post_init__(self) -> None:
138
145
raise ValueError ("sqlite_metadata_file must be set" )
139
146
140
147
if self .dataset_root :
141
- frame_builder_type = self .frame_data_builder_class_type
142
- getattr (self , f"frame_data_builder_{ frame_builder_type } _args" )[
143
- "dataset_root"
144
- ] = self .dataset_root
148
+ frame_args = f"frame_data_builder_{ self .frame_data_builder_class_type } _args"
149
+ getattr (self , frame_args )["dataset_root" ] = self .dataset_root
150
+ getattr (self , frame_args )["path_manager" ] = self .path_manager
145
151
146
152
run_auto_creation (self )
147
- self .frame_data_builder .path_manager = self .path_manager
148
153
149
- # pyre-ignore # NOTE: sqlite-specific args (read-only mode).
154
+ if self .path_manager is not None :
155
+ self .sqlite_metadata_file = self .path_manager .get_local_path (
156
+ self .sqlite_metadata_file
157
+ )
158
+ self .subset_lists_file = self .path_manager .get_local_path (
159
+ self .subset_lists_file
160
+ )
161
+
162
+ # NOTE: sqlite-specific args (read-only mode).
150
163
self ._sql_engine = sa .create_engine (
151
- f"sqlite:///file:{ self .sqlite_metadata_file } ?mode=ro&uri=true"
164
+ f"sqlite:///file:{ urllib . parse . quote ( self .sqlite_metadata_file ) } ?mode=ro&uri=true"
152
165
)
153
166
154
167
sequences = self ._get_filtered_sequences_if_any ()
@@ -166,16 +179,15 @@ def __post_init__(self) -> None:
166
179
if len (index ) == 0 :
167
180
raise ValueError (f"There are no frames in the subsets: { self .subsets } !" )
168
181
169
- self ._index = index .set_index (["sequence_name" , "frame_number" ]) # pyre-ignore
182
+ self ._index = index .set_index (["sequence_name" , "frame_number" ])
170
183
171
- self .eval_batches = None # pyre-ignore
184
+ self .eval_batches = None
172
185
if self .eval_batches_file :
173
186
self .eval_batches = self ._load_filter_eval_batches ()
174
187
175
188
logger .info (str (self ))
176
189
177
190
def __len__ (self ) -> int :
178
- # pyre-ignore[16]
179
191
return len (self ._index )
180
192
181
193
def __getitem__ (self , frame_idx : Union [int , Tuple [str , int ]]) -> FrameData :
@@ -250,7 +262,6 @@ def _get_item(
250
262
return frame_data
251
263
252
264
def __str__ (self ) -> str :
253
- # pyre-ignore[16]
254
265
return f"SqlIndexDataset #frames={ len (self ._index )} "
255
266
256
267
def sequence_names (self ) -> Iterable [str ]:
@@ -335,12 +346,12 @@ def sequence_frames_in_order(
335
346
rows = self ._index .index .get_loc (seq_name )
336
347
if isinstance (rows , slice ):
337
348
assert rows .stop is not None , "Unexpected result from pandas"
338
- rows = range (rows .start or 0 , rows .stop , rows .step or 1 )
349
+ rows_seq = range (rows .start or 0 , rows .stop , rows .step or 1 )
339
350
else :
340
- rows = np .where (rows )[0 ]
351
+ rows_seq = list ( np .where (rows )[0 ])
341
352
342
353
index_slice , idx = self ._get_frame_no_coalesced_ts_by_row_indices (
343
- rows , seq_name , subset_filter
354
+ rows_seq , seq_name , subset_filter
344
355
)
345
356
index_slice ["idx" ] = idx
346
357
@@ -461,14 +472,15 @@ def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
461
472
return [SqlSequenceAnnotation .sequence_name .notin_ (self .exclude_sequences )]
462
473
463
474
def _load_subsets_from_json (self , subset_lists_path : str ) -> pd .DataFrame :
464
- assert self .subsets is not None
475
+ subsets = self .subsets
476
+ assert subsets is not None
465
477
with open (subset_lists_path , "r" ) as f :
466
478
subset_to_seq_frame = json .load (f )
467
479
468
480
seq_frame_list = sum (
469
481
(
470
482
[(* row , subset ) for row in subset_to_seq_frame [subset ]]
471
- for subset in self . subsets
483
+ for subset in subsets
472
484
),
473
485
[],
474
486
)
@@ -522,7 +534,7 @@ def _build_index_from_subset_lists(
522
534
stmt = sa .select (
523
535
self .frame_annotations_type .sequence_name ,
524
536
self .frame_annotations_type .frame_number ,
525
- ).where (self .frame_annotations_type ._mask_mass == 0 )
537
+ ).where (self .frame_annotations_type ._mask_mass == 0 ) # pyre-ignore[16]
526
538
with Session (self ._sql_engine ) as session :
527
539
to_remove = session .execute (stmt ).all ()
528
540
@@ -586,7 +598,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]):
586
598
stmt = sa .select (
587
599
self .frame_annotations_type .sequence_name ,
588
600
self .frame_annotations_type .frame_number ,
589
- self .frame_annotations_type ._image_path ,
601
+ self .frame_annotations_type ._image_path , # pyre-ignore[16]
590
602
sa .null ().label ("subset" ),
591
603
)
592
604
where_conditions = []
@@ -600,7 +612,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]):
600
612
logger .info (" excluding samples with empty masks" )
601
613
where_conditions .append (
602
614
sa .or_ (
603
- self .frame_annotations_type ._mask_mass .is_ (None ),
615
+ self .frame_annotations_type ._mask_mass .is_ (None ), # pyre-ignore[16]
604
616
self .frame_annotations_type ._mask_mass != 0 ,
605
617
)
606
618
)
@@ -634,15 +646,18 @@ def _load_filter_eval_batches(self):
634
646
assert self .eval_batches_file
635
647
logger .info (f"Loading eval batches from { self .eval_batches_file } " )
636
648
637
- if not os .path .isfile (self .eval_batches_file ):
649
+ if (
650
+ self .path_manager and not self .path_manager .isfile (self .eval_batches_file )
651
+ ) or (not self .path_manager and not os .path .isfile (self .eval_batches_file )):
638
652
# The batch indices file does not exist.
639
653
# Most probably the user has not specified the root folder.
640
654
raise ValueError (
641
655
f"Looking for dataset json file in { self .eval_batches_file } . "
642
656
+ "Please specify a correct dataset_root folder."
643
657
)
644
658
645
- with open (self .eval_batches_file , "r" ) as f :
659
+ eval_batches_file = self ._local_path (self .eval_batches_file )
660
+ with open (eval_batches_file , "r" ) as f :
646
661
eval_batches = json .load (f )
647
662
648
663
# limit the dataset to sequences to allow multiple evaluations in one file
@@ -758,11 +773,18 @@ def _get_temp_index_table_instance(self, table_name: str = "__index"):
758
773
prefixes = ["TEMP" ], # NOTE SQLite specific!
759
774
)
760
775
776
+ @classmethod
777
+ def pre_expand (cls ) -> None :
778
+ # remove dataclass annotations that are not meant to be init params
779
+ # because they cause troubles for OmegaConf
780
+ for attr , attr_value in list (cls .__dict__ .items ()): # need to copy as we mutate
781
+ if isinstance (attr_value , Field ) and attr_value .metadata .get (
782
+ "omegaconf_ignore" , False
783
+ ):
784
+ delattr (cls , attr )
785
+ del cls .__annotations__ [attr ]
786
+
761
787
762
788
def _seq_name_to_seed (seq_name ) -> int :
763
789
"""Generates numbers in [0, 2 ** 28)"""
764
790
return int (hashlib .sha1 (seq_name .encode ("utf-8" )).hexdigest ()[:7 ], 16 )
765
-
766
-
767
- def _safe_as_tensor (data , dtype ):
768
- return torch .tensor (data , dtype = dtype ) if data is not None else None
0 commit comments