From 58cfe835703f9720b0a3f84d0be94d6b2d5659b7 Mon Sep 17 00:00:00 2001 From: Tim Pokart Date: Wed, 13 Nov 2024 16:58:36 +0100 Subject: [PATCH 1/2] Made sample type agnostic and CUDA compatible --- src/mps.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mps.jl b/src/mps.jl index 567319d..32a16f2 100644 --- a/src/mps.jl +++ b/src/mps.jl @@ -655,6 +655,8 @@ function sample(rng::AbstractRNG, m::MPS) error("sample: MPS is not normalized, norm=$(norm(m[1]))") end + ElT = promote_itensor_eltype(m) + result = zeros(Int, N) A = m[1] @@ -664,16 +666,16 @@ function sample(rng::AbstractRNG, m::MPS) # Compute the probability of each state # one-by-one and stop when the random # number r is below the total prob so far - pdisc = 0.0 + pdisc = zero(real(ElT)) r = rand(rng) # Will need n,An, and pn below n = 1 An = ITensor() - pn = 0.0 + pn = zero(real(ElT)) while n <= d projn = ITensor(s) - projn[s => n] = 1.0 - An = A * dag(projn) + projn[s => n] = one(ElT) + An = A * dag(adapt(datatype(A), projn)) pn = real(scalar(dag(An) * An)) pdisc += pn (r < pdisc) && break @@ -682,7 +684,7 @@ function sample(rng::AbstractRNG, m::MPS) result[j] = n if j < N A = m[j + 1] * An - A *= (1.0 / sqrt(pn)) + A *= (one(ElT) / sqrt(pn)) end end return result From f56dd19c93ad56784f71e293c7ee500f856c4b98 Mon Sep 17 00:00:00 2001 From: Tim Pokart Date: Fri, 15 Nov 2024 14:08:42 +0100 Subject: [PATCH 2/2] Replaced occurances of `promote_itensor_eltype` by `scalartype` --- src/mps.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mps.jl b/src/mps.jl index 32a16f2..de71318 100644 --- a/src/mps.jl +++ b/src/mps.jl @@ -655,7 +655,7 @@ function sample(rng::AbstractRNG, m::MPS) error("sample: MPS is not normalized, norm=$(norm(m[1]))") end - ElT = promote_itensor_eltype(m) + ElT = scalartype(m) result = zeros(Int, N) A = m[1] @@ -751,7 +751,7 @@ function correlation_matrix( end_site = last(sites) N = length(psi) - ElT = promote_itensor_eltype(psi) + ElT = scalartype(psi) s = siteinds(psi) Op1 = _Op1 #make copies into which we can insert "F" string operators, and then restore. @@ -985,7 +985,7 @@ updens, dndens = expect(psi, "Nup", "Ndn") # pass more than one operator function expect(psi::MPS, ops; sites=1:length(psi), site_range=nothing) psi = copy(psi) N = length(psi) - ElT = promote_itensor_eltype(psi) + ElT = scalartype(psi) s = siteinds(psi) if !isnothing(site_range)