Skip to content

Commit a921117

Browse files
committed
Last few mypy errors knocked out
1 parent 3dd7753 commit a921117

File tree

3 files changed

+258
-407
lines changed

3 files changed

+258
-407
lines changed

autosklearn/automl.py

+12-36
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,7 @@
9898
from autosklearn.smbo import AutoMLSMBO
9999
from autosklearn.util import RE_PATTERN, pipeline
100100
from autosklearn.util.dask import Dask, LocalDask, UserDask
101-
from autosklearn.util.data import (
102-
DatasetCompressionSpec,
103-
default_dataset_compression_arg,
104-
reduce_dataset_size_if_too_large,
105-
supported_precision_reductions,
106-
validate_dataset_compression_arg,
107-
)
101+
from autosklearn.util.data import DatasetCompression
108102
from autosklearn.util.logging_ import (
109103
PicklableClientLogger,
110104
get_named_client_logger,
@@ -252,15 +246,14 @@ def __init__(
252246
)
253247

254248
# Validate dataset_compression and set its values
255-
self._dataset_compression: DatasetCompressionSpec | None = None
256-
if isinstance(dataset_compression, bool):
257-
if dataset_compression is True:
258-
self._dataset_compression = default_dataset_compression_arg
259-
else:
260-
self._dataset_compression = validate_dataset_compression_arg(
261-
dataset_compression,
262-
memory_limit=memory_limit,
263-
)
249+
self._dataset_compression: DatasetCompression | None = None
250+
if dataset_compression is not False:
251+
252+
if memory_limit is None:
253+
raise ValueError("Must provide a `memory_limit` for data compression")
254+
255+
spec = {} if dataset_compression is True else dataset_compression
256+
self._dataset_compression = DatasetCompression(**spec, limit=memory_limit)
264257

265258
# If we got something callable for `get_trials_callback`, wrap it so SMAC
266259
# will accept it.
@@ -667,30 +660,13 @@ def fit(
667660
X_test, y_test = input_validator.transform(X_test, y_test)
668661

669662
# We don't support size reduction on pandas type object yet
670-
if (
671-
self._dataset_compression is not None
672-
and not isinstance(X, pd.DataFrame)
673-
and not (isinstance(y, pd.Series) or isinstance(y, pd.DataFrame))
674-
):
675-
methods = self._dataset_compression["methods"]
676-
memory_allocation = self._dataset_compression["memory_allocation"]
677-
678-
# Remove precision reduction if we can't perform it
679-
if (
680-
"precision" in methods
681-
and X.dtype not in supported_precision_reductions
682-
):
683-
methods = [method for method in methods if method != "precision"]
684-
663+
if self._dataset_compression and self._dataset_compression.supports(X, y):
685664
with warnings_to(self.logger):
686-
X, y = reduce_dataset_size_if_too_large(
665+
X, y = self._dataset_compression.compress(
687666
X=X,
688667
y=y,
689-
memory_limit=self._memory_limit,
690-
is_classification=self.is_classification,
668+
stratify=self.is_classification,
691669
random_state=self._seed,
692-
operations=methods,
693-
memory_allocation=memory_allocation,
694670
)
695671

696672
# Check the re-sampling strategy

0 commit comments

Comments
 (0)