Skip to content

Commit eed7148

Browse files
ilongindreadatour
andauthored
Making listing lazy in DatasetQuery (#976)
* adding listing as pre-step * Update src/datachain/lib/dc.py Co-authored-by: Vladimir Rudnykh <[email protected]> * Update src/datachain/query/dataset.py Co-authored-by: Vladimir Rudnykh <[email protected]> * returned to starting step --------- Co-authored-by: Vladimir Rudnykh <[email protected]>
1 parent 71d87f2 commit eed7148

File tree

6 files changed

+70
-44
lines changed

6 files changed

+70
-44
lines changed

src/datachain/catalog/catalog.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def enlist_source(
588588

589589
from_storage(
590590
source, session=self.session, update=update, object_name=object_name
591-
)
591+
).exec()
592592

593593
list_ds_name, list_uri, list_path, _ = get_listing(
594594
source, self.session, update=update

src/datachain/lib/dc/storage.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
)
77

88
from datachain.lib.file import (
9-
File,
109
FileType,
1110
get_file_type,
1211
)
@@ -95,24 +94,28 @@ def from_storage(
9594
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
9695
return dc
9796

97+
dc = from_dataset(list_ds_name, session=session, settings=settings)
98+
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
99+
98100
if update or not list_ds_exists:
99-
# disable prefetch for listing, as it pre-downloads all files
100-
(
101-
from_records(
102-
DataChain.DEFAULT_FILE_RECORD,
103-
session=session,
104-
settings=settings,
105-
in_memory=in_memory,
106-
)
107-
.settings(prefetch=0)
108-
.gen(
109-
list_bucket(list_uri, cache, client_config=client_config),
110-
output={f"{object_name}": File},
101+
102+
def lst_fn():
103+
# disable prefetch for listing, as it pre-downloads all files
104+
(
105+
from_records(
106+
DataChain.DEFAULT_FILE_RECORD,
107+
session=session,
108+
settings=settings,
109+
in_memory=in_memory,
110+
)
111+
.settings(prefetch=0)
112+
.gen(
113+
list_bucket(list_uri, cache, client_config=client_config),
114+
output={f"{object_name}": file_type},
115+
)
116+
.save(list_ds_name, listing=True)
111117
)
112-
.save(list_ds_name, listing=True)
113-
)
114118

115-
dc = from_dataset(list_ds_name, session=session, settings=settings)
116-
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
119+
dc._query.add_before_steps(lst_fn)
117120

118121
return ls(dc, list_path, recursive=recursive, object_name=object_name)

src/datachain/query/dataset.py

+39-16
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
QueryScriptCancelError,
4848
)
4949
from datachain.func.base import Function
50+
from datachain.lib.listing import is_listing_dataset
5051
from datachain.lib.udf import UDFAdapter, _get_cache
5152
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
5253
from datachain.query.schema import C, UDFParamSpec, normalize_param
@@ -151,13 +152,6 @@ def step_result(
151152
)
152153

153154

154-
class StartingStep(ABC):
155-
"""An initial query processing step, referencing a data source."""
156-
157-
@abstractmethod
158-
def apply(self) -> "StepResult": ...
159-
160-
161155
@frozen
162156
class Step(ABC):
163157
"""A query processing step (filtering, mutation, etc.)"""
@@ -170,7 +164,7 @@ def apply(
170164

171165

172166
@frozen
173-
class QueryStep(StartingStep):
167+
class QueryStep:
174168
catalog: "Catalog"
175169
dataset_name: str
176170
dataset_version: int
@@ -1097,26 +1091,42 @@ def __init__(
10971091
self.temp_table_names: list[str] = []
10981092
self.dependencies: set[DatasetDependencyType] = set()
10991093
self.table = self.get_table()
1100-
self.starting_step: StartingStep
1094+
self.starting_step: Optional[QueryStep] = None
11011095
self.name: Optional[str] = None
11021096
self.version: Optional[int] = None
11031097
self.feature_schema: Optional[dict] = None
11041098
self.column_types: Optional[dict[str, Any]] = None
1099+
self.before_steps: list[Callable] = []
11051100

1106-
self.name = name
1101+
self.list_ds_name: Optional[str] = None
11071102

1108-
if fallback_to_studio and is_token_set():
1109-
ds = self.catalog.get_dataset_with_remote_fallback(name, version)
1103+
self.name = name
1104+
self.dialect = self.catalog.warehouse.db.dialect
1105+
if version:
1106+
self.version = version
1107+
1108+
if is_listing_dataset(name):
1109+
# not setting query step yet as listing dataset might not exist at
1110+
# this point
1111+
self.list_ds_name = name
1112+
elif fallback_to_studio and is_token_set():
1113+
self._set_starting_step(
1114+
self.catalog.get_dataset_with_remote_fallback(name, version)
1115+
)
11101116
else:
1111-
ds = self.catalog.get_dataset(name)
1117+
self._set_starting_step(self.catalog.get_dataset(name))
1118+
1119+
def _set_starting_step(self, ds: "DatasetRecord") -> None:
1120+
if not self.version:
1121+
self.version = ds.latest_version
11121122

1113-
self.version = version or ds.latest_version
1123+
self.starting_step = QueryStep(self.catalog, ds.name, self.version)
1124+
1125+
# at this point we know our starting dataset so setting up schemas
11141126
self.feature_schema = ds.get_version(self.version).feature_schema
11151127
self.column_types = copy(ds.schema)
11161128
if "sys__id" in self.column_types:
11171129
self.column_types.pop("sys__id")
1118-
self.starting_step = QueryStep(self.catalog, name, self.version)
1119-
self.dialect = self.catalog.warehouse.db.dialect
11201130

11211131
def __iter__(self):
11221132
return iter(self.db_results())
@@ -1180,11 +1190,23 @@ def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
11801190
col.table = self.table
11811191
return col
11821192

1193+
def add_before_steps(self, fn: Callable) -> None:
1194+
"""
1195+
Setting custom function to be run before applying steps
1196+
"""
1197+
self.before_steps.append(fn)
1198+
11831199
def apply_steps(self) -> QueryGenerator:
11841200
"""
11851201
Apply the steps in the query and return the resulting
11861202
sqlalchemy.SelectBase.
11871203
"""
1204+
for fn in self.before_steps:
1205+
fn()
1206+
1207+
if self.list_ds_name:
1208+
# at this point we know what is our starting listing dataset name
1209+
self._set_starting_step(self.catalog.get_dataset(self.list_ds_name)) # type: ignore [arg-type]
11881210
query = self.clone()
11891211

11901212
index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
@@ -1203,6 +1225,7 @@ def apply_steps(self) -> QueryGenerator:
12031225
query = query.filter(C.sys__rand % total == index)
12041226
query.steps = query.steps[-1:] + query.steps[:-1]
12051227

1228+
assert query.starting_step
12061229
result = query.starting_step.apply()
12071230
self.dependencies.update(result.dependencies)
12081231

tests/func/test_datachain.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _list_dataset_name(uri: str) -> str:
152152
return name
153153

154154
dogs_uri = f"{src_uri}/dogs"
155-
dc.from_storage(dogs_uri, session=session)
155+
dc.from_storage(dogs_uri, session=session).exec()
156156
assert _get_listing_datasets(session) == [
157157
f"{_list_dataset_name(dogs_uri)}@v1",
158158
]
@@ -162,15 +162,15 @@ def _list_dataset_name(uri: str) -> str:
162162
f"{_list_dataset_name(dogs_uri)}@v1",
163163
]
164164

165-
dc.from_storage(src_uri, session=session)
165+
dc.from_storage(src_uri, session=session).exec()
166166
assert _get_listing_datasets(session) == sorted(
167167
[
168168
f"{_list_dataset_name(dogs_uri)}@v1",
169169
f"{_list_dataset_name(src_uri)}@v1",
170170
]
171171
)
172172

173-
dc.from_storage(f"{src_uri}/cats", session=session)
173+
dc.from_storage(f"{src_uri}/cats", session=session).exec()
174174
assert _get_listing_datasets(session) == sorted(
175175
[
176176
f"{_list_dataset_name(dogs_uri)}@v1",
@@ -196,14 +196,14 @@ def _list_dataset_name(uri: str) -> str:
196196
return name
197197

198198
uri = f"{src_uri}/cats"
199-
dc.from_storage(uri, session=session)
199+
dc.from_storage(uri, session=session).exec()
200200
assert _get_listing_datasets(session) == sorted(
201201
[
202202
f"{_list_dataset_name(uri)}@v1",
203203
]
204204
)
205205

206-
dc.from_storage(uri, session=session, update=True)
206+
dc.from_storage(uri, session=session, update=True).exec()
207207
assert _get_listing_datasets(session) == sorted(
208208
[
209209
f"{_list_dataset_name(uri)}@v1",

tests/func/test_ls.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_ls_no_args(cloud_test_catalog, cloud_type, capsys):
3232
catalog = session.catalog
3333
src = cloud_test_catalog.src_uri
3434

35-
dc.from_storage(src, session=session).collect()
35+
dc.from_storage(src, session=session).exec()
3636
ls([], catalog=catalog)
3737
captured = capsys.readouterr()
3838
assert captured.out == f"{src}/@v1\n"

tests/unit/lib/test_datachain.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_listings(test_session, tmp_dir):
339339
df.to_parquet(tmp_dir / "df.parquet")
340340

341341
uri = tmp_dir.as_uri()
342-
dc.from_storage(uri, session=test_session)
342+
dc.from_storage(uri, session=test_session).exec()
343343

344344
# check that listing is not returned as normal dataset
345345
assert not any(
@@ -370,13 +370,13 @@ def test_listings_reindex(test_session, tmp_dir):
370370

371371
uri = tmp_dir.as_uri()
372372

373-
dc.from_storage(uri, session=test_session)
373+
dc.from_storage(uri, session=test_session).exec()
374374
assert len(list(dc.listings(session=test_session).collect("listing"))) == 1
375375

376-
dc.from_storage(uri, session=test_session)
376+
dc.from_storage(uri, session=test_session).exec()
377377
assert len(list(dc.listings(session=test_session).collect("listing"))) == 1
378378

379-
dc.from_storage(uri, session=test_session, update=True)
379+
dc.from_storage(uri, session=test_session, update=True).exec()
380380
listings = list(dc.listings(session=test_session).collect("listing"))
381381
assert len(listings) == 2
382382
listings.sort(key=lambda lst: lst.version)

0 commit comments

Comments
 (0)