Skip to content

Commit

Permalink
Merge pull request #90 from JuliaGPU/jps/forwarddiff
Browse files Browse the repository at this point in the history
Add ForwardDiff integrations
  • Loading branch information
jpsamaroo authored Jan 28, 2021
2 parents e9bc3d8 + e464866 commit a83bca6
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 54 deletions.
28 changes: 5 additions & 23 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ version = "0.5.0"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "87491f7d03ae1b423a353aff99cf61a45e3c993a"
git-tree-sha1 = "4146c39f29be88c3f0cef732f86e5ab640d2e22d"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.1.0"
version = "3.1.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand Down Expand Up @@ -77,9 +77,7 @@ version = "6.2.0"

[[GPUCompiler]]
deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "52477f9ef0a4b7da2fcb8671a500be054db29127"
repo-rev = "master"
repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"
git-tree-sha1 = "6ab1bc883bc13919c25acc0fe0dea707f61ae39c"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.9.1"

Expand All @@ -98,17 +96,13 @@ deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
deps = ["Libdl"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
deps = ["NetworkOptions", "Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -129,10 +123,6 @@ version = "0.5.6"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

Expand Down Expand Up @@ -226,11 +216,3 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -33,10 +34,11 @@ julia = "1.6"
[extras]
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["FFTW", "FillArrays", "InteractiveUtils", "Pkg", "SpecialFunctions", "Test"]
test = ["FFTW", "FillArrays", "ForwardDiff", "InteractiveUtils", "Pkg", "SpecialFunctions", "Test"]
5 changes: 4 additions & 1 deletion src/AMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ include("array.jl")
#include("subarray.jl")
#include("utils.jl")
#include("indexing.jl")
#include("broadcast.jl")
include("broadcast.jl")
#include("matmul.jl")
#include("mapreduce.jl")
#include("gpuarray_interface.jl")
Expand Down Expand Up @@ -164,6 +164,9 @@ function __init__()

# Load optional OpenCL integrations
@require OpenCL="08131aa3-fb12-5dee-8b74-c09406e224a2" include("opencl.jl")

# Load optional @requires packages
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
end

end # module
17 changes: 0 additions & 17 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,6 @@ ROCArray{T,N}(xs::ROCArray{T,N}) where {T,N} = xs
Base.convert(::Type{T}, x::T) where T <: ROCArray = x


## broadcast

using Base.Broadcast: BroadcastStyle, Broadcasted

struct ROCArrayStyle{N} <: AbstractGPUArrayStyle{N} end
ROCArrayStyle(::Val{N}) where N = ROCArrayStyle{N}()
ROCArrayStyle{M}(::Val{N}) where {N,M} = ROCArrayStyle{N}()

BroadcastStyle(::Type{ROCArray{T,N}}) where {T,N} = ROCArrayStyle{N}()

# Allocating the output container
Base.similar(bc::Broadcasted{ROCArrayStyle{N}}, ::Type{T}) where {N,T} =
similar(ROCArray{T}, axes(bc))
Base.similar(bc::Broadcasted{ROCArrayStyle{N}}, ::Type{T}, dims...) where {N,T} =
ROCArray{T}(undef, dims...)


## memory operations

function Base.copyto!(dest::Array{T}, d_offset::Integer,
Expand Down
112 changes: 112 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# broadcasting

using Base.Broadcast: BroadcastStyle, Broadcasted

struct ROCArrayStyle{N} <: AbstractGPUArrayStyle{N} end
ROCArrayStyle(::Val{N}) where N = ROCArrayStyle{N}()
ROCArrayStyle{M}(::Val{N}) where {N,M} = ROCArrayStyle{N}()

BroadcastStyle(::Type{ROCArray{T,N}}) where {T,N} = ROCArrayStyle{N}()

Base.similar(bc::Broadcasted{ROCArrayStyle{N}}, ::Type{T}) where {N,T} =
similar(ROCArray{T}, axes(bc))

Base.similar(bc::Broadcasted{ROCArrayStyle{N}}, ::Type{T}, dims) where {N,T} =
ROCArray{T}(undef, dims)


## replace base functions with libdevice alternatives

rocfunc(f) = f
rocfunc(::Type{T}) where T = (x...) -> T(x...) # broadcasting type ctors isn't GPU compatible

Broadcast.broadcasted(::ROCArrayStyle{N}, f, args...) where {N} =
Broadcasted{ROCArrayStyle{N}}(rocfunc(f), args, nothing)

const device_intrinsics = :[
cos, cospi, sin, sinpi, tan, acos, asin, atan,
cosh, sinh, tanh, acosh, asinh, atanh, angle,
log, log10, log1p, log2, logb, ilogb,
exp, exp2, exp10, expm1, ldexp,
erf, erfinv, erfc, erfcinv, erfcx,
brev, clz, ffs, byte_perm, popc,
isfinite, isinf, isnan, nearbyint,
nextafter, signbit, copysign, abs,
sqrt, rsqrt, cbrt, rcbrt, pow,
ceil, floor, saturate,
lgamma, tgamma,
j0, j1, jn, y0, y1, yn,
normcdf, normcdfinv, hypot,
fma, sad, dim, mul24, mul64hi, hadd, rhadd, scalbn].args

for f in device_intrinsics
isdefined(Base, f) || continue
@eval rocfunc(::typeof(Base.$f)) = $f
end

# broadcast ^

rocliteral_pow(::typeof(^), x::T, ::Val{0}) where {T<:Real} = one(x)
rocliteral_pow(::typeof(^), x::T, ::Val{1}) where {T<:Real} = x
rocliteral_pow(::typeof(^), x::T, ::Val{2}) where {T<:Real} = x * x
rocliteral_pow(::typeof(^), x::T, ::Val{3}) where {T<:Real} = x * x * x
rocliteral_pow(::typeof(^), x::T, ::Val{p}) where {T<:Real,p} = pow(x, Int32(p))

rocfunc(::typeof(Base.literal_pow)) = rocliteral_pow
rocfunc(::typeof(Base.:(^))) = pow

using MacroTools

const _rocfuncs = [copy(device_intrinsics); :^]
rocfuncs() = (global _rocfuncs; _rocfuncs)

_rocint(x::Int) = Int32(x)
_rocint(x::Expr) = x.head == :call && x.args[1] == :Int32 && x.args[2] isa Int ? Int32(x.args[2]) : x
_rocint(x) = x

function _rocpowliteral(x::Expr)
if x.head == :call && x.args[1] == :(AMDGPU.rocfunc(^)) && x.args[3] isa Int32
num = x.args[3]
if 0 <= num <= 3
sym = gensym(:x)
new_x = Expr(:block, :($sym = $(x.args[2])))

if iszero(num)
push!(new_x.args, :(one($sym)))
else
unroll = Expr(:call, :*)
for x = one(num):num
push!(unroll.args, sym)
end
push!(new_x.args, unroll)
end

x = new_x
end
end
x
end
_rocpowliteral(x) = x

function replace_device(ex)
global _rocfuncs
MacroTools.postwalk(ex) do x
x = x in _rocfuncs ? :(AMDGPU.rocfunc($x)) : x
x = _rocint(x)
x = _rocpowliteral(x)
x
end
end

macro rocfunc(ex)
global _rocfuncs
def = MacroTools.splitdef(ex)
f = def[:name]
def[:name] = Symbol(:cu, f)
def[:body] = replace_device(def[:body])
push!(_rocfuncs, f)
quote
$(esc(MacroTools.combinedef(def)))
AMDGPU.rocfunc(::typeof($(esc(f)))) = $(esc(def[:name]))
end
end
15 changes: 6 additions & 9 deletions src/device/gcn/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ for jltype in (
:erf, :erfinv, :erfc, :erfcinv, :erfcx,
# TODO: :brev, :clz, :ffs, :byte_perm, :popc,
:isnormal, :nearbyint, :nextafter,
:pow, :pown, :powr,
:tgamma, :j0, :j1, :y0, :y1,
); inp_args=(jltype,), out_arg=jltype))

Expand All @@ -32,20 +31,18 @@ for jltype in (
push!(MATH_INTRINSICS, GCNIntrinsic(:abs, :fabs; inp_args=(jltype,), out_arg=jltype))
# TODO: abs(::Union{Int32,Int64})

# FIXME: Multi-argument functions
#=
push!(MATH_INTRINSICS, = map(intr->GCNIntrinsic(intr), (
:sincos, :frexp, :ldexp, :copysign,
)))
=#
# Multi-argument functions
push!(MATH_INTRINSICS, GCNIntrinsic(:pow; inp_args=(jltype,jltype), out_arg=jltype))
push!(MATH_INTRINSICS, GCNIntrinsic(:pow, :pown; inp_args=(jltype,Union{UInt32,Int32}), out_arg=jltype))
# TODO: push!(MATH_INTRINSICS, GCNIntrinsic(:pow, :pown; inp_args=(jltype,Union{UInt32,Int32}), out_arg=jltype))
# TODO: :sincos, :frexp, :ldexp, :copysign,
#push!(MATH_INTRINSICS, GCNIntrinsic(:ldexp; inp_args=(jltype,), out_arg=(jltype, Int32), isinverted=true))

# Multi-output functions
push!(MATH_INTRINSICS, GCNIntrinsic(:sincospi; inp_args=(jltype,), out_arg=jltype, isbroken=true))
end

let jltype=Float32
# TODO: Float64 is broken for some reason, try to re-enable on a newer LLVM
for jltype in (Float32, Float64)
push!(MATH_INTRINSICS, GCNIntrinsic(:isfinite; inp_args=(jltype,), out_arg=Int32))
push!(MATH_INTRINSICS, GCNIntrinsic(:isinf; inp_args=(jltype,), out_arg=Int32))
push!(MATH_INTRINSICS, GCNIntrinsic(:isnan; inp_args=(jltype,), out_arg=Int32))
Expand Down
87 changes: 87 additions & 0 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ForwardDiff integration

byhand = [:exp2, :log2, :exp10, :log10, :abs]

for f in device_intrinsics
if haskey(ForwardDiff.DiffRules.DEFINED_DIFFRULES, (:Base,f,1))
f byhand && continue
diffrule = ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base,f,1)]
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:AMDGPU,f,1)] =
(args...) -> replace_device(diffrule(args...))
eval(ForwardDiff.unary_dual_definition(:AMDGPU, f))
end
end

# byhand: exp2
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:AMDGPU, :exp2, 1)] = x ->
:((AMDGPU.rocfunc(exp2))(x) * (AMDGPU.rocfunc(log))(oftype(x, 2)))
eval(ForwardDiff.unary_dual_definition(:AMDGPU, :exp2))

# byhand: log2
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:AMDGPU, :log2, 1)] = x ->
:(inv(x) / (AMDGPU.rocfunc(log))(oftype(x, 2)))
eval(ForwardDiff.unary_dual_definition(:AMDGPU, :log2))

# byhand: exp10
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:AMDGPU, :exp10, 1)] = x ->
:((AMDGPU.rocfunc(exp10))(x) * (AMDGPU.rocfunc(log))(oftype(x, 10)))
eval(ForwardDiff.unary_dual_definition(:AMDGPU, :exp10))

# byhand: log10
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:AMDGPU, :log10, 1)] = x ->
:(inv(x) / (AMDGPU.rocfunc(log))(oftype(x, 10)))
eval(ForwardDiff.unary_dual_definition(:AMDGPU, :log10))

# byhand: abs
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:AMDGPU, :abs, 1)] = x ->
:(signbit(x) ? -one(x) : one(x))
eval(ForwardDiff.unary_dual_definition(:AMDGPU, :abs))


ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:AMDGPU, :pow, 2)] = (x, y) ->
replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y))

