diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index bf5517a..d567b09 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -559,3 +559,19 @@ function expect(ψ::Chain, observables) return contract(tn) end + +overlap(a::Chain, b::Chain) = overlap(socket(a), a, socket(b), b) + +# TODO fix optimal path +function overlap(::State, a::Chain, ::State, b::Chain) + @assert issetequal(sites(a), sites(b)) "Ansatzes must have the same sites" + + b = copy(b) + b = @reindex! outputs(a) => outputs(b) + + tn = merge(TensorNetwork(a), TensorNetwork(b')) + return contract(tn) +end + +overlap(a::Product, b::Chain) = overlap(convert(Chain, a), b) +overlap(a::Chain, b::Product) = overlap(a, convert(Chain, b)) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index c5f613f..21b89b1 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -54,3 +54,13 @@ function LinearAlgebra.normalize!(::Union{State,Operator}, tn::Product, p::Real) end tn end + +overlap(a::Product, b::Product) = overlap(socket(a), a, socket(b), b) + +function overlap(::State, a::Product, ::State, b::Product) + @assert issetequal(sites(a), sites(b)) "Ansatzes must have the same sites" + + mapreduce(*, zip(tensors(a), tensors(b))) do (ta, tb) + dot(parent(ta), conj(parent(tb))) + end +end diff --git a/src/Qrochet.jl b/src/Qrochet.jl index 1c88052..2f5f365 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -24,9 +24,8 @@ export MPS, pMPS, MPO, pMPO export leftindex, rightindex, isleftcanonical, isrightcanonical export canonize_site, canonize_site!, truncate! export canonize, canonize!, mixed_canonize, mixed_canonize! -export expect -export evolve! +export evolve!, expect, overlap # reexports from Tenet using Tenet