Skip to content

Commit

Permalink
Merge pull request #69 from jet-net/norm_bug_fix
Browse files Browse the repository at this point in the history
FPD and KPD normalization bug fix
  • Loading branch information
rkansal47 authored Feb 9, 2024
2 parents 77effb6 + 56ea542 commit 32d07d8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
8 changes: 4 additions & 4 deletions jetnet/evaluation/gen_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,10 @@ def fpd(
stacklevel=2,
)

real_features, gen_features = _check_get_ndarray(real_features, gen_features)
X, Y = _check_get_ndarray(real_features, gen_features)

if normalise:
X, Y = _normalise_features(real_features, gen_features)
X, Y = _normalise_features(X, Y)

# regular intervals in 1/N
batches = (1 / np.linspace(1.0 / min_samples, 1.0 / max_samples, num_points)).astype("int32")
Expand Down Expand Up @@ -836,10 +836,10 @@ def kpd(
Returns:
Tuple[float, float]: median and error of KPD.
"""
real_features, gen_features = _check_get_ndarray(real_features, gen_features)
X, Y = _check_get_ndarray(real_features, gen_features)

if normalise:
X, Y = _normalise_features(real_features, gen_features)
X, Y = _normalise_features(X, Y)

if num_threads is None:
vals_point = _kpd_batches(X, Y, num_batches, batch_size, seed)
Expand Down
24 changes: 21 additions & 3 deletions tests/evaluation/test_gen_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,37 @@

test_zeros = np.zeros((50_000, 2))
test_ones = np.ones((50_000, 2))
test_twos = np.ones((50_000, 2)) * 2


def test_fpd():
val, err = evaluation.fpd(test_zeros, test_zeros)
assert val == approx(0, abs=0.01)
assert err < 1e-3

val, err = evaluation.fpd(test_zeros, test_ones)
assert val == approx(2, rel=0.01)
val, err = evaluation.fpd(test_twos, test_zeros)
assert val == approx(2, rel=0.01) # 1^2 + 1^2
assert err < 1e-3

# test normalization
val, err = evaluation.fpd(test_zeros, test_zeros, normalise=False) # should have no effect
assert val == approx(0, abs=0.01)
assert err < 1e-3

val, err = evaluation.fpd(test_twos, test_zeros, normalise=False)
assert val == approx(8, rel=0.01) # 2^2 + 2^2
assert err < 1e-3


@pytest.mark.parametrize("num_threads", [None, 2]) # test numba parallelization
def test_kpd(num_threads):
assert evaluation.kpd(test_zeros, test_zeros, num_threads=num_threads) == approx([0, 0])
assert evaluation.kpd(test_zeros, test_ones, num_threads=num_threads) == approx([15, 0])
assert evaluation.kpd(test_twos, test_zeros, num_threads=num_threads) == approx([15, 0])

# test normalization
assert evaluation.kpd(
test_zeros, test_zeros, normalise=False, num_threads=num_threads
) == approx([0, 0])
assert evaluation.kpd(
test_twos, test_zeros, normalise=False, num_threads=num_threads
) == approx([624, 0])

0 comments on commit 32d07d8

Please sign in to comment.