@eval begin
ForwardDiff.@define_binary_dual_op(
AMDGPU.pow,
begin
vx = ForwardDiff.value(x)
vy = ForwardDiff.value(y)
expv = (AMDGPU.pow)(vx, vy)

powval = vy * AMDGPU.pow(vx , vy - Int32(1))

py = ForwardDiff.partials(y)
px = ForwardDiff.partials(x)

cond = all(py.values) do x
x == zero(x)
end

if cond
logval = one(expv)
else
logval = expv * AMDGPU.log(vx)
end

new_partials = powval * px + logval * py
return ForwardDiff.Dual{Txy}(expv, new_partials)
end,
begin
v = ForwardDiff.value(x)
expv = (AMDGPU.pow)(v, y)
if y == zero(y)
new_partials = zero(ForwardDiff.partials(x))
else
new_partials = ForwardDiff.partials(x) * y * (AMDGPU.pow)(v, y - Int32(1))
end
return ForwardDiff.Dual{Tx}(expv, new_partials)
end,
begin
v = ForwardDiff.value(y)
expv = (AMDGPU.pow)(x, v)
deriv = expv*AMDGPU.log(x)
return ForwardDiff.Dual{Ty}(expv, deriv * ForwardDiff.partials(y))
end
)
end
Loading

0 comments on commit a83bca6

Please sign in to comment.