@@ -868,6 +868,55 @@ def _evaluate_true(self, X: Tensor) -> Tensor:
868
868
)
869
869
870
870
871
+ class AckleyMixed (SyntheticTestFunction ):
872
+ r"""Mixed search space version of the Ackley problem.
873
+
874
+ This problem has dim-3 binary parameters and 3 continuous parameters in the
875
+ range [0, 1]. This means dim > 3 is required. To make the problem a bit more
876
+ interesting, the optimal value is not at the origin, but rather at `x_opt`
877
+ which is a randomly generated point.
878
+
879
+ The goal is to minimize f(x) = Ackley(x - x_opt).
880
+ """
881
+
882
+ _optimal_value = 0.0
883
+
884
+ def __init__ (
885
+ self ,
886
+ dim = 53 ,
887
+ noise_std : float | None = None ,
888
+ negate : bool = False ,
889
+ randomize_optimum : bool = False ,
890
+ dtype : torch .dtype = torch .double ,
891
+ ) -> None :
892
+ r"""
893
+ Args:
894
+ dim: The (input) dimension. Must be > 3.
895
+ noise_std: Standard deviation of the observation noise.
896
+ negate: If True, negate the function.
897
+ randomize_optimum: If True, the optimum is a random point in the domain.
898
+ dtype: The dtype that is used for the bounds of the function.
899
+ """
900
+ if dim <= 3 :
901
+ raise ValueError (f"Expected dim > 3. Got { dim = } ." )
902
+ if randomize_optimum :
903
+ x_opt = torch .rand (dim , dtype = dtype )
904
+ x_opt [: dim - 3 ] = x_opt [: dim - 3 ].round ()
905
+ else :
906
+ x_opt = torch .zeros (dim , dtype = dtype )
907
+ self ._optimizers = [tuple (x .item () for x in x_opt )]
908
+ self .dim = dim
909
+ self .discrete_inds = list (range (0 , dim - 3 ))
910
+ self .continuous_inds = list (range (dim - 3 , dim ))
911
+ bounds = [(0.0 , 1.0 ) for _ in range (self .dim )]
912
+ super ().__init__ (bounds = bounds , dtype = dtype , noise_std = noise_std , negate = negate )
913
+ self .register_buffer ("x_opt" , x_opt )
914
+ self ._ackley = Ackley (dim = dim , dtype = dtype )
915
+
916
+ def _evaluate_true (self , X : Tensor ) -> Tensor :
917
+ return self ._ackley .evaluate_true ((X - self .x_opt ).abs ())
918
+
919
+
871
920
# ------------ Constrained synthetic test functions ----------- #
872
921
873
922
0 commit comments