Skip to content

Commit

Permalink
Fix dataset unit rescaling of per-species shifts (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp authored Apr 30, 2024
1 parent c310ad6 commit 9b5b17c
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 69 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Most recent change on the bottom.
- [Breaking] `fixed_fields` machinery (`npz_fixed_field_keys` is still supported, but through a more straightforward implementation)
- Default run name/WandB project name of `NequIP`, they must now always be provided explicitly
- [Breaking] Removed `_params` as an allowable subconfiguration suffix (i.e. instead of `optimizer_params` now only `optimizer_kwargs` is valid, not both)
- [Breaking] Removed `per_species_rescale_arguments_in_dataset_units`

## [0.5.6] - 2022-12-19
### Added
Expand Down
13 changes: 4 additions & 9 deletions configs/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ early_stopping_lower_bounds:
LR: 1.0e-5

early_stopping_upper_bounds: # stop early if the training appears to have exploded
validation_loss: 1.0e4
validation_loss: 1.0e+4

# loss function
loss_coeffs:
Expand Down Expand Up @@ -145,17 +145,12 @@ lr_scheduler_factor: 0.5
# the default is to scale the atomic energy and forces by scaling them by the force standard deviation and to shift the energy by the mean atomic energy
# in certain cases, it can be useful to have a trainable shift/scale and to also have species-dependent shifts/scales for each atom

# whether the shifts and scales are trainable. Defaults to False. Optional
per_species_rescale_shifts_trainable: false
per_species_rescale_scales_trainable: false

# initial atomic energy shift for each species. default to the mean of per atom energy. Optional
# the value can be a constant float value, an array for each species, or a string that defines a statistics over the training dataset
# if numbers are explicitly provided, they must be in the same energy units as the training data
per_species_rescale_shifts: dataset_per_atom_total_energy_mean

# initial atomic energy scale for each species. Optional.
# the value can be a constant float value, an array for each species, or a string
per_species_rescale_scales: dataset_forces_rms

# if explicit numbers are given for the shifts/scales, this parameter must specify whether the given numbers are unitless shifts/scales or are in the units of the dataset. If ``True``, any global rescalings will correctly be applied to the per-species values.
# per_species_rescale_arguments_in_dataset_units: True
# if numbers are explicitly provided, they must be in the same energy units as the training data
per_species_rescale_scales: null
13 changes: 8 additions & 5 deletions configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -307,29 +307,32 @@ per_species_rescale_scales_trainable: false
# whether the scales are trainable. Defaults to False. Optional
per_species_rescale_shifts_trainable: false
# whether the shifts are trainable. Defaults to False. Optional

per_species_rescale_shifts: dataset_per_atom_total_energy_mean
# initial atomic energy shift for each species. default to the mean of per atom energy. Optional
# the value can be a constant float value, an array for each species, or a string
# if numbers are explicitly provided, they must be in the same energy units as the training data
# string option include:
# * "dataset_per_atom_total_energy_mean", which computes the per atom average
# * "dataset_per_species_total_energy_mean", which automatically compute the per atom energy mean using a GP model
per_species_rescale_scales: dataset_forces_rms

per_species_rescale_scales: null
# initial atomic energy scale for each species. Optional.
# the value can be a constant float value, an array for each species, or a string
# if numbers are explicitly provided, they must be in the same energy units as the training data
# string option include:
# * "dataset_forces_absmax", which computes the dataset maxmimum force component magnitude
# * "dataset_per_atom_total_energy_std", which computes the per atom energy std
# * "dataset_per_species_total_energy_std", which uses the GP model uncertainty
# * "dataset_per_species_forces_rms", which compute the force rms for each species
# If not provided, defaults to dataset_per_species_force_rms or dataset_per_atom_total_energy_std, depending on whether forces are being trained.
# If not provided, defaults to null.

# per_species_rescale_kwargs:
# total_energy:
# alpha: 0.001
# max_iteration: 20
# stride: 100
# keywords for ridge regression decomposition of per specie energy. Optional. Defaults to 0.001. The value should be in the range of 1e-3 to 1e-2
# per_species_rescale_arguments_in_dataset_units: True
# if explicit numbers are given for the shifts/scales, this parameter must specify whether the given numbers are unitless shifts/scales or are in the units of the dataset. If ``True``, any global rescalings will correctly be applied to the per-species values.
# keywords for ridge regression decomposition of per species energy. Optional. Defaults to 0.001. The value should be in the range of 1e-3 to 1e-2

# global energy shift and scale
# When "dataset_total_energy_mean", the mean energy of the dataset. When None, disables the global shift. When a number, used directly.
Expand Down
2 changes: 1 addition & 1 deletion configs/minimal_stress.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dataset_include_frames: !!python/object/apply:builtins.range

global_rescale_scale: dataset_total_energy_std
per_species_rescale_shifts: dataset_per_atom_total_energy_mean
per_species_rescale_scales: dataset_per_atom_total_energy_std
per_species_rescale_scales: null

# logging
wandb: false
Expand Down
109 changes: 58 additions & 51 deletions nequip/model/_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def RescaleEnergyEtc(
dataset=dataset,
initialize=initialize,
module_prefix="global_rescale",
default_scale=f"dataset_{AtomicDataDict.FORCE_KEY}_rms"
if AtomicDataDict.FORCE_KEY in model.irreps_out
else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std",
default_scale=(
f"dataset_{AtomicDataDict.FORCE_KEY}_rms"
if AtomicDataDict.FORCE_KEY in model.irreps_out
else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std"
),
default_shift=None,
default_scale_keys=AtomicDataDict.ALL_ENERGY_KEYS,
default_shift_keys=[AtomicDataDict.TOTAL_ENERGY_KEY],
Expand Down Expand Up @@ -129,42 +131,60 @@ def PerSpeciesRescale(
initialize: bool,
dataset: Optional[AtomicDataset] = None,
):
"""Add per-atom rescaling (and shifting) for energy.
If ``initialize`` is false, doesn't compute statistics.
"""
"""Add per-atom rescaling (and shifting) for per-atom energies."""
module_prefix = "per_species_rescale"

# = Determine energy rescale type =
scales = config.get(
module_prefix + "_scales",
f"dataset_{AtomicDataDict.FORCE_KEY}_rms"
# if `train_on_keys` isn't provided, assume conservatively
# that we aren't "training" on anything (i.e. take the
# most general defaults)
if AtomicDataDict.FORCE_KEY in config.get("train_on_keys", [])
else f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_std",
)
shifts = config.get(
module_prefix + "_shifts",
f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_mean",
)

# Check for common double shift mistake with defaults
if "RescaleEnergyEtc" in config.get("model_builders", []):
# if the defaults are enabled, then we will get bad double shift
# THIS CHECK IS ONLY GOOD ENOUGH FOR EMITTING WARNINGS
has_global_shift = config.get("global_rescale_shift", None) is not None
if has_global_shift:
if shifts is not None:
if config.get(module_prefix + "_shifts", True) is not None:
# using default of per_atom shift
raise RuntimeError(
"A global_rescale_shift was provided, but the default per-atom energy shift was not disabled."
)
del has_global_shift

# = Determine what statistics need to be compute =\
arguments_in_dataset_units = None
return _PerSpeciesRescale(
scales_default=None,
shifts_default=f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_mean",
field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
module_prefix=module_prefix,
insert_before="total_energy_sum",
model=model,
config=config,
initialize=initialize,
dataset=dataset,
)


def _PerSpeciesRescale(
scales_default,
shifts_default,
field: str,
out_field: str,
module_prefix: str,
insert_before: str,
model: GraphModuleMixin,
config,
initialize: bool,
dataset: Optional[AtomicDataset] = None,
):
"""Add per-atom rescaling (and shifting) for a field
If ``initialize`` is false, doesn't compute statistics.
"""
scales = config.get(module_prefix + "_scales", scales_default)
shifts = config.get(module_prefix + "_shifts", shifts_default)

# = Determine what statistics need to be compute =
assert config.get(
module_prefix + "_arguments_in_dataset_units", True
), f"The PerSpeciesRescale builder is only compatible with {module_prefix + '_arguments_in_dataset_units'} set to True"

if initialize:
str_names = []
for value in [scales, shifts]:
Expand All @@ -181,20 +201,6 @@ def PerSpeciesRescale(
else:
raise ValueError(f"Invalid value `{value}` of type {type(value)}")

if len(str_names) == 2:
# Both computed from dataset
arguments_in_dataset_units = True
elif len(str_names) == 1:
if None in [scales, shifts]:
# if the one that isnt str is null, it's just disabled
# that has no units
# so it's ok to have just one and to be in dataset units
arguments_in_dataset_units = True
else:
assert config[
module_prefix + "_arguments_in_dataset_units"
], "Requested to set either the shifts or scales of the per_species_rescale using dataset values, but chose to provide the other in non-dataset units. Please give the explictly specified shifts/scales in dataset units and set per_species_rescale_arguments_in_dataset_units"

# = Compute shifts and scales =
if len(str_names) > 0:
computed_stats = _compute_stats(
Expand All @@ -206,21 +212,24 @@ def PerSpeciesRescale(

if isinstance(scales, str):
s = scales
scales = computed_stats[str_names.index(scales)].squeeze(-1) # energy is 1D
# energy or other property is 1D:
scales = computed_stats[str_names.index(scales)].squeeze(-1)
logging.info(f"Replace string {s} to {scales}")
elif isinstance(scales, (list, float)):
scales = torch.as_tensor(scales)

if isinstance(shifts, str):
s = shifts
shifts = computed_stats[str_names.index(shifts)].squeeze(-1) # energy is 1D
# energy or other property is 1D:
shifts = computed_stats[str_names.index(shifts)].squeeze(-1)
logging.info(f"Replace string {s} to {shifts}")
elif isinstance(shifts, (list, float)):
shifts = torch.as_tensor(shifts)

# TODO kind of weird error to check for here
if scales is not None and torch.min(scales) < RESCALE_THRESHOLD:
raise ValueError(
f"Per species energy scaling was very low: {scales}. Maybe try setting {module_prefix}_scales = 1."
f"Per species scaling was very low: {scales}. Maybe try setting {module_prefix}_scales = 1."
)

logging.info(
Expand All @@ -234,22 +243,20 @@ def PerSpeciesRescale(
# so this is fine regardless of whether its trainable.
scales = 1.0 if scales is not None else None
shifts = 0.0 if shifts is not None else None
# values correctly scaled according to where the come from
# will be brought from the state dict later,
# so what you set this to doesnt matter:
arguments_in_dataset_units = False
# values from the previously initialized model
# will be brought in from the state dict later,
# so these values (and rescaling them) doesn't matter

# insert in per species shift
params = dict(
field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
field=field,
out_field=out_field,
shifts=shifts,
scales=scales,
arguments_in_dataset_units=True,
)

params["arguments_in_dataset_units"] = arguments_in_dataset_units
model.insert_from_parameters(
before="total_energy_sum",
before=insert_before,
name=module_prefix,
shared_params=config,
builder=PerSpeciesScaleShift,
Expand Down
9 changes: 6 additions & 3 deletions nequip/nn/_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def __init__(
self.out_field = f"{reduce}_{field}" if out_field is None else out_field
self._init_irreps(
irreps_in=irreps_in,
irreps_out={self.out_field: irreps_in[self.field]}
if self.field in irreps_in
else {},
irreps_out=(
{self.out_field: irreps_in[self.field]}
if self.field in irreps_in
else {}
),
)

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(
else:
self.register_buffer("scales", torch.Tensor())

assert isinstance(arguments_in_dataset_units, bool)
self.arguments_in_dataset_units = arguments_in_dataset_units

# we can use FMA for performance but its type promotion is broken until 1.13
Expand Down

0 comments on commit 9b5b17c

Please sign in to comment.