47
47
QueryScriptCancelError ,
48
48
)
49
49
from datachain .func .base import Function
50
+ from datachain .lib .listing import is_listing_dataset
50
51
from datachain .lib .udf import UDFAdapter , _get_cache
51
52
from datachain .progress import CombinedDownloadCallback , TqdmCombinedDownloadCallback
52
53
from datachain .query .schema import C , UDFParamSpec , normalize_param
@@ -151,13 +152,6 @@ def step_result(
151
152
)
152
153
153
154
154
- class StartingStep (ABC ):
155
- """An initial query processing step, referencing a data source."""
156
-
157
- @abstractmethod
158
- def apply (self ) -> "StepResult" : ...
159
-
160
-
161
155
@frozen
162
156
class Step (ABC ):
163
157
"""A query processing step (filtering, mutation, etc.)"""
@@ -170,7 +164,7 @@ def apply(
170
164
171
165
172
166
@frozen
173
- class QueryStep ( StartingStep ) :
167
+ class QueryStep :
174
168
catalog : "Catalog"
175
169
dataset_name : str
176
170
dataset_version : int
@@ -1097,26 +1091,42 @@ def __init__(
1097
1091
self .temp_table_names : list [str ] = []
1098
1092
self .dependencies : set [DatasetDependencyType ] = set ()
1099
1093
self .table = self .get_table ()
1100
- self .starting_step : StartingStep
1094
+ self .starting_step : Optional [ QueryStep ] = None
1101
1095
self .name : Optional [str ] = None
1102
1096
self .version : Optional [int ] = None
1103
1097
self .feature_schema : Optional [dict ] = None
1104
1098
self .column_types : Optional [dict [str , Any ]] = None
1099
+ self .before_steps : list [Callable ] = []
1105
1100
1106
- self .name = name
1101
+ self .list_ds_name : Optional [ str ] = None
1107
1102
1108
- if fallback_to_studio and is_token_set ():
1109
- ds = self .catalog .get_dataset_with_remote_fallback (name , version )
1103
+ self .name = name
1104
+ self .dialect = self .catalog .warehouse .db .dialect
1105
+ if version :
1106
+ self .version = version
1107
+
1108
+ if is_listing_dataset (name ):
1109
+ # not setting query step yet as listing dataset might not exist at
1110
+ # this point
1111
+ self .list_ds_name = name
1112
+ elif fallback_to_studio and is_token_set ():
1113
+ self ._set_starting_step (
1114
+ self .catalog .get_dataset_with_remote_fallback (name , version )
1115
+ )
1110
1116
else :
1111
- ds = self .catalog .get_dataset (name )
1117
+ self ._set_starting_step (self .catalog .get_dataset (name ))
1118
+
1119
+ def _set_starting_step (self , ds : "DatasetRecord" ) -> None :
1120
+ if not self .version :
1121
+ self .version = ds .latest_version
1112
1122
1113
- self .version = version or ds .latest_version
1123
+ self .starting_step = QueryStep (self .catalog , ds .name , self .version )
1124
+
1125
+ # at this point we know our starting dataset so setting up schemas
1114
1126
self .feature_schema = ds .get_version (self .version ).feature_schema
1115
1127
self .column_types = copy (ds .schema )
1116
1128
if "sys__id" in self .column_types :
1117
1129
self .column_types .pop ("sys__id" )
1118
- self .starting_step = QueryStep (self .catalog , name , self .version )
1119
- self .dialect = self .catalog .warehouse .db .dialect
1120
1130
1121
1131
def __iter__ (self ):
1122
1132
return iter (self .db_results ())
@@ -1180,11 +1190,23 @@ def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
1180
1190
col .table = self .table
1181
1191
return col
1182
1192
1193
+ def add_before_steps (self , fn : Callable ) -> None :
1194
+ """
1195
+ Setting custom function to be run before applying steps
1196
+ """
1197
+ self .before_steps .append (fn )
1198
+
1183
1199
def apply_steps (self ) -> QueryGenerator :
1184
1200
"""
1185
1201
Apply the steps in the query and return the resulting
1186
1202
sqlalchemy.SelectBase.
1187
1203
"""
1204
+ for fn in self .before_steps :
1205
+ fn ()
1206
+
1207
+ if self .list_ds_name :
1208
+ # at this point we know what is our starting listing dataset name
1209
+ self ._set_starting_step (self .catalog .get_dataset (self .list_ds_name )) # type: ignore [arg-type]
1188
1210
query = self .clone ()
1189
1211
1190
1212
index = os .getenv ("DATACHAIN_QUERY_CHUNK_INDEX" , self ._chunk_index )
@@ -1203,6 +1225,7 @@ def apply_steps(self) -> QueryGenerator:
1203
1225
query = query .filter (C .sys__rand % total == index )
1204
1226
query .steps = query .steps [- 1 :] + query .steps [:- 1 ]
1205
1227
1228
+ assert query .starting_step
1206
1229
result = query .starting_step .apply ()
1207
1230
self .dependencies .update (result .dependencies )
1208
1231
0 commit comments