-
Notifications
You must be signed in to change notification settings - Fork 0
/
adjoint_nonlin_diff_1D_v2.jl
179 lines (168 loc) · 5.31 KB
/
adjoint_nonlin_diff_1D_v2.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
using Enzyme,Plots,Printf
function residual!(R,H,npow,dx)
@inbounds @simd for ix = 2:length(R)-1
R[ix] = (H[ix-1]^(npow[ix-1]+1.0) - 2.0*H[ix]^(npow[ix]+1.0) + H[ix+1]^(npow[ix+1]+1.0))/dx/dx/(npow[ix]+1.0)
end
return
end
function cost(H,H_obs)
J = 0.0
@inbounds @simd for ix ∈ eachindex(H)
J += (H[ix]-H_obs[ix])^2
end
return 0.5*J
end
mutable struct ForwardProblem{T<:Real,A<:AbstractArray{T}}
H::A; R::A; dR::A; npow::A
niter::Int; ncheck::Int; ϵtol::T
dx::T; dmp::T
end
function ForwardProblem(H,npow,niter,ncheck,ϵtol,dx,dmp)
R = similar(H)
dR = similar(H)
return ForwardProblem(H,R,dR,npow,niter,ncheck,ϵtol,dx,dmp)
end
function solve!(problem::ForwardProblem)
(;H,R,dR,npow,niter,ncheck,ϵtol,dx,dmp) = problem
nx = length(H)
dt = dx/6
R .= 0; dR .= 0
merr = 2ϵtol; iter = 1
while merr >= ϵtol && iter < niter
residual!(dR,H,npow,dx)
@. R = R*(1.0-dmp/nx) + dt*dR
@. H += dt*R
if iter % ncheck == 0
merr = maximum(abs.(dR))
isfinite(merr) || error("forward solve failed")
end
iter += 1
end
if iter == niter && merr >= ϵtol
error("forward solve not converged")
end
@printf(" forward solve converged: #iter/nx = %.1f, err = %.1e\n",iter/nx,merr)
return
end
mutable struct AdjointProblem{T<:Real,A<:AbstractArray{T}}
Ψ::A; R::A; dR; tmp1::A; tmp2::A; ∂J_∂H::A; H::A; H_obs::A; npow::A
niter::Int; ncheck::Int; ϵtol::T
dx::T; dmp::T
end
function AdjointProblem(H,H_obs,npow,niter,ncheck,ϵtol,dx,dmp)
Ψ = similar(H)
R = similar(H)
dR = similar(H)
tmp1 = similar(H)
tmp2 = similar(H)
∂J_∂H = similar(H)
return AdjointProblem(Ψ,R,dR,tmp1,tmp2,∂J_∂H,H,H_obs,npow,niter,ncheck,ϵtol,dx,dmp)
end
function solve!(problem::AdjointProblem)
(;Ψ,R,dR,tmp1,tmp2,∂J_∂H,H,H_obs,npow,niter,ncheck,ϵtol,dx,dmp) = problem
nx = length(Ψ)
dt = dx/3
Ψ .= 0; R .= 0; dR .= 0
@. ∂J_∂H = H - H_obs
merr = 2ϵtol; iter = 1
while merr >= ϵtol && iter < niter
dR .= .-∂J_∂H; tmp2 .= Ψ
Enzyme.autodiff(residual!,Duplicated(tmp1,tmp2),Duplicated(H,dR),Const(npow),Const(dx))
@. R = R*(1.0-dmp/nx) + dt*dR
@. Ψ += dt*R
Ψ[1] = 0; Ψ[end] = 0
if iter % ncheck == 0
merr = maximum(abs.(dR[2:end-1]))
isfinite(merr) || error("adjoint solve failed")
end
iter += 1
end
if iter == niter && merr >= ϵtol
error("adjoint solve not converged")
end
@printf(" adjoint solve converged: #iter/nx = %.1f, err = %.1e\n",iter/nx,merr)
return
end
function cost_gradient!(Jn,problem::AdjointProblem)
(;Ψ,tmp1,tmp2,H,npow,dx) = problem
tmp1 .= .-Ψ; Jn .= 0.0
Enzyme.autodiff(residual!,Duplicated(tmp2,tmp1),Const(H),Duplicated(npow,Jn),Const(dx))
Jn[1] = Jn[2]; Jn[end] = Jn[end-1]
return
end
@views function main()
# physics
lx = 20.0
npows0 = 3.0
npowi0 = 1.0
# numerics
nx = 128
niter = 100nx
ncheck = 5nx
ϵtol = 1e-8
gd_ϵtol = 1e-5
dmp = 1/2
dmp_adj = 3/2
gd_niter = 500
bt_niter = 10
γ0 = 1e2
# preprocessing
dx = lx/nx
xc = LinRange(dx/2,lx-dx/2,nx)
# init
H = collect(1.0 .- 0.5.*xc./lx)
H_obs = copy(H)
H_ini = copy(H)
npow_synt = fill(npows0,nx)
npow_init = fill(npowi0,nx)
npow = copy(npow_init)
Jn = zeros(nx) # cost function gradient
synt_problem = ForwardProblem(H_obs, npow_synt,niter,ncheck,ϵtol,dx,dmp )
fwd_problem = ForwardProblem(H , npow ,niter,ncheck,ϵtol,dx,dmp )
adj_problem = AdjointProblem(H ,H_obs,npow ,niter,ncheck,ϵtol,dx,dmp_adj)
# action
println(" generating synthetic data...")
solve!(synt_problem)
println(" done.")
solve!(fwd_problem)
println(" gradient descent")
γ = γ0
J_old = cost(H,H_obs)
J_evo = Float64[]; iter_evo = Int[]
for gd_iter = 1:gd_niter
npow_init .= npow
# adjoint solve
solve!(adj_problem)
# compute cost function gradient
cost_gradient!(Jn,adj_problem)
# line search
for bt_iter = 1:bt_niter
@. npow -= γ*Jn
fwd_problem.H .= H_ini
solve!(fwd_problem)
J_new = cost(H,H_obs)
if J_new < J_old
γ *= 1.2
J_old = J_new
break
else
npow .= npow_init
γ *= 0.5
end
end
push!(iter_evo,gd_iter); push!(J_evo,J_old)
if J_old < gd_ϵtol
@printf(" gradient descent converged, misfit = %.1e\n", J_old)
break
else
@printf(" #iter = %d, misfit = %.1e\n", gd_iter, J_old)
end
# visu
p1 = plot(xc,[H,H_obs] ; title="H" , label=["H" "H_obs"])
p2 = plot(iter_evo,J_evo ; title="misfit", label="", yaxis=:log10)
p3 = plot(xc,[npow,npow_synt]; title="n" , label=["current" "synthetic"])
display(plot(p1,p2,p3;layout=(1,3)))
end
return
end
main()