Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making listing lazy in DatasetQuery #976

Merged
merged 9 commits into from
Mar 26, 2025
2 changes: 1 addition & 1 deletion src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def enlist_source(

from_storage(
source, session=self.session, update=update, object_name=object_name
)
).exec()

list_ds_name, list_uri, list_path, _ = get_listing(
source, self.session, update=update
Expand Down
37 changes: 20 additions & 17 deletions src/datachain/lib/dc/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
)

from datachain.lib.file import (
File,
FileType,
get_file_type,
)
Expand Down Expand Up @@ -95,24 +94,28 @@ def from_storage(
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
return dc

dc = from_dataset(list_ds_name, session=session, settings=settings)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling from_dataset when list_ds_exists is false also doesn't seem right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lower level code (DatasetQuery) is aware of listing being lazy so this is ok. We will start chain with listing dataset and the fact it doesn't exists yet is just the nature of it's "laziness"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we could get dataset not found error when the ist_ds_name doesn't exist

dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})

if update or not list_ds_exists:
# disable prefetch for listing, as it pre-downloads all files
(
from_records(
DataChain.DEFAULT_FILE_RECORD,
session=session,
settings=settings,
in_memory=in_memory,
)
.settings(prefetch=0)
.gen(
list_bucket(list_uri, cache, client_config=client_config),
output={f"{object_name}": File},

def lst_fn():
# disable prefetch for listing, as it pre-downloads all files
(
from_records(
DataChain.DEFAULT_FILE_RECORD,
session=session,
settings=settings,
in_memory=in_memory,
)
.settings(prefetch=0)
.gen(
list_bucket(list_uri, cache, client_config=client_config),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be called everytime I use the datachain to apply steps. Should'nt this be applied only once?

Copy link
Contributor Author

@ilongin ilongin Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be called every time you apply steps. The whole idea is for user to apply steps only once anyway as it's very expensive operation.

output={f"{object_name}": file_type},
)
.save(list_ds_name, listing=True)
)
.save(list_ds_name, listing=True)
)

dc = from_dataset(list_ds_name, session=session, settings=settings)
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
dc._query.add_before_steps(lst_fn)

return ls(dc, list_path, recursive=recursive, object_name=object_name)
55 changes: 39 additions & 16 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
QueryScriptCancelError,
)
from datachain.func.base import Function
from datachain.lib.listing import is_listing_dataset
from datachain.lib.udf import UDFAdapter, _get_cache
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.query.schema import C, UDFParamSpec, normalize_param
Expand Down Expand Up @@ -151,13 +152,6 @@ def step_result(
)


class StartingStep(ABC):
"""An initial query processing step, referencing a data source."""

@abstractmethod
def apply(self) -> "StepResult": ...


