Skip to content

Commit

Permalink
Merge pull request #21 from rkansal47/main
Browse files Browse the repository at this point in the history
Dataset normalisation bug fix
  • Loading branch information
rkansal47 authored Sep 10, 2022
2 parents 1ba2ff8 + 8767301 commit 9953d1a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion jetnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
import jetnet.datasets.utils
import jetnet.datasets.normalisations

__version__ = "0.2.1.post1"
__version__ = "0.2.1.post2"
6 changes: 4 additions & 2 deletions jetnet/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def __init__(

if self.use_particle_features:
if self.particle_normalisation is not None:
self.particle_normalisation.derive_dataset_features(self.particle_data)
if self.particle_normalisation.features_need_deriving():
self.particle_normalisation.derive_dataset_features(self.particle_data)
self.particle_data = self.particle_normalisation(self.particle_data)

if self.use_jet_features:
if self.jet_normalisation is not None:
self.jet_normalisation.derive_dataset_features(self.jet_data)
if self.jet_normalisation.features_need_deriving():
self.jet_normalisation.derive_dataset_features(self.jet_data)
self.jet_data = self.jet_normalisation(self.jet_data)

self.particle_transform = particle_transform
Expand Down
21 changes: 17 additions & 4 deletions jetnet/datasets/normalisations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class NormaliseABC(ABC):
ABC for generalised normalisation class.
"""

def features_need_deriving(self) -> bool:
"""Checks if any dataset values or features need to be derived"""
return False

def derive_dataset_features(self, x: ArrayLike):
"""Derive features from dataset needed for normalisation if needed"""
pass
Expand Down Expand Up @@ -86,7 +90,15 @@ def derive_dataset_features(self, x: ArrayLike) -> Optional[Tuple[np.ndarray, np
self.feature_scales = 1.0 / np.std(x.reshape(-1, num_features), axis=0)
return self.feature_shifts, self.feature_scales

def features_need_deriving(self) -> bool:
"""Checks if any dataset values or features need to be derived"""
return (self.feature_shifts is None) or (self.feature_scales is None)

def __call__(self, x: ArrayLike, inverse: bool = False, inplace: bool = False) -> ArrayLike:
assert (
not self.features_need_deriving()
), "Feature means and stds have not been specified, you need to either set or derive them first"

num_features = x.shape[-1]

if isinstance(self.feature_shifts, float):
Expand Down Expand Up @@ -175,9 +187,6 @@ class FeaturewiseLinearBounded(NormaliseABC):
"""

# stores pre-scaled max absolute value of each feature for normalising and inverting
feature_maxes = None

def __init__(
self,
feature_norms: Union[float, List[float]] = 1.0,
Expand Down Expand Up @@ -207,9 +216,13 @@ def derive_dataset_features(self, x: ArrayLike) -> np.ndarray:
self.feature_maxes = np.max(np.abs(x.reshape(-1, num_features)), axis=0)
return self.feature_maxes

def features_need_deriving(self) -> bool:
"""Checks if any dataset values or features need to be derived"""
return self.feature_maxes is None

def __call__(self, x: ArrayLike, inverse: bool = False, inplace: bool = False) -> ArrayLike:
assert (
self.feature_maxes is not None
not self.features_need_deriving()
), "Feature maxes have not been specified, you need to either set or derive them first"

num_features = x.shape[-1]
Expand Down

0 comments on commit 9953d1a

Please sign in to comment.