Skip to content

Commit

Permalink
Merge pull request #812 from sathvikbhagavan/sb/neuraladapter
Browse files Browse the repository at this point in the history
fix: bounds of variables in neural adapter setup
  • Loading branch information
ChrisRackauckas authored Mar 3, 2024
2 parents 9c068e9 + fa90f4a commit e68f778
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 237 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- AdaptiveLoss
- Logging
- Forward
- NeuralAdapter
version:
- "1"
steps:
Expand Down
13 changes: 6 additions & 7 deletions src/neural_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strateg
for d in domains])
args = get_argument(eqs, dict_indvars, dict_depvars)

bounds = map(args) do pd
bounds = first(map(args) do pd
span = map(p -> get(dict_span, p, p), pd)
map(s -> adapt(eltypeθ, s), span)
end
bounds
end)
bounds = [getindex.(bounds, 1), getindex.(bounds, 2)]
return bounds
end

function get_loss_function_(loss, init_params, pde_system, strategy::StochasticTraining)
Expand All @@ -46,8 +47,7 @@ function get_loss_function_(loss, init_params, pde_system, strategy::StochasticT
pde_system.depvars)

eltypeθ = eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)[1]

bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
get_loss_function(loss, bound, eltypeθ, strategy)
end

Expand All @@ -62,8 +62,7 @@ function get_loss_function_(loss, init_params, pde_system, strategy::QuasiRandom
pde_system.depvars)

eltypeθ = eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)[1]

bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
get_loss_function(loss, bound, eltypeθ, strategy)
end

Expand Down
1 change: 0 additions & 1 deletion src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ end
function get_loss_function(loss_function, bound, eltypeθ, strategy::StochasticTraining;
τ = nothing)
points = strategy.points

loss = (θ) -> begin
sets = generate_random_points(points, bound, eltypeθ)
sets_ = adapt(parameterless_type(ComponentArrays.getdata(θ)), sets)
Expand Down
Loading

0 comments on commit e68f778

Please sign in to comment.