Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spin dynamics speedup #28

Merged
merged 8 commits into from
Sep 16, 2024
126 changes: 86 additions & 40 deletions src/Dynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,49 +50,17 @@ function diffeqsolver(
if length(Jlist) != length(matrix) || length(Jlist) != length(bcoupling)
throw(DimensionMismatch("The dimension of Jlist, bcoupling, and matrix must match."))
end

M = length(Jlist)
u0 = [s0; zeros(6*M)]
invsqrtS0 = 1/sqrt(S0)
Cω2 = []
b = []
for i in 1:M
push!(Cω2, matrix[i].C*transpose(matrix[i].C))
push!(b, t -> invsqrtS0*matrix[i].C*[bfield[i][1](t), bfield[i][2](t), bfield[i][3](t)]);
end
function f(du, u, (Cω2v, Beff), t)
Cω2v = get_tmp(Cω2v, u)
Beff = get_tmp(Beff, u)
s = @view u[1:3*N]
v = @view u[1+3*N:3*N+3*M]
w = @view u[1+3*N+3*M:3*N+6*M]
ds = @view du[1:3*N]
dv = @view du[1+3*N:3*N+3*M]
dw = @view du[1+3*N+3*M:3*N+6*M]
for i in 1:N
Beff[i, :] .= Bext
for j in 1:M
Beff[i, :] .+= bcoupling[j][i]*(b[j](t) + mul!(Cω2v, Cω2[j], v[1+(j-1)*3:3+(j-1)*3]))
end
end
for i in 1:N
ds[1+(i-1)*3] = -(s[2+(i-1)*3]*Beff[i,3]-s[3+(i-1)*3]*Beff[i,2])
ds[2+(i-1)*3] = -(s[3+(i-1)*3]*Beff[i,1]-s[1+(i-1)*3]*Beff[i,3])
ds[3+(i-1)*3] = -(s[1+(i-1)*3]*Beff[i,2]-s[2+(i-1)*3]*Beff[i,1])
for j in 1:N
ds[1+(i-1)*3] += -(s[2+(i-1)*3]*JH[i,j]*s[3+(j-1)*3]-s[3+(i-1)*3]*JH[i,j]*s[2+(j-1)*3])
ds[2+(i-1)*3] += -(s[3+(i-1)*3]*JH[i,j]*s[1+(j-1)*3]-s[1+(i-1)*3]*JH[i,j]*s[3+(j-1)*3])
ds[3+(i-1)*3] += -(s[1+(i-1)*3]*JH[i,j]*s[2+(j-1)*3]-s[2+(i-1)*3]*JH[i,j]*s[1+(j-1)*3])
end
end
dv .= w
for i in 1:M
dw[1+3*(i-1):3+3*(i-1)] = -(Jlist[i].ω0^2)*v[1+3*(i-1):3+3*(i-1)] -Jlist[i].Γ*w[1+3*(i-1):3+3*(i-1)]
for j in 1:N
dw[1+3*(i-1):3+3*(i-1)] += -Jlist[i].α*bcoupling[i][j]*s[1+3*(j-1):3+3*(j-1)]
end
end
end
prob = ODEProblem(f, u0, tspan, (dualcache(zeros(3)), dualcache(zeros(N,3))))
Cω = [matrix[i].C for i in 1:M]
Cω2 = [matrix[i].C*transpose(matrix[i].C) for i in 1:M]

b, Cb, Cω2v, Beff = dualcache(zeros(3)), dualcache(zeros(3)), dualcache(zeros(3)), dualcache(zeros(N, 3))
params = (N, M, invsqrtS0, Bext, JH, Jlist, Cω, Cω2, bfield, bcoupling, b, Cb, Cω2v, Beff)
prob = ODEProblem(_spin_time_step!, u0, tspan, params)

condition(u, t, integrator) = true
function affect!(integrator) # projection
for n in 1:N
Expand All @@ -102,15 +70,93 @@ function diffeqsolver(
end
cb = DiscreteCallback(condition, affect!, save_positions=(false,false))
skwargs = projection ? (callback=cb,) : NamedTuple()

if save_fields
save_idxs = 1:(3*N+6*M)
else
save_idxs = 1:3*N
end

sol = solve(prob, alg; abstol=atol, reltol=rtol, maxiters=Int(1e9), save_idxs=save_idxs, saveat=saveat, kwargs..., skwargs...)

return sol
end

function _spin_time_step!(
du,
u,
(N, M, invsqrtS0, Bext, JH, Jlist, Cω, Cω2, bfields, bcoupling, b, Cb, Cω2v, Beff),
t
)
b = get_tmp(b, u)
Cb = get_tmp(Cb, u)
Cω2v = get_tmp(Cω2v, u)
Beff = get_tmp(Beff, u)

s = @view u[1:3*N]
v = @view u[1+3*N:3*N+3*M]
w = @view u[1+3*N+3*M:3*N+6*M]
ds = @view du[1:3*N]
dv = @view du[1+3*N:3*N+3*M]
dw = @view du[1+3*N+3*M:3*N+6*M]

for i in 1:N
Beff[i, :] .= Bext
end

for j in 1:M
vj = @view v[1+(j-1)*3:3+(j-1)*3]

for k in 1:3
b[k] = bfields[j][k](t)
end

mul!(Cb, Cω[j], b)
lmul!(invsqrtS0, Cb)
mul!(Cω2v, Cω2[j], vj)
Cb .+= Cω2v

for i in 1:N
Beffi = @view Beff[i, :]

@. Beffi += bcoupling[j][i]*Cb
end
end

for i in 1:N
si = @view s[1+(i-1)*3:3+(i-1)*3]
dsi = @view ds[1+(i-1)*3:3+(i-1)*3]

dsi[1] = -(si[2]*Beff[i,3] - si[3]*Beff[i,2])
dsi[2] = -(si[3]*Beff[i,1] - si[1]*Beff[i,3])
dsi[3] = -(si[1]*Beff[i,2] - si[2]*Beff[i,1])

for j in 1:N
sj = @view s[1+(j-1)*3:3+(j-1)*3]

dsi[1] += -(si[2]*JH[i,j]*sj[3] - si[3]*JH[i,j]*sj[2])
dsi[2] += -(si[3]*JH[i,j]*sj[1] - si[1]*JH[i,j]*sj[3])
dsi[3] += -(si[1]*JH[i,j]*sj[2] - si[2]*JH[i,j]*sj[1])
end
end

dv .= w

for i in 1:M
vi = @view v[1+(i-1)*3:3+(i-1)*3]
wi = @view w[1+(i-1)*3:3+(i-1)*3]
dwi = @view dw[1+(i-1)*3:3+(i-1)*3]

@. dwi = -(Jlist[i].ω0^2)*vi - Jlist[i].Γ*wi

for j in 1:N
sj = @view s[1+3*(j-1):3+3*(j-1)]

@. dwi += -Jlist[i].α*bcoupling[i][j]*sj
end
end
end

"""
function diffeqsolver(s0, tspan, J::LorentzianSD, Jshared::LorentzianSD, bfields, bfieldshared, matrix::Coupling; JH=zero(I), S0=1/2, Bext=[0, 0, 1], saveat=[], save_fields=false, projection=true, alg=Tsit5(), atol=1e-3, rtol=1e-3)

Expand Down
Loading