Skip to content

Commit

Permalink
Merge pull request #171 from JuliaGPU/vc/norm
Browse files Browse the repository at this point in the history
Fix norm
  • Loading branch information
vchuravy authored Oct 31, 2021
2 parents a046daa + 9d8f1fa commit 5f69e4f
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ROCmDeviceLibs_jll = "873c0968-716b-5aa7-bb8d-d1e2e2aeff2d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
hsa_rocr_jll = "dd59ff1a-a01a-568d-8b29-0669330f116a"

[compat]
Expand Down
1 change: 1 addition & 0 deletions src/AMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ include("broadcast.jl")
#include("matmul.jl")
include("mapreduce.jl")
#include("gpuarray_interface.jl")
include("compat.jl")

allowscalar(x::Bool) = GPUArrays.allowscalar(x)

Expand Down
2 changes: 1 addition & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ for (f, froc) in (
(:threadidx, :threadIdx),
(:griddim, :gridDimWG)
)
@eval GPUArrays.$f(::ROCKernelContext) = AMDGPU.$froc().x
@eval @inline GPUArrays.$f(::ROCKernelContext) = AMDGPU.$froc().x
end

# math
Expand Down
66 changes: 66 additions & 0 deletions src/compat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Hacks that can be ripped out once we support device overrides

function LinearAlgebra.norm(v::ROCArray{T}, p::Real=2) where {T}
if p == Inf
maximum(abs.(v))
elseif p == -Inf
minimum(abs.(v))
else
mapreduce(x->AMDGPU.pow(AMDGPU.abs(x), p), +, v; init=zero(T))^(1/p)
end
end

@inline function ssqs(x::T, y::T) where T<:Real
k::Int = 0
ρ = x*x + y*y
# FIXME: isfinite, isinf returns i32
if isfinite(ρ) == Int32(0) && (isinf(x) == Int32(1) || isinf(y) == Int32(1))
ρ = convert(T, Inf)
elseif isinf(ρ) == Int32(1) ||==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
m::T = max(abs(x), abs(y))
k = m==0 ? m : exponent(m)
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
end
ρ, k
end

@inline function sqrt(z::Complex)
z = float(z)
x, y = reim(z)
if x==y==0
return Complex(zero(x),y)
end
ρ, k::Int = ssqs(x, y)
# FIXME: isfinite returns i32
if isfinite(x) == Int32(1)
ρ=ldexp(abs(x),-k)+sqrt(ρ)
end
if isodd(k)
k = div(k-1,2)
else
k = div(k,2)-1
ρ += ρ
end
ρ = ldexp(sqrt(ρ),k) #sqrt((abs(z)+abs(x))/2) without over/underflow
ξ = ρ
η = y
if ρ != 0
# FIXME: isfinite returns i32
if isfinite(η) == Int32(1)
η=/ρ)/2
end
if x<0
ξ = abs(η)
η = copysign(ρ,y)
end
end
Complex(ξ,η)
end


import Statistics
function Statistics.corzm(x::ROCArray{<:Any, 2}, vardim::Int=1)
c = Statistics.unscaled_covzm(x, vardim)
return Statistics.cov2cor!(c, sqrt.(LinearAlgebra.diag(c)))
end
42 changes: 42 additions & 0 deletions src/device/gcn/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,45 @@ end

abs(z::Complex) = hypot(real(z), imag(z))
abs(i::Integer) = Base.abs(i)

@inline function pow(x::Float32, y::Int64)
y == -1 && return inv(x)
y == 0 && return one(x)
y == 1 && return x
y == 2 && return x*x
y == 3 && return x*x*x
pow(x, Float32(y))
end
@inline function pow(x::Float64, y::Int64)
y == -1 && return inv(x)
y == 0 && return one(x)
y == 1 && return x
y == 2 && return x*x
y == 3 && return x*x*x
pow(x, Float64(y))
end

pow(x::Integer, p::Union{Float32, Float64}) = pow(convert(typeof(p), x), p)
@inline function pow(x::Integer, p::Integer)
p < 0 && throw("Negative integer power not supported")
p == 0 && return one(x)
p == 1 && return x
p == 2 && return x*x
p == 3 && return x*x*x

t = trailing_zeros(p) + 1
p >>= t
while (t -= 1) > 0
x *= x
end
y = x
while p > 0
t = trailing_zeros(p) + 1
p >>= t
while (t -= 1) >= 0
x *= x
end
y *= x
end
return y
end
5 changes: 4 additions & 1 deletion src/device/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ function link_device_libs!(target, mod::LLVM.Module)
for (option, value) in options
toggle = value ? "on" : "off"
lib = locate_lib("oclc_$(option)_$(toggle)")
@assert lib !== nothing
if lib === nothing
@warn "Could not find OCLC library for option $option=$value"
continue
end
load_and_link!(mod, lib)
end
end
9 changes: 5 additions & 4 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function HSAKernelInstance(agent::HSAAgent, exe::HSAExecutable, symbol::String,
if kernarg_segment_size[] == 0
# FIXME: Hidden arguments!
if length(args) > 0
kernarg_segment_size[] = sum(sizeof.(args))
kernarg_segment_size[] = sum(sizeof(typeof(arg)) for arg in args)
else
# Allocate some memory anyway, #10
kernarg_segment_size[] = max(kernarg_segment_size[], 8)
Expand Down Expand Up @@ -76,13 +76,14 @@ function HSAKernelInstance(agent::HSAAgent, exe::HSAExecutable, symbol::String,
if rem > 0
ctr += align-rem
end
#@info "Storing $(typeof(arg)) at offset $ctr"
sz = sizeof(typeof(arg))
ccall(:memcpy, Cvoid,
(Ptr{Cvoid}, Ptr{Cvoid}, Csize_t),
kernarg_address[]+ctr, rarg, sizeof(arg))
ctr += sizeof(arg)
kernarg_address[]+ctr, rarg, sz)
ctr += sz
end


kernel = HSAKernelInstance(agent, exe, symbol, args, kernel_object[],
kernarg_segment_size[], group_segment_size[],
private_segment_size[], kernarg_address[])
Expand Down

0 comments on commit 5f69e4f

Please sign in to comment.