Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix benchmark collection gradient check #2564

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tests/benchmark-models/test_benchmark_collection.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,22 @@ for model in $models; do
yaml="${model_dir}"/"${model}"/problem.yaml
fi

# problems we need to flatten
to_flatten=(
"Bruno_JExpBot2016" "Chen_MSB2009" "Crauste_CellSystems2017"
"Fiedler_BMCSystBiol2016" "Fujita_SciSignal2010" "SalazarCavazos_MBoC2020"
)
flatten=""
for item in "${to_flatten[@]}"; do
if [[ "$item" == "$model" ]]; then
flatten="--flatten"
break
fi
done

amici_model_dir=test_bmc/"${model}"
mkdir -p "$amici_model_dir"
cmd_import="amici_import_petab ${yaml} -o ${amici_model_dir} -n ${model} --flatten"
cmd_import="amici_import_petab ${yaml} -o ${amici_model_dir} -n ${model} ${flatten}"
cmd_run="$script_path/test_petab_model.py -y ${yaml} -d ${amici_model_dir} -m ${model} -c"

printf '=%.0s' {1..40}
Expand Down
29 changes: 18 additions & 11 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from amici.petab.petab_import import import_petab_problem
import benchmark_models_petab
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from amici import SensitivityMethod
from petab.v1.lint import measurement_table_has_timepoint_specific_mappings
from fiddy import MethodId, get_derivative
from fiddy.derivative_check import NumpyIsCloseDerivativeCheck
from fiddy.extensions.amici import simulate_petab_to_cached_functions
Expand Down Expand Up @@ -58,14 +59,18 @@ class GradientCheckSettings:
atol_consistency: float = 1e-5
rtol_consistency: float = 1e-1
# Step sizes for finite difference gradient checks.
step_sizes = [
1e-1,
5e-2,
1e-2,
1e-3,
1e-4,
1e-5,
]
step_sizes: list[float] = field(
default_factory=lambda: [
2e-1,
1e-1,
5e-2,
1e-2,
5e-1,
1e-3,
1e-4,
1e-5,
]
)
rng_seed: int = 0
ss_sensitivity_mode: amici.SteadyStateSensitivityMode = (
amici.SteadyStateSensitivityMode.integrateIfNewtonFails
Expand Down Expand Up @@ -97,7 +102,6 @@ class GradientCheckSettings:
noise_level=0.01,
atol_consistency=1e-3,
)
settings["Okuonghae_ChaosSolitonsFractals2020"].step_sizes.extend([0.2, 0.005])
settings["Oliveira_NatCommun2021"] = GradientCheckSettings(
# Avoid "root after reinitialization"
atol_sim=1e-12,
Expand Down Expand Up @@ -176,7 +180,10 @@ def test_benchmark_gradient(model, scale, sensitivity_method, request):
pytest.skip()

petab_problem = benchmark_models_petab.get_problem(model)
petab.flatten_timepoint_specific_output_overrides(petab_problem)
if measurement_table_has_timepoint_specific_mappings(
petab_problem.measurement_df,
):
petab.flatten_timepoint_specific_output_overrides(petab_problem)

# Only compute gradient for estimated parameters.
parameter_ids = petab_problem.x_free_ids
Expand Down
7 changes: 6 additions & 1 deletion tests/benchmark-models/test_petab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
simulate_petab,
)
from petab.v1.visualize import plot_problem
from petab.v1.lint import measurement_table_has_timepoint_specific_mappings

logger = get_logger(f"amici.{__name__}", logging.WARNING)

Expand Down Expand Up @@ -115,7 +116,11 @@ def main():

# load PEtab files
problem = petab.Problem.from_yaml(args.yaml_file_name)
petab.flatten_timepoint_specific_output_overrides(problem)

if measurement_table_has_timepoint_specific_mappings(
problem.measurement_df
):
petab.flatten_timepoint_specific_output_overrides(problem)

# load model
if args.model_directory:
Expand Down
Loading