Skip to content

Commit 6bad3af

Browse files
authoredMar 28, 2023
Merge pull request #35 from rkansal47/main
Parallelize KPD
2 parents c7e249b + 447eb4b commit 6bad3af

File tree

3 files changed

+50
-18
lines changed

3 files changed

+50
-18
lines changed
 

‎jetnet/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# __init__
22

3-
# IMPORTANT: evaluation has to be imported first since energyflow must be imported before torch
4-
# See https://github.com/pkomiske/EnergyFlow/issues/24
53
import jetnet.datasets # noqa: F401
64
import jetnet.datasets.normalisations
75
import jetnet.datasets.utils
86
import jetnet.evaluation
97
import jetnet.losses
108
import jetnet.utils # noqa: F401
119

12-
__version__ = "0.2.3.post1"
10+
__version__ = "0.2.3.post2"

‎jetnet/evaluation/gen_metrics.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import torch
99
from energyflow.emd import emds
10-
from numba import njit
10+
from numba import njit, prange, set_num_threads
1111
from numpy.typing import ArrayLike
1212
from scipy import linalg
1313
from scipy.optimize import curve_fit
@@ -747,13 +747,47 @@ def _mmd_poly_quadratic_unbiased(X: ArrayLike, Y: ArrayLike, degree: int = 4) ->
747747
return _mmd_quadratic_unbiased(XX, YY, XY)
748748

749749

750+
@njit(parallel=True)
751+
def _kpd_batches_parallel(X, Y, num_batches, batch_size, seed):
752+
vals_point = np.zeros(num_batches, dtype=np.float64)
753+
for i in prange(num_batches):
754+
np.random.seed(seed + i * 1000) # in case of multi-threading
755+
rand1 = np.random.choice(len(X), size=batch_size)
756+
rand2 = np.random.choice(len(Y), size=batch_size)
757+
758+
rand_sample1 = X[rand1]
759+
rand_sample2 = Y[rand2]
760+
761+
val = _mmd_poly_quadratic_unbiased(rand_sample1, rand_sample2, degree=4)
762+
vals_point[i] = val
763+
764+
return vals_point
765+
766+
767+
def _kpd_batches(X, Y, num_batches, batch_size, seed):
768+
vals_point = []
769+
for i in range(num_batches):
770+
np.random.seed(seed + i * 1_000)
771+
rand1 = np.random.choice(len(X), size=batch_size)
772+
rand2 = np.random.choice(len(Y), size=batch_size)
773+
774+
rand_sample1 = X[rand1]
775+
rand_sample2 = Y[rand2]
776+
777+
val = _mmd_poly_quadratic_unbiased(rand_sample1, rand_sample2)
778+
vals_point.append(val)
779+
780+
return vals_point
781+
782+
750783
def kpd(
751784
real_features: Union[Tensor, np.ndarray],
752785
gen_features: Union[Tensor, np.ndarray],
753786
num_batches: int = 10,
754787
batch_size: int = 5_000,
755788
normalise: bool = True,
756789
seed: int = 42,
790+
num_threads: int = None,
757791
) -> Tuple[float, float]:
758792
"""Calculates the median and error of the kernel physics distance (KPD) between a set of real
759793
and generated features, as defined in https://arxiv.org/abs/2211.10295.
@@ -768,12 +802,15 @@ def kpd(
768802
real_features (Union[Tensor, np.ndarray]): set of real features of shape
769803
``[num_samples, num_features]``.
770804
gen_features (Union[Tensor, np.ndarray]): set of generated features of shape
771-
``[num_samples, num_features]``.
805+
``[num_samples, num_features]``.
772806
num_batches (int, optional): number of batches to average over. Defaults to 10.
773807
batch_size (int, optional): size of each batch for which MMD is measured. Defaults to 5,000.
774808
normalise (bool, optional): normalise the individual features over the full sample to have
775809
the same scaling. Defaults to True.
776810
seed (int, optional): random seed. Defaults to 42.
811+
num_threads (int, optional): parallelize KPD through numba using this many threads. 0 means
812+
numba's default number of threads, based on # of cores available. Defaults to None, i.e.
813+
no parallelization.
777814
778815
Returns:
779816
Tuple[float, float]: median and error of KPD.
@@ -783,17 +820,13 @@ def kpd(
783820
if normalise:
784821
X, Y = _normalise_features(real_features, gen_features)
785822

786-
vals_point = []
787-
for i in range(num_batches):
788-
np.random.seed(seed + i * 1_000)
789-
rand1 = np.random.choice(len(X), size=batch_size)
790-
rand2 = np.random.choice(len(Y), size=batch_size)
791-
792-
rand_sample1 = X[rand1]
793-
rand_sample2 = Y[rand2]
823+
if num_threads is None:
824+
vals_point = _kpd_batches(X, Y, num_batches, batch_size, seed)
825+
else:
826+
if num_threads > 0:
827+
set_num_threads(num_threads)
794828

795-
val = _mmd_poly_quadratic_unbiased(rand_sample1, rand_sample2)
796-
vals_point.append(val)
829+
vals_point = _kpd_batches_parallel(X, Y, num_batches, batch_size, seed)
797830

798831
# median, error = half of 16 - 84 IQR
799832
return (np.median(vals_point), iqr(vals_point, rng=(16.275, 83.725)) / 2)

‎tests/evaluation/test_gen_metrics.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_fpd():
1818
assert err < 1e-3
1919

2020

21-
def test_kpd():
22-
assert evaluation.kpd(test_zeros, test_zeros) == approx([0, 0])
23-
assert evaluation.kpd(test_zeros, test_ones) == approx([15, 0])
21+
@pytest.mark.parametrize("num_threads", [None, 0, 2]) # test numba parallelization
22+
def test_kpd(num_threads):
23+
assert evaluation.kpd(test_zeros, test_zeros, num_threads=num_threads) == approx([0, 0])
24+
assert evaluation.kpd(test_zeros, test_ones, num_threads=num_threads) == approx([15, 0])

0 commit comments

Comments
 (0)
Please sign in to comment.