diff --git a/ext/QrochetChainRulesCoreExt.jl b/ext/QrochetChainRulesCoreExt.jl index 12d8f92..3db9cd9 100644 --- a/ext/QrochetChainRulesCoreExt.jl +++ b/ext/QrochetChainRulesCoreExt.jl @@ -24,4 +24,51 @@ Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ)) ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback +Base.zero(x::Dict{Site,Symbol}) = x + +ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super = ProjectTo(Quantum(x))) +(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super(Δ.super), Δ.boundary) + +# NOTE edge case: `Product` has no `boundary`. should it? +(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super(Δ.super)) + +ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super = ẋ) + +Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super) +Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ)) +function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz} + y = T(x) + y, Ansatz_pullback +end + +function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz} + T(x, boundary), Tangent{T}(; super = ẋ, boundary = NoTangent()) +end + +Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent()) +Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ)) +function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz} + T(x, boundary), Ansatz_boundary_pullback +end + +# Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn))) +# Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ)) +# function ChainRulesCore.rrule( +# ::Type{T}, +# socket::Qrochet.Socket, +# boundary::Qrochet.Boundary, +# arrays; +# kwargs..., +# ) where {T<:Ansatz} +# y = T(socket, boundary, arrays; kwargs...) +# y, Ansatz_from_arrays_pullback +# end + +copy_pullback(ȳ) = (NoTangent(), ȳ) +copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ) +function ChainRulesCore.rrule(::typeof(copy), x::Quantum) + y = copy(x) + y, copy_pullback +end + end diff --git a/test/integration/ChainRulesCore_test.jl b/test/integration/ChainRulesCore_test.jl index 619a189..01ab80f 100644 --- a/test/integration/ChainRulesCore_test.jl +++ b/test/integration/ChainRulesCore_test.jl @@ -4,7 +4,35 @@ using ChainRulesTestUtils @testset "Quantum" begin - test_frule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) - test_rrule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + test_frule(Quantum, TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + test_rrule(Quantum, TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + end + + @testset "Ansatz" begin + @testset "Product" begin + tn = TensorNetwork([Tensor(ones(2), [:i]), Tensor(ones(2), [:j]), Tensor(ones(2), [:k])]) + qtn = Quantum(tn, Dict([site"1" => :i, site"2" => :j, site"3" => :k])) + + test_frule(Product, qtn) + test_rrule(Product, qtn) + end + + @testset "Chain" begin + tn = Chain(State(), Open(), [ones(2, 2), ones(2, 2, 2), ones(2, 2)]) + # test_frule(Chain, Quantum(tn), Open()) + test_rrule(Chain, Quantum(tn), Open()) + + tn = Chain(State(), Periodic(), [ones(2, 2, 2), ones(2, 2, 2), ones(2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + test_rrule(Chain, Quantum(tn), Periodic()) + + tn = Chain(Operator(), Open(), [ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Open()) + test_rrule(Chain, Quantum(tn), Open()) + + tn = Chain(Operator(), Periodic(), [ones(2, 2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + test_rrule(Chain, Quantum(tn), Periodic()) + end end end