-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor the convergence logic #373
Conversation
@@ -30,7 +30,7 @@ end | |||
initialization::Symbol = :add_perturbation, | |||
generative_model_params::NamedTuple = (;), | |||
min_success_rate::AbstractFloat=0.99, | |||
converge_when::Symbol=:decision_threshold, | |||
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold, | |||
invalidation_rate::AbstractFloat=0.5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundancies such as the unneccessary keyword arguments here will be removed as part of another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM when tests pass and no major changes are implemented
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #373 +/- ##
==========================================
+ Coverage 77.04% 77.71% +0.66%
==========================================
Files 69 74 +5
Lines 1577 1557 -20
==========================================
- Hits 1215 1210 -5
+ Misses 362 347 -15 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done! Left just one comment that needs to be addressed before I approve it.
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(; | ||
decision_threshold=(1 / length(data.y_levels)) | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't it always overwrites the convergence
value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so - it's the same kind of default argument as the ones other keyword arguments here are using. If the user provides a different convergence
value, that will be used instead. The tests should also fail if this is always used as the convergence
value. Am I missing something here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, this looks good to me
src/convergence/invalidation_rate.jl
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in #371, I don't think this should be handled as it's own convergence class, but rather under generator conditions. Happy to park that for later though, since we want to move on with this.
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(; | ||
decision_threshold=(1 / length(data.y_levels)) | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine to me
This pull request is part of the refactor to simplify the core struct and related methods. In it, I propose the following changes:
converged()
method. This removes a ton of kwargs from the main struct and makes the code clearer in general. For the more casual users, nothing changes after this refactor: they can still use the:converge_when
keyword argument to specify a convergence criterion with the default arguments. The more advanced user will need to be a bit more involved than before to define custom convergence criteria: they will have to define one of the new convergence structs on their own and pass that to thegenerate_counterfactual()
method. However, the improved clarity of the code seems to justify the added hassle for those users.converged()
method that had been created specifically for Growing Spheres. There's no reason Growing Spheres shouldn't fit into the pre-existing convergence framework.:early_stopping
convergence type, which appears to be functionally the same as the:max_iter
convergence type.