-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Make models inherit from base model #176
Conversation
@effigies @oesteban Following this comment #166 (comment), I had a look at inheriting the models from The main point is that the This is also related to issue #174. So comments would be appreciated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will make the code much more readable -- left some comments.
df51602
to
5d58f56
Compare
src/eddymotion/model/base.py
Outdated
|
||
def _exec_fit(model, data, chunk=None): | ||
retval = model.fit(data) | ||
return retval, chunk | ||
|
||
|
||
def _exec_predict(model, gradient, chunk=None, **kwargs): | ||
def _exec_predict_dwi(model, gradient, chunk=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed this method to contain the dwi
label, as it requires a gradient and optionally uses a S0
argument. If gradient
in reality should be an index
, it may be renamed back. Also, as things are right now, I do not see the need to pass an S0
either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move this back to be general, and make gradient an index -- will work on this through a PR to this branch.
src/eddymotion/model/base.py
Outdated
|
||
gradient = _rasb2dipy(gradient) | ||
self._gtab = _rasb2dipy(self._gtab) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about this: gradient
contains the whole gtab according to _rasb2dipy
: https://github.com/nipreps/eddymotion/pull/176/files#diff-a875f501910044a7d95658fb83740e2c5c6c1693e7e6808703d282441db82be8L412
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a naming confusion here: the models expect a RAS+b gradient object (which is different from the dipy
gtab
object) into a gtab
parameter. Am I correct @oesteban ?
src/eddymotion/model/base.py
Outdated
"""Predict asynchronously chunk-by-chunk the diffusion signal.""" | ||
if self._b_max is not None: | ||
gradient[-1] = min(gradient[-1], self._b_max) | ||
index[-1] = min(index[-1], self._b_max) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not OK. If the gradients are capped, not sure how the indices get affected/how they should be checked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is capping the b value (only the last item of gradient
. For some models, very high b-values 'saturate' and it's better to model as if they were lower. This only kicks in after setting b_max so you need to be explicit about it.
src/eddymotion/model/base.py
Outdated
((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) | ||
if gtab is not None | ||
((self._gtab[3] >= self._th_low) & (self._gtab[3] <= self._th_high)) | ||
if self._gtab is not None | ||
else np.ones((data.shape[-1],), dtype=bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd dare to say that self._gtab
will not be None
, so this if/else
block is not necessary to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we do want to use the input gtab, as opposed to the global gtab, which potentially contains the left-out gradient.
Gave this another go. Adjusted some docstrings. More questions/comments (inline and below, long to digest, sorry):
|
Sorry to ping you again this morning @oesteban. |
|
||
model_str = getattr(self, "_model_class", None) | ||
if not model_str: | ||
raise TypeError("No model defined") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oesteban Some tests are now failing because the _model_class
is None
in this base class, and I am not setting any particular value in the derived classes. What is this property supposed to contain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused about the use of this. I now see that _model_class
and _modelargs
are properties of the wrappers DTIModel
and DKIModel
. This adds to this docstring https://github.com/nipreps/eddymotion/pull/176/files#diff-a875f501910044a7d95658fb83740e2c5c6c1693e7e6808703d282441db82be8L79
If I say here
kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs}
from importlib import import_module
model_str = "eddymotion.model.AverageDWModel"
module_name, class_name = model_str.rsplit(".", 1)
my_model = getattr(import_module(module_name), class_name)(**kwargs)
and leaving aside that AverageDWModel
requires a gtab
(as it inherits from BaseDWIModel
; setting it to None
in its init
would make it), the above statement produces a recursive call, since instantiating AverageDWModel
calls the superclass init
method.
So I am not following what was intended with this block.
Also, I am not sure what we want to do here with the DTI and DKI wrappers either.
Edit: if the DTI/DKI wrappers make sense here, it looks as if the BaseDWIModel
should not inherit from BaseModel
, or at least, the init
method of the latter and its docstring suggest that it is intended to be a superclass for the wrapper classes; however, the TrivialB0Model
, AverageDWModel
, etc. are not intended to be wrappers around dipy
objects, and it does not make sense IMO for them to have _model_class
and _modelargs
properties. So there seems to be 2 things that are mixed here. The model factory will also need to be adapted following all this.
@oesteban Can you please clarify these aspects?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and leaving aside that AverageDWModel requires a gtab (as it inherits from BaseDWIModel;
The other two tests fail because of this reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a look ASAP - sorry for my slow turnaround
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed response. I can now answer this question. I'm sorry for this particular case—you are prey to an undocumented feature.
The idea was that models can be fit in two ways:
- Pure leave-one-out fashion: at every iteration of the Estimator, a fully-fledged model is fit without the particular index/orientation. This is typically very slow.
- Single model: the model is fit on all the data, and each iteration produces the left-out index. These are enable by adding the prefix
Full
to the model name.
This is implemented in the Estimator, under the understanding that the model is the same, what changes is how you use it.
eddymotion/src/eddymotion/estimator.py
Lines 117 to 135 in ce8de17
single_model = model.lower() in ( | |
"b0", | |
"s0", | |
"avg", | |
"average", | |
"mean", | |
) or model.lower().startswith("full") | |
dwmodel = None | |
if single_model: | |
if model.lower().startswith("full"): | |
model = model[4:] | |
# Factory creates the appropriate model and pipes arguments | |
dwmodel = ModelFactory.init( | |
model=model, | |
**kwargs, | |
) | |
dwmodel.fit(dwdata.dataobj, n_jobs=n_jobs) |
ATM I cannot comment on why this has some effect on the model itself so a _model_class
is necessary, I can't recall the reason. I bet it is just to inform the estimator that fit should not be called every time (which probably should be handled here!)
That said, let's take the average model for example. When instantiated as FullAverage
, then it is fit only once before entering the iterator loop of the estimator. If not, at every iteration an average without the particular direction will be calculated in the fit call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'll leave my above comment because it explains something useful --- but it is totally unrelated to @jhlegarreta's question. Apologies for the confusion.
After working on the PR and re-reading the code, I understand that _model_class
and _modelargs
enable using DIPY models without much overhead (see DKI and DTI at the end).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to checkout this code to make a better review of it. A nit pick for the time being.
|
||
model_str = getattr(self, "_model_class", None) | ||
if not model_str: | ||
raise TypeError("No model defined") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a look ASAP - sorry for my slow turnaround
Make models inherit from base model.
5d58f56
to
324d2ee
Compare
Let's get #166 over the final line and then I move onto this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like where this is going. I'm going to add the docstring of constants and then work locally on this PR.
Improving the documentation of constants. cc/ @jhlegarreta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I responded to the two major questions in this PR. Happy to chat about the _model_class
as it seems the feature may not be implemented in an intuitive way and it is definitely not sufficiently documented.
src/eddymotion/model/base.py
Outdated
|
||
def _exec_fit(model, data, chunk=None): | ||
retval = model.fit(data) | ||
return retval, chunk | ||
|
||
|
||
def _exec_predict(model, gradient, chunk=None, **kwargs): | ||
def _exec_predict_dwi(model, gradient, chunk=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move this back to be general, and make gradient an index -- will work on this through a PR to this branch.
|
||
model_str = getattr(self, "_model_class", None) | ||
if not model_str: | ||
raise TypeError("No model defined") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed response. I can now answer this question. I'm sorry for this particular case—you are prey to an undocumented feature.
The idea was that models can be fit in two ways:
- Pure leave-one-out fashion: at every iteration of the Estimator, a fully-fledged model is fit without the particular index/orientation. This is typically very slow.
- Single model: the model is fit on all the data, and each iteration produces the left-out index. These are enable by adding the prefix
Full
to the model name.
This is implemented in the Estimator, under the understanding that the model is the same, what changes is how you use it.
eddymotion/src/eddymotion/estimator.py
Lines 117 to 135 in ce8de17
single_model = model.lower() in ( | |
"b0", | |
"s0", | |
"avg", | |
"average", | |
"mean", | |
) or model.lower().startswith("full") | |
dwmodel = None | |
if single_model: | |
if model.lower().startswith("full"): | |
model = model[4:] | |
# Factory creates the appropriate model and pipes arguments | |
dwmodel = ModelFactory.init( | |
model=model, | |
**kwargs, | |
) | |
dwmodel.fit(dwdata.dataobj, n_jobs=n_jobs) |
ATM I cannot comment on why this has some effect on the model itself so a _model_class
is necessary, I can't recall the reason. I bet it is just to inform the estimator that fit should not be called every time (which probably should be handled here!)
That said, let's take the average model for example. When instantiated as FullAverage
, then it is fit only once before entering the iterator loop of the estimator. If not, at every iteration an average without the particular direction will be calculated in the fit call.
@oesteban Have gone through the comments. Will wait after this #176 (comment). The main difficulty to make this work now lies in https://github.com/nipreps/eddymotion/pull/176/files#r1616403642. Although you answered to the thread, not sure if the question was addressed: the point is that I do not see why |
* enh: revise code * sty: ruff format
* enh: revise code * sty: ruff format
47cd1e6
to
250d23a
Compare
* enh: revise code * sty: ruff format
250d23a
to
abae1cc
Compare
* enh: revise code * sty: ruff format
abae1cc
to
ee73cf9
Compare
* enh: revise code * sty: ruff format
ee73cf9
to
104bc3d
Compare
* enh: revise code * sty: ruff format
104bc3d
to
a768e72
Compare
@oesteban Had to adjust the test: due to this now being applied to the |
Do not overwrite the gradient table in prediction. Co-authored-by: Oscar Esteban <[email protected]>
Make models inherit from base model.