Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the convergence logic #373

Merged
merged 8 commits into from
Dec 13, 2023
Merged

Refactor the convergence logic #373

merged 8 commits into from
Dec 13, 2023

Conversation

RaunoArike
Copy link
Contributor

@RaunoArike RaunoArike commented Dec 8, 2023

This pull request is part of the refactor to simplify the core struct and related methods. In it, I propose the following changes:

  • Create new structs for all different convergence criteria, and dispatch based on the types of these structs when calling the converged() method. This removes a ton of kwargs from the main struct and makes the code clearer in general. For the more casual users, nothing changes after this refactor: they can still use the :converge_when keyword argument to specify a convergence criterion with the default arguments. The more advanced user will need to be a bit more involved than before to define custom convergence criteria: they will have to define one of the new convergence structs on their own and pass that to the generate_counterfactual() method. However, the improved clarity of the code seems to justify the added hassle for those users.
  • Remove the converged() method that had been created specifically for Growing Spheres. There's no reason Growing Spheres shouldn't fit into the pre-existing convergence framework.
  • Remove the :early_stopping convergence type, which appears to be functionally the same as the :max_iter convergence type.

@RaunoArike RaunoArike changed the title Refactor convergence logic Refactor the convergence logic Dec 8, 2023
@@ -30,7 +30,7 @@ end
initialization::Symbol = :add_perturbation,
generative_model_params::NamedTuple = (;),
min_success_rate::AbstractFloat=0.99,
converge_when::Symbol=:decision_threshold,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
invalidation_rate::AbstractFloat=0.5,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundancies such as the unneccessary keyword arguments here will be removed as part of another PR.

Copy link
Contributor

@VincentPikand VincentPikand left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM when tests pass and no major changes are implemented

src/generators/gradient_based/loss.jl Show resolved Hide resolved
Copy link

codecov bot commented Dec 11, 2023

Codecov Report

Attention: 6 lines in your changes are missing coverage. Please review.

Comparison is base (2e36ea7) 77.04% compared to head (2fac9ca) 77.71%.

Files Patch % Lines
src/convergence/invalidation_rate.jl 90.47% 2 Missing ⚠️
src/convergence/max_iter.jl 33.33% 2 Missing ⚠️
src/counterfactuals/latent_space_mappings.jl 0.00% 1 Missing ⚠️
src/evaluation/measures.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #373      +/-   ##
==========================================
+ Coverage   77.04%   77.71%   +0.66%     
==========================================
  Files          69       74       +5     
  Lines        1577     1557      -20     
==========================================
- Hits         1215     1210       -5     
+ Misses        362      347      -15     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@RaunoArike RaunoArike marked this pull request as draft December 11, 2023 15:25
@RaunoArike RaunoArike self-assigned this Dec 11, 2023
@RaunoArike RaunoArike marked this pull request as ready for review December 11, 2023 15:25
Copy link
Contributor

@kmariuszk kmariuszk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done! Left just one comment that needs to be addressed before I approve it.

Comment on lines +56 to +58
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
decision_threshold=(1 / length(data.y_levels))
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't it always overwrites the convergence value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so - it's the same kind of default argument as the ones other keyword arguments here are using. If the user provides a different convergence value, that will be used instead. The tests should also fail if this is always used as the convergence value. Am I missing something here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me

@RaunoArike RaunoArike requested a review from kmariuszk December 12, 2023 18:26
Copy link
Member

@pat-alt pat-alt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, this looks good to me

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in #371, I don't think this should be handled as it's own convergence class, but rather under generator conditions. Happy to park that for later though, since we want to move on with this.

Comment on lines +56 to +58
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
decision_threshold=(1 / length(data.y_levels))
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me

src/generators/gradient_based/loss.jl Show resolved Hide resolved
@RaunoArike RaunoArike merged commit bbb2dde into main Dec 13, 2023
7 of 11 checks passed
@RaunoArike RaunoArike deleted the refactor/372 branch December 13, 2023 16:32
This was referenced Dec 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants