Skip to content

Commit

Permalink
Use _pullback inside rules instead of pullback
Browse files Browse the repository at this point in the history
This will give us more flexibility to implement internal changes such as
#603 without changing the user-facing API.
  • Loading branch information
ToucheSir committed Mar 5, 2023
1 parent 4e342fe commit 49a1184
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
Expand Down
44 changes: 19 additions & 25 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ end
enumerate(xs), back
end

@adjoint Iterators.Filter(f, x) = pullback(filter, f, collect(x))
function _pullback(cx::AContext, ::Type{<:Iterators.Filter}, f, x)
res, back = _pullback(cx, filter, f, collect(x))
return res, back unthunk_tangent
end

_ndims(::Base.HasShape{d}) where {d} = d
_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1
Expand Down Expand Up @@ -321,18 +324,12 @@ end
end
end

@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end

@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
sum(xs, dims = dims), Δ -> (nothing,)
end

function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs)
y, ȳ -> (nothing, back(ȳ)...)
return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs)
end

@adjoint real(x::AbstractArray) = real(x), r̄ -> (real(r̄),)
Expand All @@ -357,8 +354,14 @@ function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
end
_kron(a::AbstractVector, b::AbstractVector) = vec(_kron(reshape(a, :, 1), reshape(b, :, 1)))

@adjoint kron(a::AbstractMatrix, b::AbstractMatrix) = pullback(_kron, a, b)
@adjoint kron(a::AbstractVector, b::AbstractVector) = pullback(_kron, a, b)
function _pullback(cx::AContext, ::typeof(kron), a::AbstractVector, b::AbstractVector)
res, back = _pullback(cx, _kron, a, b)
return res, back unthunk_tangent
end
function _pullback(cx::AContext, ::typeof(kron), a::AbstractMatrix, b::AbstractMatrix)
res, back = _pullback(cx, _kron, a, b)
return res, back unthunk_tangent
end

@adjoint logabsdet(xs::AbstractMatrix) = logabsdet(xs), Δ -> (Δ[1] * inv(xs)',)

Expand Down Expand Up @@ -432,15 +435,6 @@ end
@adjoint LinearAlgebra.UnitLowerTriangular(A) = UnitLowerTriangular(A), Δ->(UnitLowerTriangular(Δ)-I,)
@adjoint LinearAlgebra.UnitUpperTriangular(A) = UnitUpperTriangular(A), Δ->(UnitUpperTriangular(Δ)-I,)

# This is basically a hack while we don't have a working `ldiv!`.
@adjoint function \(A::Cholesky, B::AbstractVecOrMat)
Y, back = Zygote.pullback((U, B)->U \ (U' \ B), A.U, B)
return Y, function(Ȳ)
Ā_factors, B̄ = back(Ȳ)
return ((uplo=nothing, info=nothing, factors=Ā_factors), B̄)
end
end

function _symmetric_back(Δ, uplo)
L, U, D = LowerTriangular(Δ), UpperTriangular(Δ), Diagonal(Δ)
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
Expand Down Expand Up @@ -572,14 +566,14 @@ _hermsympow(A::Hermitian, p::Integer) = A^p

@adjoint function _hermsympow(A::Hermitian, p::Integer)
if p < 0
B, back = Zygote.pullback(A->Base.power_by_squaring(inv(A), -p), A)
B, back = _pullback(__context__, A -> Base.power_by_squaring(inv(A), -p), A)
else
B, back = Zygote.pullback(A->Base.power_by_squaring(A, p), A)
B, back = _pullback(__context__, A -> Base.power_by_squaring(A, p), A)
end
Ω = Hermitian(_realifydiag!(B))
return Ω, function (Ω̄)
= _hermitian_back(Ω̄, 'U')
Ā = back(B̄)[1]
Ā = last(back(B̄))
return (Ā, nothing)
end
end
Expand Down Expand Up @@ -812,8 +806,8 @@ end
# =======================

@adjoint function broadcasted(op, r::AbstractFill{<:Real})
y, _back = Zygote.pullback(op, getindex_value(r))
back::AbstractFill) = (nothing, Fill(_back(getindex_value(Δ))[1], size(r)))
back::AbstractArray) = (nothing, getindex.(_back.(Δ), 1))
y, _back = _pullback(__context__, op, getindex_value(r))
back::AbstractFill) = (nothing, Fill(last(_back(getindex_value(Δ))), size(r)))
back::AbstractArray) = (nothing, last.(_back.(Δ)))
return Fill(y, size(r)), back
end
7 changes: 3 additions & 4 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,9 @@ end
# For merge between NamedTuple and Dict, we will just convert the Dict to a NamedTuple.
# and then call `pullback`, which should overall be pretty efficient code generated,
# and it avoids trying to AD the problematic generic `merge(::NamedTuple, ::iter)` method which uses `push!`.
if VERSION >= v"1.6"
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, NamedTuple(dict))
else
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
function _pullback(cx::AContext, ::typeof(merge), a::NamedTuple, b::Dict{Symbol})
res, back = _pullback(cx, merge, a, NamedTuple(b))
return res, back unthunk_tangent
end

# Keyword arguments pretend to be a Dict, but are secretly wrapping a NamedTuple.
Expand Down
17 changes: 12 additions & 5 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
# Utilities
# =========

# ChainRules already marks this non-differentiable,
# But inference can still give up because of the Zygote -> CR wrapper layer
@nograd Broadcast.combine_styles
# ChainRules already marks this non-differentiable,# But inference can still give up because of the Zygote -> CR wrapper layer.
# This has been desugared from the (deprecated) @nograd macro.
@inline function Zygote._pullback(::AContext, ::typeof(Broadcast.combine_styles), args...)
dargs = ntuple(_ -> nothing, length(args) + 1)
combine_styles_pullback(_) = dargs
return Broadcast.combine_styles(args...), combine_styles_pullback
end

accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)

Expand Down Expand Up @@ -358,9 +362,12 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve

# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
function _pullback(cx::AContext, ::Core.kwftype(typeof(sum)), kws, ::typeof(sum), f,
xs::AbstractGPUArray)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
res, back = _pullback(cx, (f, xs) -> sum(f.(xs); kws...), f, xs)
sum_gpuarray_pullback(Δ) = last(back(unthunk_tangent(Δ)))
return res, sum_gpuarray_pullback
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
Expand Down
18 changes: 14 additions & 4 deletions src/lib/distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,32 @@ end

_sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d)

@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
X::AbstractMatrix, Y::AbstractMatrix)
# Modify the forwards-pass slightly to ensure stability on the reverse.
dims = kws.dims
function _pairwise_euclidean(sqdist::SqEuclidean, X, Y)
D2 = pairwise(sqdist, X, Y; dims=dims)
δ = eps(eltype(D2))
return _sqrt_if_positive.(D2, δ)
end
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
return res, pairwise_Euclidean_pullback
end

@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix; dims=2)
function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
X::AbstractMatrix)
# Modify the forwards-pass slightly to ensure stability on the reverse.
dims = kws.dims
function _pairwise_euclidean(sqdist::SqEuclidean, X)
D2 = pairwise(sqdist, X; dims=dims)
δ = eps(eltype(D2))
return _sqrt_if_positive.(D2, δ)
end
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X)
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
return res, pairwise_Euclidean_pullback
end
20 changes: 12 additions & 8 deletions src/lib/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,32 +137,36 @@ end

# Use this to allow second derivatives -- this is forward-over-forward,
# see https://github.com/FluxML/Zygote.jl/issues/769 for a forward-over-reverse proposal
@adjoint function ForwardDiff.gradient(f, x)
function _pullback(cx::AContext, ::typeof(ForwardDiff.gradient), f, x)
F = typeof(f)
Base.issingletontype(F) || @warn """`ForwardDiff.gradient(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.gradient(f, x), x)
return res, back unthunk_tangent
end

@adjoint function ForwardDiff.jacobian(f::F, x) where F
function _pullback(cx::AContext, ::typeof(ForwardDiff.jacobian), f::F, x) where F
Base.issingletontype(F) || @warn """`ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
return res, back unthunk_tangent
end

@adjoint function ForwardDiff.derivative(f::F, x) where F
function _pullback(cx::AContext, ::typeof(ForwardDiff.derivative), f::F, x) where F
Base.issingletontype(F) || @warn """`ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.derivative(f, x), x)
return res, back unthunk_tangent
end

@adjoint function ForwardDiff.hessian(f::F, x) where F
function _pullback(cx::AContext, ::typeof(ForwardDiff.hessian), f::F, x) where F
Base.issingletontype(F) || @warn """`ForwardDiff.hessian(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.hessian(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.hessian(f, x), x)
return res, back unthunk_tangent
end

1 change: 0 additions & 1 deletion test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using CUDA
using Zygote: Grads
using LinearAlgebra
using Random: randn!
import FiniteDifferences
CUDA.allowscalar(false)
Expand Down
2 changes: 0 additions & 2 deletions test/forward/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ end == 0
@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1]
@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real

using LinearAlgebra

@test D(3) do x
A = zeros(5, 5)
B = zeros(5, 5)
Expand Down
1 change: 0 additions & 1 deletion test/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ChainRulesTestUtils
using LinearAlgebra
using Zygote: ZygoteRuleConfig, _pullback

# issue 897
Expand Down
2 changes: 0 additions & 2 deletions test/lib/base.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra;

@testset "base.jl" begin
@testset "Dict getindex with implicit params" begin
d = Dict{String, Vector{Float64}}("key"=>ones(4))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Zygote, Test
using Zygote, Test, LinearAlgebra
using Zygote: gradient, ZygoteRuleConfig
using CUDA
using CUDA: has_cuda
Expand Down
1 change: 0 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using LinearAlgebra
using ForwardDiff
using Zygote: hessian_dual, hessian_reverse

Expand Down

0 comments on commit 49a1184

Please sign in to comment.