forked from bystrogenomics/bystro
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of github.com:bystrogenomics/bystro into featur…
…e/api
- Loading branch information
Showing
24 changed files
with
2,309 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
""" | ||
This provides a basic template for any model that uses numpyro as an | ||
inference method. It has several methods that should be filled by any | ||
object extending the template, namely | ||
fit | ||
_fill_hp_options | ||
These are the methods for running the samples given data and providing | ||
hyperameter selections respectively. | ||
Objects | ||
------- | ||
_BaseNPRModel(object) | ||
This is the template Numpyro model | ||
""" | ||
import abc | ||
import cloudpickle # type: ignore | ||
import numpyro # type: ignore | ||
|
||
|
||
class _BaseNumpyroModel: | ||
def __init__(self, mcmc_options=None, hp_options=None): | ||
""" | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
""" | ||
if mcmc_options is None: | ||
mcmc_options = {} | ||
if hp_options is None: | ||
hp_options = {} | ||
self.mcmc_options = self._fill_mcmc_options(mcmc_options) | ||
self.hp_options = self._fill_hp_options(hp_options) | ||
self.samples = None | ||
|
||
@abc.abstractmethod | ||
def fit(self, *args): | ||
""" | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
""" | ||
|
||
def render_model(self): | ||
""" | ||
This provides a graphical representation of the model | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
""" | ||
assert self._model is not None, "Fit model first" | ||
return numpyro.render_model(self._model, model_kwargs=self.model_kwargs) | ||
|
||
def pickle(self, path): | ||
""" | ||
This saves samples from a fit model | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
""" | ||
assert self.samples is not None, "Fit model first" | ||
with open(path, "wb") as f: | ||
cloudpickle.dump(self.samples, f) | ||
|
||
def unpickle(self, path): | ||
""" | ||
This loads samples from a previously saved model | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
""" | ||
with open(path, "rb") as f: | ||
self.samples = cloudpickle.load(f) | ||
|
||
def _fill_mcmc_options(self, mcmc_options): | ||
""" | ||
This fills in default MCMC options of the sampler. Further methods | ||
might override these but these are common/basic enough to leave in | ||
as an implemented method. | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
""" | ||
default_options = { | ||
"num_chains": 1, | ||
"num_warmup": 500, | ||
"num_samples": 2000, | ||
} | ||
mopts = {**default_options, **mcmc_options} | ||
return mopts | ||
|
||
@abc.abstractmethod | ||
def _fill_hp_options(self, hp_options): | ||
""" | ||
This fills in default hyperparameters of the model. Since these are | ||
not conserved between models we leave this as an abstract method | ||
to be filled in per model. | ||
Parameters | ||
---------- | ||
Returns | ||
------- | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
""" | ||
This implements a base class for any model using stochastic gradient | ||
descent-based techniques for inference. | ||
Objects | ||
------- | ||
_BaseSGDModel(object): | ||
Methods | ||
------- | ||
None | ||
""" | ||
import abc | ||
import cloudpickle # type: ignore | ||
|
||
|
||
class _BaseSGDModel(object): | ||
def __init__(self, training_options=None): | ||
""" | ||
The base class of a model relying on stochastic gradient descent for | ||
inference | ||
""" | ||
if training_options is None: | ||
training_options = {} | ||
self.training_options = self._fill_training_options(training_options) | ||
|
||
@abc.abstractmethod | ||
def fit(self, *args, **kwargs): | ||
""" | ||
Method for fitting | ||
Parameters | ||
---------- | ||
*args: | ||
List of arguments | ||
*kwargs: | ||
Key word arguments | ||
""" | ||
|
||
def pickle(self, path): | ||
""" | ||
Method for saving the model | ||
Parameters | ||
---------- | ||
path : str | ||
The directory to save the model to | ||
""" | ||
mydict = {"model": self} | ||
with open(path, "wb") as f: | ||
cloudpickle.dump(mydict, f) | ||
|
||
@abc.abstractmethod | ||
def unpickle(self, path): | ||
""" | ||
Method for loading the model | ||
Parameters | ||
---------- | ||
path : str | ||
The directory to load the model from | ||
""" | ||
|
||
def _fill_training_options(self, training_options): | ||
""" | ||
This fills any relevant parameters for the learning algorithm | ||
Parameters | ||
---------- | ||
training_options : dict | ||
Returns | ||
------- | ||
training_opts : dict | ||
""" | ||
default_options = {"n_iterations": 5000} | ||
training_opts = {**default_options, **training_options} | ||
return training_opts | ||
|
||
@abc.abstractmethod | ||
def _save_variables(self, training_variables): | ||
""" | ||
This saves the final parameter values after training | ||
Parameters | ||
---------- | ||
training_variables :list | ||
The variables trained | ||
""" | ||
raise NotImplementedError("_save_variables") | ||
|
||
@abc.abstractmethod | ||
def _initialize_save_losses(self): | ||
""" | ||
This method initializes the arrays to track relevant variables | ||
during training | ||
Parameters | ||
---------- | ||
""" | ||
raise NotImplementedError("_initialize_save_losses") | ||
|
||
@abc.abstractmethod | ||
def _save_losses(self, *args): | ||
""" | ||
This saves the respective losses at each iteration | ||
Parameters | ||
---------- | ||
""" | ||
raise NotImplementedError("_save_losses") | ||
|
||
@abc.abstractmethod | ||
def _test_inputs(self, *args): | ||
""" | ||
This performs error checking on inputs for fit | ||
Parameters | ||
---------- | ||
""" | ||
raise NotImplementedError("_transform_training_data") | ||
|
||
@abc.abstractmethod | ||
def _transform_training_data(self, *args): | ||
""" | ||
This converts training data to adequate format | ||
Parameters | ||
---------- | ||
""" | ||
raise NotImplementedError("_transform_training_data") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.