Skip to content

Commit

Permalink
docs: use GridTraining for a couple examples which trains better and …
Browse files Browse the repository at this point in the history
…faster
  • Loading branch information
sathvikbhagavan committed Mar 4, 2024
1 parent 6e46d21 commit 5fb889a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 21 deletions.
10 changes: 3 additions & 7 deletions docs/src/examples/linear_parabolic.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ w(t, 1) = \frac{e^{\lambda_1} cos(\frac{x}{a})-e^{\lambda_2}cos(\frac{x}{a})}{\l
with a physics-informed neural network.

```@example
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, LineSearches
using Plots
import ModelingToolkit: Interval, infimum, supremum
Expand Down Expand Up @@ -71,7 +71,7 @@ input_ = length(domains)
n = 15
chain = [Lux.Chain(Dense(input_, n, Lux.σ), Dense(n, n, Lux.σ), Dense(n, 1)) for _ in 1:2]
strategy = QuadratureTraining()
strategy = GridTraining(0.01)
discretization = PhysicsInformedNN(chain, strategy)
@named pdesystem = PDESystem(eqs, bcs, domains, [t, x], [u(t, x), w(t, x)])
Expand All @@ -92,7 +92,7 @@ callback = function (p, l)
return false
end
res = Optimization.solve(prob, BFGS(); callback = callback, maxiters = 5000)
res = Optimization.solve(prob, LBFGS(linesearch = BackTracking()); callback = callback, maxiters = 500)
phi = discretization.phi
Expand All @@ -110,9 +110,5 @@ for i in 1:2
p2 = plot(ts, xs, u_predict[i], linetype = :contourf, title = "predict")
p3 = plot(ts, xs, diff_u[i], linetype = :contourf, title = "error")
plot(p1, p2, p3)
savefig("sol_u$i")
end
```

![linear_parabolic_sol_u1](https://user-images.githubusercontent.com/26853713/125745625-49c73760-0522-4ed4-9bdd-bcc567c9ace3.png)
![linear_parabolic_sol_u2](https://user-images.githubusercontent.com/26853713/125745637-b12e1d06-e27b-46fe-89f3-076d415fcd7e.png)
12 changes: 2 additions & 10 deletions docs/src/examples/nonlinear_elliptic.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,14 @@ input_ = length(domains)
n = 15
chain = [Lux.Chain(Dense(input_, n, Lux.σ), Dense(n, n, Lux.σ), Dense(n, 1)) for _ in 1:6] # 1:number of @variables
strategy = QuadratureTraining()
strategy = GridTraining(0.01)
discretization = PhysicsInformedNN(chain, strategy)
vars = [u(x, y), w(x, y), Dxu(x, y), Dyu(x, y), Dxw(x, y), Dyw(x, y)]
@named pdesystem = PDESystem(eqs_, bcs__, domains, [x, y], vars)
prob = NeuralPDE.discretize(pdesystem, discretization)
sym_prob = NeuralPDE.symbolic_discretize(pdesystem, discretization)
strategy = NeuralPDE.QuadratureTraining()
discretization = PhysicsInformedNN(chain, strategy)
sym_prob = NeuralPDE.symbolic_discretize(pdesystem, discretization)
pde_inner_loss_functions = sym_prob.loss_functions.pde_loss_functions
bcs_inner_loss_functions = sym_prob.loss_functions.bc_loss_functions[1:6]
aprox_derivative_loss_functions = sym_prob.loss_functions.bc_loss_functions[7:end]
Expand All @@ -107,7 +103,7 @@ callback = function (p, l)
return false
end
res = Optimization.solve(prob, BFGS(); callback = callback, maxiters = 5000)
res = Optimization.solve(prob, BFGS(); callback = callback, maxiters = 100)
phi = discretization.phi
Expand All @@ -125,9 +121,5 @@ for i in 1:2
p2 = plot(xs, ys, u_predict[i], linetype = :contourf, title = "predict")
p3 = plot(xs, ys, diff_u[i], linetype = :contourf, title = "error")
plot(p1, p2, p3)
savefig("non_linear_elliptic_sol_u$i")
end
```

![non_linear_elliptic_sol_u1](https://user-images.githubusercontent.com/26853713/125745550-0b667c10-b09a-4659-a543-4f7a7e025d6c.png)
![non_linear_elliptic_sol_u2](https://user-images.githubusercontent.com/26853713/125745571-45a04739-7838-40ce-b979-43b88d149028.png)
6 changes: 2 additions & 4 deletions docs/src/tutorials/constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function norm_loss_function(phi, θ, p)
end
discretization = PhysicsInformedNN(chain,
QuadratureTraining(),
GridTraining(0.01),
additional_loss = norm_loss_function)
@named pdesystem = PDESystem(eq, bcs, domains, [x], [p(x)])
Expand All @@ -86,7 +86,7 @@ end
res = Optimization.solve(prob, LBFGS(), callback = cb_, maxiters = 400)
prob = remake(prob, u0 = res.u)
res = Optimization.solve(prob, BFGS(), callback = cb_, maxiters = 2000)
res = Optimization.solve(prob, BFGS(), callback = cb_, maxiters = 500)
```

And some analysis:
Expand All @@ -103,5 +103,3 @@ u_predict = [first(phi(x, res.u)) for x in xs]
plot(xs, u_real, label = "analytic")
plot!(xs, u_predict, label = "predict")
```

![fp](https://user-images.githubusercontent.com/12683885/129405830-3d00c24e-adf1-443b-aa36-6af0e5305821.png)

0 comments on commit 5fb889a

Please sign in to comment.