Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Fix ChainRules rrules tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 17, 2024
1 parent 1952c9c commit 544877b
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
47 changes: 47 additions & 0 deletions ext/QrochetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,51 @@ Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

Base.zero(x::Dict{Site,Symbol}) = x

ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super = ProjectTo(Quantum(x)))
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super.super), Δ.boundary)

# NOTE edge case: `Product` has no `boundary`. should it?
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super.super))

ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super = ẋ)

Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super)
Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz}
y = T(x)
y, Ansatz_pullback
end

function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
T(x, boundary), Tangent{T}(; super = ẋ, boundary = NoTangent())
end

Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent())
Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
T(x, boundary), Ansatz_boundary_pullback
end

# Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn)))
# Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ))
# function ChainRulesCore.rrule(
# ::Type{T},
# socket::Qrochet.Socket,
# boundary::Qrochet.Boundary,
# arrays;
# kwargs...,
# ) where {T<:Ansatz}
# y = T(socket, boundary, arrays; kwargs...)
# y, Ansatz_from_arrays_pullback
# end

copy_pullback(ȳ) = (NoTangent(), ȳ)
copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ)
function ChainRulesCore.rrule(::typeof(copy), x::Quantum)
y = copy(x)
y, copy_pullback
end

end
32 changes: 30 additions & 2 deletions test/integration/ChainRulesCore_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,35 @@
using ChainRulesTestUtils

@testset "Quantum" begin
test_frule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i))
test_rrule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i))
test_frule(Quantum, TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i))
test_rrule(Quantum, TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i))
end

@testset "Ansatz" begin
@testset "Product" begin
tn = TensorNetwork([Tensor(ones(2), [:i]), Tensor(ones(2), [:j]), Tensor(ones(2), [:k])])
qtn = Quantum(tn, Dict([site"1" => :i, site"2" => :j, site"3" => :k]))

test_frule(Product, qtn)
test_rrule(Product, qtn)
end

@testset "Chain" begin
tn = Chain(State(), Open(), [ones(2, 2), ones(2, 2, 2), ones(2, 2)])
# test_frule(Chain, Quantum(tn), Open())
test_rrule(Chain, Quantum(tn), Open())

tn = Chain(State(), Periodic(), [ones(2, 2, 2), ones(2, 2, 2), ones(2, 2, 2)])
# test_frule(Chain, Quantum(tn), Periodic())
test_rrule(Chain, Quantum(tn), Periodic())

tn = Chain(Operator(), Open(), [ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)])
# test_frule(Chain, Quantum(tn), Open())
test_rrule(Chain, Quantum(tn), Open())

tn = Chain(Operator(), Periodic(), [ones(2, 2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2, 2)])
# test_frule(Chain, Quantum(tn), Periodic())
test_rrule(Chain, Quantum(tn), Periodic())
end
end
end

0 comments on commit 544877b

Please sign in to comment.