7
7
import numpy as np
8
8
import torch
9
9
from energyflow .emd import emds
10
- from numba import njit
10
+ from numba import njit , prange , set_num_threads
11
11
from numpy .typing import ArrayLike
12
12
from scipy import linalg
13
13
from scipy .optimize import curve_fit
@@ -747,13 +747,47 @@ def _mmd_poly_quadratic_unbiased(X: ArrayLike, Y: ArrayLike, degree: int = 4) ->
747
747
return _mmd_quadratic_unbiased (XX , YY , XY )
748
748
749
749
750
+ @njit (parallel = True )
751
+ def _kpd_batches_parallel (X , Y , num_batches , batch_size , seed ):
752
+ vals_point = np .zeros (num_batches , dtype = np .float64 )
753
+ for i in prange (num_batches ):
754
+ np .random .seed (seed + i * 1000 ) # in case of multi-threading
755
+ rand1 = np .random .choice (len (X ), size = batch_size )
756
+ rand2 = np .random .choice (len (Y ), size = batch_size )
757
+
758
+ rand_sample1 = X [rand1 ]
759
+ rand_sample2 = Y [rand2 ]
760
+
761
+ val = _mmd_poly_quadratic_unbiased (rand_sample1 , rand_sample2 , degree = 4 )
762
+ vals_point [i ] = val
763
+
764
+ return vals_point
765
+
766
+
767
+ def _kpd_batches (X , Y , num_batches , batch_size , seed ):
768
+ vals_point = []
769
+ for i in range (num_batches ):
770
+ np .random .seed (seed + i * 1_000 )
771
+ rand1 = np .random .choice (len (X ), size = batch_size )
772
+ rand2 = np .random .choice (len (Y ), size = batch_size )
773
+
774
+ rand_sample1 = X [rand1 ]
775
+ rand_sample2 = Y [rand2 ]
776
+
777
+ val = _mmd_poly_quadratic_unbiased (rand_sample1 , rand_sample2 )
778
+ vals_point .append (val )
779
+
780
+ return vals_point
781
+
782
+
750
783
def kpd (
751
784
real_features : Union [Tensor , np .ndarray ],
752
785
gen_features : Union [Tensor , np .ndarray ],
753
786
num_batches : int = 10 ,
754
787
batch_size : int = 5_000 ,
755
788
normalise : bool = True ,
756
789
seed : int = 42 ,
790
+ num_threads : int = None ,
757
791
) -> Tuple [float , float ]:
758
792
"""Calculates the median and error of the kernel physics distance (KPD) between a set of real
759
793
and generated features, as defined in https://arxiv.org/abs/2211.10295.
@@ -768,12 +802,15 @@ def kpd(
768
802
real_features (Union[Tensor, np.ndarray]): set of real features of shape
769
803
``[num_samples, num_features]``.
770
804
gen_features (Union[Tensor, np.ndarray]): set of generated features of shape
771
- ``[num_samples, num_features]``.
805
+ ``[num_samples, num_features]``.
772
806
num_batches (int, optional): number of batches to average over. Defaults to 10.
773
807
batch_size (int, optional): size of each batch for which MMD is measured. Defaults to 5,000.
774
808
normalise (bool, optional): normalise the individual features over the full sample to have
775
809
the same scaling. Defaults to True.
776
810
seed (int, optional): random seed. Defaults to 42.
811
+ num_threads (int, optional): parallelize KPD through numba using this many threads. 0 means
812
+ numba's default number of threads, based on # of cores available. Defaults to None, i.e.
813
+ no parallelization.
777
814
778
815
Returns:
779
816
Tuple[float, float]: median and error of KPD.
@@ -783,17 +820,13 @@ def kpd(
783
820
if normalise :
784
821
X , Y = _normalise_features (real_features , gen_features )
785
822
786
- vals_point = []
787
- for i in range (num_batches ):
788
- np .random .seed (seed + i * 1_000 )
789
- rand1 = np .random .choice (len (X ), size = batch_size )
790
- rand2 = np .random .choice (len (Y ), size = batch_size )
791
-
792
- rand_sample1 = X [rand1 ]
793
- rand_sample2 = Y [rand2 ]
823
+ if num_threads is None :
824
+ vals_point = _kpd_batches (X , Y , num_batches , batch_size , seed )
825
+ else :
826
+ if num_threads > 0 :
827
+ set_num_threads (num_threads )
794
828
795
- val = _mmd_poly_quadratic_unbiased (rand_sample1 , rand_sample2 )
796
- vals_point .append (val )
829
+ vals_point = _kpd_batches_parallel (X , Y , num_batches , batch_size , seed )
797
830
798
831
# median, error = half of 16 - 84 IQR
799
832
return (np .median (vals_point ), iqr (vals_point , rng = (16.275 , 83.725 )) / 2 )
0 commit comments