Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Training Strategies to dae_solvers.jl #838

Closed
wants to merge 28 commits into from
Closed
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
280a43d
Update dae_solve.jl
hippyhippohops Mar 26, 2024
911b3a3
Update dae_solve.jl
hippyhippohops Mar 26, 2024
9d98ae1
Update NNDAE_tests.jl
hippyhippohops Mar 28, 2024
35700c4
Added strategy::QuadratureTraining
hippyhippohops Mar 28, 2024
5773e91
Formatted indentation in strategy::WeightIntervalTraining
hippyhippohops Mar 28, 2024
c9f54de
Formatted Indentation in strategy::QuadratureTraining
hippyhippohops Mar 28, 2024
abc61bb
Refactored generate_losses in ode_solve.jl
hippyhippohops Mar 31, 2024
e4f06d5
Reverted back the ode_solve.jl to the previous set of codes
hippyhippohops Apr 1, 2024
0f4cfde
Edits to dae_solve.jl and NNDAE_tests.jl
hippyhippohops Apr 9, 2024
911d68c
Modified dae_solve.jl and NNDAE_tests
hippyhippohops Apr 12, 2024
2f4c505
Removed param_estim
hippyhippohops Apr 21, 2024
52cdea8
Update dae_solve.jl
hippyhippohops Apr 25, 2024
95457b9
Merge branch 'SciML:master' into patch-1
hippyhippohops Apr 30, 2024
d6f2e5f
Reset the code to match master code. Planning to start from scratch a…
hippyhippohops May 3, 2024
1ed7683
Implemented WeightedIntervalTraining and it's Test
hippyhippohops May 6, 2024
2f9db68
Formatted Code
hippyhippohops May 6, 2024
c2453d2
Added in failed Quadature training
hippyhippohops May 8, 2024
7c6c2bf
trying to workout quadature training strategy.
hippyhippohops May 16, 2024
70e0657
Stochastic training passes
hippyhippohops May 16, 2024
0098c6d
updates on NNDAE_tests.jl
hippyhippohops May 26, 2024
3e9473e
Updates
hippyhippohops May 26, 2024
92ec11c
Merge branch 'SciML:master' into patch-1
hippyhippohops Jun 4, 2024
b00c8cf
removing empty line
hippyhippohops Jun 4, 2024
afd05ee
Merge branch 'patch-1' of https://github.com/hippyhippohops/NeuralPDE…
hippyhippohops Jun 6, 2024
a4e2877
changes to quadrature training
hippyhippohops Jun 7, 2024
41dbf62
Added Quadrature training
hippyhippohops Jun 10, 2024
9c72fee
Changing to float64
hippyhippohops Jul 8, 2024
47b5aea
Merge branch 'SciML:master' into patch-1
hippyhippohops Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,43 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end


function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,differential_vars::AbstractVector)
hippyhippohops marked this conversation as resolved.
Show resolved Hide resolved
autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data
function loss(θ, _)
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
hippyhippohops marked this conversation as resolved.
Show resolved Hide resolved
end

function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
alg::NNDAE,
args...;
Expand Down