@frozen
class Step(ABC):
"""A query processing step (filtering, mutation, etc.)"""
Expand All @@ -170,7 +164,7 @@ def apply(


@frozen
class QueryStep(StartingStep):
class QueryStep:
catalog: "Catalog"
dataset_name: str
dataset_version: int
Expand Down Expand Up @@ -1097,26 +1091,42 @@ def __init__(
self.temp_table_names: list[str] = []
self.dependencies: set[DatasetDependencyType] = set()
self.table = self.get_table()
self.starting_step: StartingStep
self.starting_step: Optional[QueryStep] = None
self.name: Optional[str] = None
self.version: Optional[int] = None
self.feature_schema: Optional[dict] = None
self.column_types: Optional[dict[str, Any]] = None
self.before_steps: list[Callable] = []

self.name = name
self.list_ds_name: Optional[str] = None

if fallback_to_studio and is_token_set():
ds = self.catalog.get_dataset_with_remote_fallback(name, version)
self.name = name
self.dialect = self.catalog.warehouse.db.dialect
if version:
self.version = version

if is_listing_dataset(name):
# not setting query step yet as listing dataset might not exist at
# this point
self.list_ds_name = name
elif fallback_to_studio and is_token_set():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but is_token_set here looks odd and raises questions.

We may want to import it as:

from datachain.remote.studio import is_token_set as is_studio_token_set

above, for example, just for the better readability of the code here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, agreed cc @amritghimire ... it is still not a good idea to have Studio exposed this way

ideally it should be just get_dataset, inside it it should be deciding on fallback

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's push really really hard to keep studio contained, it is important ... in the same way as for example using DC itself for the implementations (e.g. I wonder if from_storage can be done via map or gen and thus in a lazy way)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Listing is already done with gen but we cannot just append the rest of the chain to that part as we want to cache listing at some point, i.e call save() on it and if we call it the middle of the chain it's not lazy any more. It needs to happen in the save() of the dataset when we apply other steps.
So we could do

def from_ storage():
    return (
        cls.from_records(DEFAULT_FILE_RECORDS)
           .gen(list_bucket(,,,))
           .save(list_ds_name, listing=True) 
    )

ds = DataChain.from_storage("s3://ldb-public").filter(...).map(...).save("my_dataset")

This is similar as it was before this PR but it's not lazy and to make it lazy we need to add some step in DatasetQuery as there we start to apply steps.
Ideal solution would be to move all those steps and apply_step function from DatasetQuery to DataChain as there is no point for main logic to be there IMO and maybe even remove DatasetQuery alltogether but that's whole another topic.

self._set_starting_step(
self.catalog.get_dataset_with_remote_fallback(name, version)
)
else:
ds = self.catalog.get_dataset(name)
self._set_starting_step(self.catalog.get_dataset(name))

def _set_starting_step(self, ds: "DatasetRecord") -> None:
if not self.version:
self.version = ds.latest_version

self.version = version or ds.latest_version
self.starting_step = QueryStep(self.catalog, ds.name, self.version)

# at this point we know our starting dataset so setting up schemas
self.feature_schema = ds.get_version(self.version).feature_schema
self.column_types = copy(ds.schema)
if "sys__id" in self.column_types:
self.column_types.pop("sys__id")
self.starting_step = QueryStep(self.catalog, name, self.version)
self.dialect = self.catalog.warehouse.db.dialect

def __iter__(self):
return iter(self.db_results())
Expand Down Expand Up @@ -1180,11 +1190,23 @@ def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
col.table = self.table
return col

def add_before_steps(self, fn: Callable) -> None:
"""
Setting custom function to be run before applying steps
"""
self.before_steps.append(fn)

def apply_steps(self) -> QueryGenerator:
"""
Apply the steps in the query and return the resulting
sqlalchemy.SelectBase.
"""
for fn in self.before_steps:
fn()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this I saw a caveat, that fn seems to be called every time a step is performed since we don't clear the before steps at any time. So, whenever I try to use the collect or chain, I am getting the query to refetch the table instead.

Copy link
Contributor Author

@ilongin ilongin Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expected and it's how it was before when listing was lazy (before we refactored it using DataChain higher level functions). Listing was always done when someone would apply steps if update flag is used

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it seems to run every time I run chain.collect() or chain.count() . As you can see in the test test_from_storage_multiple_uris_cache in #994 , it is called every time for chains.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chain.collect() applies steps every time it's called

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my question was should we rerun listing every time collect is called when update is passed? Or once should suffice?


if self.list_ds_name:
# at this point we know what is our starting listing dataset name
self._set_starting_step(self.catalog.get_dataset(self.list_ds_name)) # type: ignore [arg-type]
query = self.clone()

index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
Expand All @@ -1203,6 +1225,7 @@ def apply_steps(self) -> QueryGenerator:
query = query.filter(C.sys__rand % total == index)
query.steps = query.steps[-1:] + query.steps[:-1]

assert query.starting_step
result = query.starting_step.apply()
self.dependencies.update(result.dependencies)

Expand Down
10 changes: 5 additions & 5 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _list_dataset_name(uri: str) -> str:
return name

dogs_uri = f"{src_uri}/dogs"
dc.from_storage(dogs_uri, session=session)
dc.from_storage(dogs_uri, session=session).exec()
assert _get_listing_datasets(session) == [
f"{_list_dataset_name(dogs_uri)}@v1",
]
Expand All @@ -162,15 +162,15 @@ def _list_dataset_name(uri: str) -> str:
f"{_list_dataset_name(dogs_uri)}@v1",
]

dc.from_storage(src_uri, session=session)
dc.from_storage(src_uri, session=session).exec()
assert _get_listing_datasets(session) == sorted(
[
f"{_list_dataset_name(dogs_uri)}@v1",
f"{_list_dataset_name(src_uri)}@v1",
]
)

dc.from_storage(f"{src_uri}/cats", session=session)
dc.from_storage(f"{src_uri}/cats", session=session).exec()
assert _get_listing_datasets(session) == sorted(
[
f"{_list_dataset_name(dogs_uri)}@v1",
Expand All @@ -196,14 +196,14 @@ def _list_dataset_name(uri: str) -> str:
return name

uri = f"{src_uri}/cats"
dc.from_storage(uri, session=session)
dc.from_storage(uri, session=session).exec()
assert _get_listing_datasets(session) == sorted(
[
f"{_list_dataset_name(uri)}@v1",
]
)

dc.from_storage(uri, session=session, update=True)
dc.from_storage(uri, session=session, update=True).exec()
assert _get_listing_datasets(session) == sorted(
[
f"{_list_dataset_name(uri)}@v1",
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_ls_no_args(cloud_test_catalog, cloud_type, capsys):
catalog = session.catalog
src = cloud_test_catalog.src_uri

dc.from_storage(src, session=session).collect()
dc.from_storage(src, session=session).exec()
ls([], catalog=catalog)
captured = capsys.readouterr()
assert captured.out == f"{src}/@v1\n"
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_listings(test_session, tmp_dir):
df.to_parquet(tmp_dir / "df.parquet")

uri = tmp_dir.as_uri()
dc.from_storage(uri, session=test_session)
dc.from_storage(uri, session=test_session).exec()

# check that listing is not returned as normal dataset
assert not any(
Expand Down Expand Up @@ -370,13 +370,13 @@ def test_listings_reindex(test_session, tmp_dir):

uri = tmp_dir.as_uri()

dc.from_storage(uri, session=test_session)
dc.from_storage(uri, session=test_session).exec()
assert len(list(dc.listings(session=test_session).collect("listing"))) == 1

dc.from_storage(uri, session=test_session)
dc.from_storage(uri, session=test_session).exec()
assert len(list(dc.listings(session=test_session).collect("listing"))) == 1

dc.from_storage(uri, session=test_session, update=True)
dc.from_storage(uri, session=test_session, update=True).exec()
listings = list(dc.listings(session=test_session).collect("listing"))
assert len(listings) == 2
listings.sort(key=lambda lst: lst.version)
Expand Down
Loading