|
98 | 98 | from autosklearn.smbo import AutoMLSMBO
|
99 | 99 | from autosklearn.util import RE_PATTERN, pipeline
|
100 | 100 | from autosklearn.util.dask import Dask, LocalDask, UserDask
|
101 |
| -from autosklearn.util.data import ( |
102 |
| - DatasetCompressionSpec, |
103 |
| - default_dataset_compression_arg, |
104 |
| - reduce_dataset_size_if_too_large, |
105 |
| - supported_precision_reductions, |
106 |
| - validate_dataset_compression_arg, |
107 |
| -) |
| 101 | +from autosklearn.util.data import DatasetCompression |
108 | 102 | from autosklearn.util.logging_ import (
|
109 | 103 | PicklableClientLogger,
|
110 | 104 | get_named_client_logger,
|
@@ -252,15 +246,14 @@ def __init__(
|
252 | 246 | )
|
253 | 247 |
|
254 | 248 | # Validate dataset_compression and set its values
|
255 |
| - self._dataset_compression: DatasetCompressionSpec | None = None |
256 |
| - if isinstance(dataset_compression, bool): |
257 |
| - if dataset_compression is True: |
258 |
| - self._dataset_compression = default_dataset_compression_arg |
259 |
| - else: |
260 |
| - self._dataset_compression = validate_dataset_compression_arg( |
261 |
| - dataset_compression, |
262 |
| - memory_limit=memory_limit, |
263 |
| - ) |
| 249 | + self._dataset_compression: DatasetCompression | None = None |
| 250 | + if dataset_compression is not False: |
| 251 | + |
| 252 | + if memory_limit is None: |
| 253 | + raise ValueError("Must provide a `memory_limit` for data compression") |
| 254 | + |
| 255 | + spec = {} if dataset_compression is True else dataset_compression |
| 256 | + self._dataset_compression = DatasetCompression(**spec, limit=memory_limit) |
264 | 257 |
|
265 | 258 | # If we got something callable for `get_trials_callback`, wrap it so SMAC
|
266 | 259 | # will accept it.
|
@@ -667,30 +660,13 @@ def fit(
|
667 | 660 | X_test, y_test = input_validator.transform(X_test, y_test)
|
668 | 661 |
|
669 | 662 | # We don't support size reduction on pandas type object yet
|
670 |
| - if ( |
671 |
| - self._dataset_compression is not None |
672 |
| - and not isinstance(X, pd.DataFrame) |
673 |
| - and not (isinstance(y, pd.Series) or isinstance(y, pd.DataFrame)) |
674 |
| - ): |
675 |
| - methods = self._dataset_compression["methods"] |
676 |
| - memory_allocation = self._dataset_compression["memory_allocation"] |
677 |
| - |
678 |
| - # Remove precision reduction if we can't perform it |
679 |
| - if ( |
680 |
| - "precision" in methods |
681 |
| - and X.dtype not in supported_precision_reductions |
682 |
| - ): |
683 |
| - methods = [method for method in methods if method != "precision"] |
684 |
| - |
| 663 | + if self._dataset_compression and self._dataset_compression.supports(X, y): |
685 | 664 | with warnings_to(self.logger):
|
686 |
| - X, y = reduce_dataset_size_if_too_large( |
| 665 | + X, y = self._dataset_compression.compress( |
687 | 666 | X=X,
|
688 | 667 | y=y,
|
689 |
| - memory_limit=self._memory_limit, |
690 |
| - is_classification=self.is_classification, |
| 668 | + stratify=self.is_classification, |
691 | 669 | random_state=self._seed,
|
692 |
| - operations=methods, |
693 |
| - memory_allocation=memory_allocation, |
694 | 670 | )
|
695 | 671 |
|
696 | 672 | # Check the re-sampling strategy
|
|
0 commit comments