Skip to content

Commit

Permalink
Merge pull request #28 from quantum-exeter/spin_dyn_speedup
Browse files Browse the repository at this point in the history
Spin dynamics speedup
  • Loading branch information
cerisola authored Sep 16, 2024
2 parents 64aaa45 + d4459a2 commit fa109c3
Showing 1 changed file with 86 additions and 40 deletions.
126 changes: 86 additions & 40 deletions src/Dynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,49 +50,17 @@ function diffeqsolver(
if length(Jlist) != length(matrix) || length(Jlist) != length(bcoupling)
throw(DimensionMismatch("The dimension of Jlist, bcoupling, and matrix must match."))
end

M = length(Jlist)
u0 = [s0; zeros(6*M)]
invsqrtS0 = 1/sqrt(S0)
Cω2 = []
b = []
for i in 1:M
push!(Cω2, matrix[i].C*transpose(matrix[i].C))
push!(b, t -> invsqrtS0*matrix[i].C*[bfield[i][1](t), bfield[i][2](t), bfield[i][3](t)]);
end
function f(du, u, (Cω2v, Beff), t)
Cω2v = get_tmp(Cω2v, u)
Beff = get_tmp(Beff, u)
s = @view u[1:3*N]
v = @view u[1+3*N:3*N+3*M]
w = @view u[1+3*N+3*M:3*N+6*M]
ds = @view du[1:3*N]
dv = @view du[1+3*N:3*N+3*M]
dw = @view du[1+3*N+3*M:3*N+6*M]
for i in 1:N
Beff[i, :] .= Bext
for j in 1:M
Beff[i, :] .+= bcoupling[j][i]*(b[j](t) + mul!(Cω2v, Cω2[j], v[1+(j-1)*3:3+(j-1)*3]))
end
end
for i in 1:N
ds[1+(i-1)*3] = -(s[2+(i-1)*3]*Beff[i,3]-s[3+(i-1)*3]*Beff[i,2])
ds[2+(i-1)*3] = -(s[3+(i-1)*3]*Beff[i,1]-s[1+(i-1)*3]*Beff[i,3])
ds[3+(i-1)*3] = -(s[1+(i-1)*3]*Beff[i,2]-s[2+(i-1)*3]*Beff[i,1])
for j in 1:N
ds[1+(i-1)*3] += -(s[2+(i-1)*3]*JH[i,j]*s[3+(j-1)*3]-s[3+(i-1)*3]*JH[i,j]*s[2+(j-1)*3])
ds[2+(i-1)*3] += -(s[3+(i-1)*3]*JH[i,j]*s[1+(j-1)*3]-s[1+(i-1)*3]*JH[i,j]*s[3+(j-1)*3])
ds[3+(i-1)*3] += -(s[1+(i-1)*3]*JH[i,j]*s[2+(j-1)*3]-s[2+(i-1)*3]*JH[i,j]*s[1+(j-1)*3])
end
end
dv .= w
for i in 1:M
dw[1+3*(i-1):3+3*(i-1)] = -(Jlist[i].ω0^2)*v[1+3*(i-1):3+3*(i-1)] -Jlist[i].Γ*w[1+3*(i-1):3+3*(i-1)]
for j in 1:N
dw[1+3*(i-1):3+3*(i-1)] += -Jlist[i].α*bcoupling[i][j]*s[1+3*(j-1):3+3*(j-1)]
end
end
end
prob = ODEProblem(f, u0, tspan, (dualcache(zeros(3)), dualcache(zeros(N,3))))
= [matrix[i].C for i in 1:M]
Cω2 = [matrix[i].C*transpose(matrix[i].C) for i in 1:M]

b, Cb, Cω2v, Beff = dualcache(zeros(3)), dualcache(zeros(3)), dualcache(zeros(3)), dualcache(zeros(N, 3))
params = (N, M, invsqrtS0, Bext, JH, Jlist, Cω, Cω2, bfield, bcoupling, b, Cb, Cω2v, Beff)
prob = ODEProblem(_spin_time_step!, u0, tspan, params)

condition(u, t, integrator) = true
function affect!(integrator) # projection
for n in 1:N
Expand All @@ -102,15 +70,93 @@ function diffeqsolver(
end
cb = DiscreteCallback(condition, affect!, save_positions=(false,false))
skwargs = projection ? (callback=cb,) : NamedTuple()

if save_fields
save_idxs = 1:(3*N+6*M)
else
save_idxs = 1:3*N
end

sol = solve(prob, alg; abstol=atol, reltol=rtol, maxiters=Int(1e9), save_idxs=save_idxs, saveat=saveat, kwargs..., skwargs...)

return sol
end

function _spin_time_step!(
du,
u,
(N, M, invsqrtS0, Bext, JH, Jlist, Cω, Cω2, bfields, bcoupling, b, Cb, Cω2v, Beff),
t
)
b = get_tmp(b, u)
Cb = get_tmp(Cb, u)
Cω2v = get_tmp(Cω2v, u)
Beff = get_tmp(Beff, u)

s = @view u[1:3*N]
v = @view u[1+3*N:3*N+3*M]
w = @view u[1+3*N+3*M:3*N+6*M]
ds = @view du[1:3*N]
dv = @view du[1+3*N:3*N+3*M]
dw = @view du[1+3*N+3*M:3*N+6*M]

for i in 1:N
Beff[i, :] .= Bext
end

for j in 1:M
vj = @view v[1+(j-1)*3:3+(j-1)*3]

for k in 1:3
b[k] = bfields[j][k](t)
end

mul!(Cb, Cω[j], b)
lmul!(invsqrtS0, Cb)
mul!(Cω2v, Cω2[j], vj)
Cb .+= Cω2v

for i in 1:N
Beffi = @view Beff[i, :]

@. Beffi += bcoupling[j][i]*Cb
end
end

for i in 1:N
si = @view s[1+(i-1)*3:3+(i-1)*3]
dsi = @view ds[1+(i-1)*3:3+(i-1)*3]

dsi[1] = -(si[2]*Beff[i,3] - si[3]*Beff[i,2])
dsi[2] = -(si[3]*Beff[i,1] - si[1]*Beff[i,3])
dsi[3] = -(si[1]*Beff[i,2] - si[2]*Beff[i,1])

for j in 1:N
sj = @view s[1+(j-1)*3:3+(j-1)*3]

dsi[1] += -(si[2]*JH[i,j]*sj[3] - si[3]*JH[i,j]*sj[2])
dsi[2] += -(si[3]*JH[i,j]*sj[1] - si[1]*JH[i,j]*sj[3])
dsi[3] += -(si[1]*JH[i,j]*sj[2] - si[2]*JH[i,j]*sj[1])
end
end

dv .= w

for i in 1:M
vi = @view v[1+(i-1)*3:3+(i-1)*3]
wi = @view w[1+(i-1)*3:3+(i-1)*3]
dwi = @view dw[1+(i-1)*3:3+(i-1)*3]

@. dwi = -(Jlist[i].ω0^2)*vi - Jlist[i].Γ*wi

for j in 1:N
sj = @view s[1+3*(j-1):3+3*(j-1)]

@. dwi += -Jlist[i].α*bcoupling[i][j]*sj
end
end
end

"""
function diffeqsolver(s0, tspan, J::LorentzianSD, Jshared::LorentzianSD, bfields, bfieldshared, matrix::Coupling; JH=zero(I), S0=1/2, Bext=[0, 0, 1], saveat=[], save_fields=false, projection=true, alg=Tsit5(), atol=1e-3, rtol=1e-3)
Expand Down

0 comments on commit fa109c3

Please sign in to comment.