10
10
import json
11
11
import logging
12
12
import os
13
+
13
14
import urllib
14
15
from dataclasses import dataclass , Field , field
15
16
from typing import (
37
38
FrameDataBuilder , # noqa
38
39
FrameDataBuilderBase ,
39
40
)
41
+
40
42
from pytorch3d .implicitron .tools .config import (
41
43
registry ,
42
44
ReplaceableBase ,
43
45
run_auto_creation ,
44
46
)
45
- from sqlalchemy .orm import Session
47
+ from sqlalchemy .orm import scoped_session , Session , sessionmaker
46
48
47
49
from .orm_types import SqlFrameAnnotation , SqlSequenceAnnotation
48
50
@@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
91
93
engine verbatim. Don’t expose it to end users of your application!
92
94
pick_categories: Restrict the dataset to the given list of categories.
93
95
pick_sequences: A Sequence of sequence names to restrict the dataset to.
96
+ pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
94
97
exclude_sequences: A Sequence of the names of the sequences to exclude.
95
98
limit_sequences_per_category_to: Limit the dataset to the first up to N
96
99
sequences within each category (applies after all other sequence filters
@@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
105
108
more frames than that; applied after other frame-level filters.
106
109
seed: The seed of the random generator sampling `n_frames_per_sequence`
107
110
random frames per sequence.
111
+ preload_metadata: If True, the metadata is preloaded into memory.
112
+ precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
113
+ scoped_session: If True, allows different parts of the code to share
114
+ a global session to access the database.
108
115
"""
109
116
110
117
frame_annotations_type : ClassVar [Type [SqlFrameAnnotation ]] = SqlFrameAnnotation
@@ -123,13 +130,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
123
130
pick_categories : Tuple [str , ...] = ()
124
131
125
132
pick_sequences : Tuple [str , ...] = ()
133
+ pick_sequences_sql_clause : Optional [str ] = None
126
134
exclude_sequences : Tuple [str , ...] = ()
127
135
limit_sequences_per_category_to : int = 0
128
136
limit_sequences_to : int = 0
129
137
limit_to : int = 0
130
138
n_frames_per_sequence : int = - 1
131
139
seed : int = 0
132
140
remove_empty_masks_poll_whole_table_threshold : int = 300_000
141
+ preload_metadata : bool = False
142
+ precompute_seq_to_idx : bool = False
133
143
# we set it manually in the constructor
134
144
_index : pd .DataFrame = field (init = False , metadata = {"omegaconf_ignore" : True })
135
145
_sql_engine : sa .engine .Engine = field (
@@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
142
152
frame_data_builder : FrameDataBuilderBase # pyre-ignore[13]
143
153
frame_data_builder_class_type : str = "FrameDataBuilder"
144
154
155
+ scoped_session : bool = False
156
+
145
157
def __post_init__ (self ) -> None :
146
158
if sa .__version__ < "2.0" :
147
159
raise ImportError ("This class requires SQL Alchemy 2.0 or later" )
@@ -169,6 +181,9 @@ def __post_init__(self) -> None:
169
181
f"sqlite:///file:{ urllib .parse .quote (self .sqlite_metadata_file )} ?mode=ro&uri=true"
170
182
)
171
183
184
+ if self .preload_metadata :
185
+ self ._sql_engine = self ._preload_database (self ._sql_engine )
186
+
172
187
sequences = self ._get_filtered_sequences_if_any ()
173
188
174
189
if self .subsets :
@@ -192,6 +207,20 @@ def __post_init__(self) -> None:
192
207
193
208
logger .info (str (self ))
194
209
210
+ if self .scoped_session :
211
+ self ._session_factory = sessionmaker (bind = self ._sql_engine ) # pyre-ignore
212
+
213
+ if self .precompute_seq_to_idx :
214
+ # This is deprecated and will be removed in the future.
215
+ # After we backport https://github.com/facebookresearch/uco3d/pull/3
216
+ logger .warning (
217
+ "Using precompute_seq_to_idx is deprecated and will be removed in the future."
218
+ )
219
+ self ._index ["rowid" ] = np .arange (len (self ._index ))
220
+ groupby = self ._index .groupby ("sequence_name" , sort = False )["rowid" ]
221
+ self ._seq_to_indices = dict (groupby .apply (list )) # pyre-ignore
222
+ del self ._index ["rowid" ]
223
+
195
224
def __len__ (self ) -> int :
196
225
return len (self ._index )
197
226
@@ -252,9 +281,15 @@ def _get_item(
252
281
seq_stmt = sa .select (self .sequence_annotations_type ).where (
253
282
self .sequence_annotations_type .sequence_name == seq
254
283
)
255
- with Session (self ._sql_engine ) as session :
256
- entry = session .scalars (stmt ).one ()
257
- seq_metadata = session .scalars (seq_stmt ).one ()
284
+ if self .scoped_session :
285
+ # pyre-ignore
286
+ with scoped_session (self ._session_factory )() as session :
287
+ entry = session .scalars (stmt ).one ()
288
+ seq_metadata = session .scalars (seq_stmt ).one ()
289
+ else :
290
+ with Session (self ._sql_engine ) as session :
291
+ entry = session .scalars (stmt ).one ()
292
+ seq_metadata = session .scalars (seq_stmt ).one ()
258
293
259
294
assert entry .image .path == self ._index .loc [(seq , frame ), "_image_path" ]
260
295
@@ -363,6 +398,20 @@ def sequence_frames_in_order(
363
398
364
399
yield from index_slice .itertuples (index = False )
365
400
401
+ # override
402
+ def sequence_indices_in_order (
403
+ self , seq_name : str , subset_filter : Optional [Sequence [str ]] = None
404
+ ) -> Iterator [int ]:
405
+ """Same as `sequence_frames_in_order` but returns the iterator over
406
+ only dataset indices.
407
+ """
408
+ if self .precompute_seq_to_idx and subset_filter is None :
409
+ # pyre-ignore
410
+ yield from self ._seq_to_indices [seq_name ]
411
+ else :
412
+ for _ , _ , idx in self .sequence_frames_in_order (seq_name , subset_filter ):
413
+ yield idx
414
+
366
415
# override
367
416
def get_eval_batches (self ) -> Optional [List [Any ]]:
368
417
"""
@@ -396,11 +445,35 @@ def is_filtered(self) -> bool:
396
445
or self .limit_sequences_to > 0
397
446
or self .limit_sequences_per_category_to > 0
398
447
or len (self .pick_sequences ) > 0
448
+ or self .pick_sequences_sql_clause is not None
399
449
or len (self .exclude_sequences ) > 0
400
450
or len (self .pick_categories ) > 0
401
451
or self .n_frames_per_sequence > 0
402
452
)
403
453
454
+ def _preload_database (
455
+ self , source_engine : sa .engine .base .Engine
456
+ ) -> sa .engine .base .Engine :
457
+ destination_engine = sa .create_engine ("sqlite:///:memory:" )
458
+ metadata = sa .MetaData ()
459
+ metadata .reflect (bind = source_engine )
460
+ metadata .create_all (bind = destination_engine )
461
+
462
+ with source_engine .connect () as source_conn :
463
+ with destination_engine .connect () as destination_conn :
464
+ for table_obj in metadata .tables .values ():
465
+ # Select all rows from the source table
466
+ source_rows = source_conn .execute (table_obj .select ())
467
+
468
+ # Insert rows into the destination table
469
+ for row in source_rows :
470
+ destination_conn .execute (table_obj .insert ().values (row ))
471
+
472
+ # Commit the changes for each table
473
+ destination_conn .commit ()
474
+
475
+ return destination_engine
476
+
404
477
def _get_filtered_sequences_if_any (self ) -> Optional [pd .Series ]:
405
478
# maximum possible filter (if limit_sequences_per_category_to == 0):
406
479
# WHERE category IN 'self.pick_categories'
@@ -413,6 +486,9 @@ def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
413
486
* self ._get_pick_filters (),
414
487
* self ._get_exclude_filters (),
415
488
]
489
+ if self .pick_sequences_sql_clause :
490
+ print ("Applying the custom SQL clause." )
491
+ where_conditions .append (sa .text (self .pick_sequences_sql_clause ))
416
492
417
493
def add_where (stmt ):
418
494
return stmt .where (* where_conditions ) if where_conditions else stmt
@@ -749,9 +825,15 @@ def _get_frame_no_coalesced_ts_by_row_indices(
749
825
self .frame_annotations_type .sequence_name == seq_name ,
750
826
self .frame_annotations_type .frame_number .in_ (frames ),
751
827
)
828
+ frame_no_ts = None
752
829
753
- with self ._sql_engine .connect () as connection :
754
- frame_no_ts = pd .read_sql_query (stmt , connection )
830
+ if self .scoped_session :
831
+ stmt_text = str (stmt .compile (compile_kwargs = {"literal_binds" : True }))
832
+ with scoped_session (self ._session_factory )() as session : # pyre-ignore
833
+ frame_no_ts = pd .read_sql_query (stmt_text , session .connection ())
834
+ else :
835
+ with self ._sql_engine .connect () as connection :
836
+ frame_no_ts = pd .read_sql_query (stmt , connection )
755
837
756
838
if len (frame_no_ts ) != len (index_slice ):
757
839
raise ValueError (
0 commit comments