-
Notifications
You must be signed in to change notification settings - Fork 5
Models
Model objects encapsulate the mapping between parameters and the spectrum. The key components are
- A set of parameters and their properties including limits, symbols and units
- A function that returns a power spectrum as a function of model parameters (
objective()
)
To understand how a model is put together, we will now step through the definition of bt.model.full
. The class definition is
classdef full < bt.model.template
This means that the class full
inherits member variables and methods from the bt.model.template
class. The template class specifies a minimal interface for all BrainTrak models. Reading through +bt/+model/template.m
, you can see a list of all of the variables and methods that must be implemented by the derived classes.
For the member variables, all models store:
-
name
, which identifies the model -
n_params
, which specifies how may parameters the model has -
param_names
, which are used when storing the parameters in structs. These names must satisfy normal Matlab variable conventions i.e. starting with a letter, no spaces -
param_symbols
, which are used when plotting. These can contain LaTeX commands for superscripts and subscripts etc. -
param_units
, which are used in some of the axis labels when plotting -
initial_step_size
- at the start of fitting, samples are drawn from a multivariate normal distribution. Theinitial_step_size
defines the standard deviation of the proposal distribution for each parameter. If it is too large, the routine will take a long time to commence fitting, because many proposed steps will be rejected. If it is too small, the initial proposal distribution will be too uniform, and the adaptive proposal distribution method may take a long time to adapt. However, in general,initial_step_size
should be smaller rather than bigger -
limits
, which stores lower and upper bounds for each model parameter -
n_fitted
, which stores the number of parameters being fitted at a particular time step -
skip_fit
, which stores which parameter are being fitted at a particular time step -
electrodes
, which stores a cell array of electrode names. This is important when performing multielectrode fitting
There are also a number of 'temporary' variables that are used internally while fitting, but that are useful to store within the model object. These are
-
target_f
, the frequencies for the data being fitted -
target_P
, an array of EEG power. This is a matrix for multielectrode fitting, in which case the columns correspond to the electrodes named in theelectrodes
variable -
weights
, which specifies the frequency weighting when calculating the goodness of fit statistic -
prior_pp
, which stores the priors -
prior_size
, which stores the number of points in each of the priors
Note that the priors are stored as a struct, which contains three arrays, x
, y
, and ndx
. These will be examined in more detail below.
An abstract method is one that has a declaration but no implementation. It is used to specify an interface for a class, that derived classes will implement. The template
class contains the following abstract methods
p_from_params(self,params)
params_from_p(self,p)
objective(self,pars)
initialize_fit(self,target_f,target_P)
get_xyz(self,params)
Notice that template.m
does not contain any code for these functions. When a class like bt.model.full
inherits from bt.model.template
, it is required to define these functions (or else an error will be raised). For example, you can open +bt/+model/full.m
and see that these functions are all defined. This structure has been used because all models need to be able to do these things, but they will in general all do them differently.
This function takes in the parameters for the model, and returns a model.params
object (from the corticothalamic-model
repository) that corresponds to the BrainTrak parameters. For example, if you had a BrainTrak model that didn't fit alpha and beta, but instead fitted alpha and beta/alpha, then p_from_params
would multiply alpha by beta/alpha and assign it to the output params
object. You can see this example in full_b_ratio.m
.
This is the opposite of params_from_p
- it takes in a params
object , and returns the BrainTrak parameters for the model. One important question is what to do if your model introduces new parameters that aren't compatible with model.params
. Notice that p_from_params
instantiates a new model.params
object, and params_from_p
takes in a params
argument. This means that you can define a new params
class for use with your object, if you wish. Most of BrainTrak does not use these functions - they exist mainly for convenience when working interactively with the corticothalamic model, and also for interfacing with the spatial variations system that relies on model.params
.
This function is the core of the model - it takes in the model parameters, computes the spectrum, and then calculates the goodness-of-fit function. Note that objective
needs to know the experimental spectrum in order to calculate the goodness-of-fit. Therefore, the experimental function being fitted is also stored in the model object.
Different models need to be started with different initial conditions. This function returns the initial parameters and initial prior distributions to use at the start of fitting. Note that initialize_fit
takes in the spectrum being fitted. Therefore, you can use an arbitrary function to calculate your initial parameters, taking into account the spectrum being fitted. For the corticothalamic model, an initial fit is performed against a set of precomputed spectra - see database fitting.
Many plots use XYZ, and these are particularly useful quantities to work with. Different models have different ways of computing XYZ. For example, if the gains are fitted like in full.m
, then you need to compute the values of XYZ. If you fit the gains directly, like in reduced.m
, then the parameters can be returned directly.
These variables change during the fitting process e.g., if an artifact is encountered, or if t0
is not being fitted at a particular timestep. These variables are generally assigned by prepare_for_fit()
.
To see how model objects are used, we can go through a typical fitting process using bt.core.fit_spectrum
.
Going through fit_spectrum.m
, the first parts are mainly concerned with initializing the default arguments. In order, these are
-
debugmode
which controls whetherchain.m
generates a traceplot. This is false by default -
skip_fit
which controls whether any of the parameters are held constant. This is empty (fit all parameters) by default -
target_state
which is stored in the output and used as a plot title. This is set toN/A
by default since some data does not have an associated sleep stage. -
npoints
which is simply the default chain length
Now, we handle the initial parameter values and priors. When tracking, these are both obtained from the previous fit. At the first fit, these values must be chosen in a different way. The initial values and priors depend on the model, since different models have different parameters. Therefore, these are delegated to the model's initialize_fit
method. So to get the initial parameters and priors, fit_spectrum()
calls model.initialize_fit()
.
Next, we need to select the electrodes. For a single electrode recording, it doesn't really matter what electrode is used, except that this may affect plotting further on. By default, the fit_spectrum.m
assumed the electrode is Cz, which is a reasonable assumption if the electrode is unknown. You can look in electrode_positions.m
for a list of electrodes and their coordinates - whatever electrode you specify must be named in that file (you can add entries to it with arbitary names and positions if you need to). For single-electrode models, set_electrodes()
does pretty much nothing, as seen in template.m
. For multi-electode models, set_electrodes()
uses electrode_positions.m
to calculate the spatial positions of the electrodes, which is required as part of calculating the predicted power spectrum at those same positions. Since set_electrodes()
is a method of the model
object, if you need to do any extra operations depending on the name of the electrode, you can overload set_electrodes()
in your model object.
After this initialization, we now get to the fitting itself. The first part of fitting is to call model.prepare_for_fit()
. This prepares the model object to be used by chain
. After prepare_for_fit()
is called, the model is expected to be ready for chain.m
. template.m
has a minimal list of the tasks that prepare_for_fit()
performs. These are
self.n_fitted = sum(~self.skip_fit);
self.target_f = target_f(:);
self.target_P = target_P;
self.prior_pp = prior_pp;
self.set_cache(initial_values);
So the model's copy of the target frequencies and power spectra is updated, the number of fitted parameters is calculated, the priors are loaded into the model, and then modle.set_cache()
is called. Before going further, note that the chain is called using
bt.core.chain(model,initial_values,npoints,debugmode,timelimit)
that is, the target_f
, target_P
, and prior_pp
are not passed to chain.m
as arguments - they are passed in via the model object. This is because you need all of these three quantities to compute the model's objective function, but chain.m
is not the only place where this needs to be done. Therefore, after prepare_for_fit()
is called, you can get the model to compute the objective (or the spectrum) given a set of parameters without providing any additional information.
The set_cache()
function is implemented by the model class. The idea is simple - in order to compute the power spectrum, you need to perform several calculations. Some of them, like computing (1-iw/gamma)^2
only need to be done once (as long as gamma
is not being fitted). Other part of the calculation, like computing L
, need to be recalculated because they depend on the model's fitted parameters. chain.m
will call the objective function many thousands of times. Therefore, any tasks that do not change based on the fitted parameters can be run in set_cache()
. set_cache()
is run before every fit, because the fitted parameters may change over the course of tracking. For example, exp(iwt_0)
needs to be recalculated if t_0
is being fitted, but in some models, t_0
is only fitted when the alpha peak in the experimental data is larger than some threshold. Therefore, set_cache()
in full.m
checks whether t_0
is being fitted, and if not, it calculates exp(iwt_0)
. Otherwise, exp(iwt_0)
is calculated in objective()
.
Sometimes, other tasks need to be performed by prepare_for_fit
. For example, as discussed above, bt.model.full
will only fit t_0
if the alpha peak is above a certain amount. This needs to be checked in prepare_for_fit
. Therefore, full.m
overloads prepare_for_fit
, but then calls bt.model.template.prepare_for_fit()
to ensure all of the essential operations are also performed.
chain.m
computes the random walk - see here for more information about how this works. The model object is used by the chain in several places. Most importantly, the model specifies
- The probability of a set of parameters. This is typically obtained using
template.probability()
. Note that this method in turn callsself.objective(pars)
andself.eval_priors(pars)
. That is, the specific model's objective function and prior function are called by the templateprobability()
method. So, whenchain.m
callsmodel.probability()
, this ends up callingobjective()
(which uses the storedtarget_f
andtarget_P
to compute the goodness of fit statistic) andeval_priors()
(which uses the storedprior_pp
to evaluate the priors) - The objective function, which returns the goodness of fit statistic for a set of parameters
- The
eval_priors()
function which returnsp(pars)
i.e. the product of the prior probabilities of each parameter. Note that for simplicity and tractability, the marginal distributions for each parameter are used, but future work could examine using the full joint distribution as the prior. - The
validate()
method, which performs any last minute checks regarding the internal state of the model. For example, it currently guards against having a zero-frequency component (which would result in an infinite weighting for that frequency) - The
validate_params()
method, which checks if all of the parameters are within their allowed ranges. If the parameters are outside their allowed ranges, then the probability of those parameters is zero
The method template.probability()
uses self.validate_params()
to return zero probability if the parameters are outside their allowed ranges. If the parameters are within their ranges, self.objective()
is called. If the parameters are unstable, then self.objective()
returns NaN
which then causes template.probability()
to return zero. So when implementing a new model, you should have method
-
validate_params
, which takes in the parameters, and returnstrue
if the parameters are acceptable for testing inobjective
, and false otherwise -
objective
, which takes in the parameters, and either returnsNaN
which will result in a probability of zero, or returnschisq
which is then converted to a probability after exponentiation and multiplication by the priors.
After the chain is completed, we return to bt.core.fit_spectrum
. The output steps are analyzed, and the parameters with the highest probability are selected. The model is then used to compute the posterior marginal distributions. These tasks are performed via the make_posterior()
method, and the xyz_posterior()
method. Both of these functions are provided by template.m
and rarely need to be modified.
The prior/posterior distributions are evaluated by linear interpolation. The distinction between posterior and prior is somewhat blurred because the posterior at time t
is used as the prior at time t+1
(they are exactly the same). This part of the BrainTrak is highly performance-sensitive, because the prior for each parameter is evaluated for each proposed step. For example, a model with 10 parameters, tested with a chain of length 100000, will require around 1 million evaluations of the priors.
The function make_posterior()
produces a struct called posterior_pp
, which contains 3 fields. These are
-
x
, which is a matrix withself.prior_size
rows, andself.n_params
columns, and corresponds to the value of the parameter -
y
, which is a matrix the same size asx
, corresponding to the probability -
ndx
, which is a the reciprocal of the difference between adjacentx
values
As seen in template.make_posterior
, for each parameter
- A range of x-values is created using
linspace()
between the lower and upper limits for that parameter - The
hist
function is used to estimate the marginal distribution of the parameters. The bin centers are provided - Linear interpolation is used to estimate the probability at the bin edges themselves
- Any probabilities less than zero (which would be due to the interpolation) are set to zero
- The y values are normalized using
trapz()
so that the integral overx
is 1 (which makes y a probability density)
The next question is, how do we use these data to perform the interpolation? The method for this is contained in eval_priors()
. When an x
value is requested, the floor()
function is used to find the index in prior_pp.x
preceding the requested value - this is stored as fxi
. The requested value of x
then lies between fxi
and fxi+1
. Linear interpolation is then used between these two indexes. Note that ndx
appears when computing fxi
, which is why it is stored in prior_pp - knowing that the x
values are evenly spaced means that the array index can be found using the floor()
function, rather than requiring a binary search. The matrix operations in eval_priors()
enable all of the priors to be evaluated by linear interpolation simultaneously, which makes the interpolation very efficient even when there are many parameters.
The method xyz_posterior
is analogous to make_posterior
except that it computes the posterior distributions for the XYZ quantities rather than the model's parameters. It relies on model.get_xyz()
to convert the parameters into XYZ, and then model.make_posterior
to compute the posteriors.
You can also construct dummy models that test various aspects of BrainTrak. For example, bt.model.dummy
tests whether the posterior distributions are calculated correctly, by explicitly setting the underlying distributions. You can use it with
f = bt.core.fit_spectrum(bt.model.dummy,1,1,[],[],1e6)
f.plot()
to verify that the marginal distributions match the actual distributions.