Skip to content

Commit

Permalink
add new type of model, which is a sequence of other models, allowing …
Browse files Browse the repository at this point in the history
…for much more flexibility
  • Loading branch information
wilsonmr committed Apr 28, 2021
1 parent c42a451 commit 1b487f0
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 4 deletions.
20 changes: 19 additions & 1 deletion anvil/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from anvil.geometry import Geometry2D
from anvil.checkpoint import TrainingOutput
from anvil.models import MODEL_OPTIONS
from anvil.models import MODEL_OPTIONS, LOADED_MODEL_OPTIONS
from anvil.distributions import BASE_OPTIONS, TARGET_OPTIONS

from random import randint
Expand Down Expand Up @@ -89,6 +89,24 @@ def produce_model_action(self, model: str):
except KeyError:
raise ConfigError(f"Invalid model {model}", model, MODEL_OPTIONS.keys())

@explicit_node
def produce_model_to_load(self, model: str, model_params):
"""Decides whether to load sequential model or a preset combination"""
if isinstance(model_params, list):
inner_models = {inner.get("model") for inner in model_params}
if ("sequential_model" in inner_models) or (None in inner_models):
raise ConfigError(
"Inner models cannot be undefined or `sequential_model`",
inner_models,
MODEL_OPTIONS.keys()
)
if model != "sequential_model":
raise ConfigError(
"model_params can only be a list when the model is `sequential_model`"
)
return LOADED_MODEL_OPTIONS["sequential_model"]
return LOADED_MODEL_OPTIONS["preset_model"]

def parse_n_batch(self, nb: int):
"""Batch size for training."""
return nb
Expand Down
7 changes: 5 additions & 2 deletions anvil/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
r"""
coupling.py
"""
core.py
Module containing project specific extensions to pytorch base classes.
"""
import torch
import torch.nn as nn
Expand Down
16 changes: 15 additions & 1 deletion anvil/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,28 @@ def affine_spline(real_nvp, rational_quadratic_spline):

_normalising_flow = collect("model_action", ("model_params",))

def model_to_load(_normalising_flow):
def preset_model(_normalising_flow):
return _normalising_flow[0]


def sequential_model(_normalising_flow):
"""action which wraps a list of affine models in
:py:class:`anvil.core.Sequential`. This allows the user to specify an
arbitrary combination of layers as the model
"""
return Sequential(*_normalising_flow)

MODEL_OPTIONS = {
"nice": nice,
"real_nvp": real_nvp,
"rational_quadratic_spline": rational_quadratic_spline,
"spline_affine": spline_affine,
"affine_spline": affine_spline,
}


LOADED_MODEL_OPTIONS = {
"preset_model": preset_model,
"sequential_model": sequential_model
}
46 changes: 46 additions & 0 deletions examples/runcards/train_sequential_model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Example of how to specify a custom sequential model explicitly.

# Lattice
lattice_length: 6
lattice_dimension: 2

# Target
target: phi_four
parameterisation: albergo2019
couplings:
m_sq: -4
lam: 6.975

# Model
base: gaussian

model: sequential_model

model_params:
- model: real_nvp
n_affine: 2
z2_equivar: true
activation: tanh
hidden_shape: [72]
- model: rational_quadratic_spline
n_spline: 1
n_segments: 8
z2_equivar_spline: false
activation: tanh
hidden_shape: [72]

# Training
n_batch: 1000
epochs: 2000
save_interval: 1000

# Optimizer
optimizer: Adam
optimizer_params:
lr: 0.005

# Scheduler
scheduler: CosineAnnealingLR
scheduler_params:
T_max: 2000

0 comments on commit 1b487f0

Please sign in to comment.