Skip to content

Commit

Permalink
Feature/ppca (bystrogenomics#235)
Browse files Browse the repository at this point in the history
This implements the basic maximum likelihood PPCA model along with a
base class for all future PPCA/factor analysis models.
  • Loading branch information
austinTalbot7241993 authored Sep 6, 2023
1 parent c694619 commit 6dbf870
Show file tree
Hide file tree
Showing 10 changed files with 719 additions and 23 deletions.
18 changes: 12 additions & 6 deletions python/python/bystro/_template_npr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Objects
-------
_BaseNPRModel(object)
BaseNumpyroModel(mcmc_options=None,hp_options=None)
This is the template Numpyro model
"""
Expand All @@ -19,7 +19,11 @@
import numpyro # type: ignore


class _BaseNumpyroModel:
class BaseNumpyroModel(abc.ABC):
"""
The template for a numpyro-based model
"""

def __init__(self, mcmc_options=None, hp_options=None):
"""
Expand Down Expand Up @@ -75,11 +79,12 @@ def pickle(self, path):
-------
"""
assert self.samples is not None, "Fit model first"
mydict = {"model": self}
with open(path, "wb") as f:
cloudpickle.dump(self.samples, f)
cloudpickle.dump(mydict, f)

def unpickle(self, path):
@classmethod
def unpickle(cls, path):
"""
This loads samples from a previously saved model
Expand All @@ -91,7 +96,8 @@ def unpickle(self, path):
"""
with open(path, "rb") as f:
self.samples = cloudpickle.load(f)
myDict = cloudpickle.load(f)
return myDict["model"]

def _fill_mcmc_options(self, mcmc_options):
"""
Expand Down
34 changes: 17 additions & 17 deletions python/python/bystro/_template_sgd_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
Objects
-------
_BaseSGDModel(object):
BaseSGDModel(training_options=None)
Methods
-------
Expand All @@ -14,12 +15,13 @@
import cloudpickle # type: ignore


class _BaseSGDModel(object):
class BaseSGDModel(abc.ABC):
"""
The base class of a model relying on stochastic gradient descent for
inference
"""

def __init__(self, training_options=None):
"""
The base class of a model relying on stochastic gradient descent for
inference
"""
if training_options is None:
training_options = {}
self.training_options = self._fill_training_options(training_options)
Expand Down Expand Up @@ -51,8 +53,8 @@ def pickle(self, path):
with open(path, "wb") as f:
cloudpickle.dump(mydict, f)

@abc.abstractmethod
def unpickle(self, path):
@classmethod
def unpickle(cls, path):
"""
Method for loading the model
Expand All @@ -61,6 +63,9 @@ def unpickle(self, path):
path : str
The directory to load the model from
"""
with open(path, "rb") as f:
myDict = cloudpickle.load(f)
return myDict["model"]

def _fill_training_options(self, training_options):
"""
Expand All @@ -79,16 +84,15 @@ def _fill_training_options(self, training_options):
return training_opts

@abc.abstractmethod
def _save_variables(self, training_variables):
def _store_instance_variables(self, trainable_variables):
"""
This saves the final parameter values after training
Saves the learned variables
Parameters
----------
training_variables :list
The variables trained
trainable_variables : list
List of variables to save
"""
raise NotImplementedError("_save_variables")

@abc.abstractmethod
def _initialize_save_losses(self):
Expand All @@ -99,7 +103,6 @@ def _initialize_save_losses(self):
Parameters
----------
"""
raise NotImplementedError("_initialize_save_losses")

@abc.abstractmethod
def _save_losses(self, *args):
Expand All @@ -109,7 +112,6 @@ def _save_losses(self, *args):
Parameters
----------
"""
raise NotImplementedError("_save_losses")

@abc.abstractmethod
def _test_inputs(self, *args):
Expand All @@ -119,7 +121,6 @@ def _test_inputs(self, *args):
Parameters
----------
"""
raise NotImplementedError("_transform_training_data")

@abc.abstractmethod
def _transform_training_data(self, *args):
Expand All @@ -129,4 +130,3 @@ def _transform_training_data(self, *args):
Parameters
----------
"""
raise NotImplementedError("_transform_training_data")
Loading

0 comments on commit 6dbf870

Please sign in to comment.