Skip to content

Commit

Permalink
Adapt to GPUArrays@10 (#580)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jan 14, 2024
1 parent 1b3fbb8 commit 8e6480a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 25 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ hsa_rocr_jll = "dd59ff1a-a01a-568d-8b29-0669330f116a"

[compat]
AbstractFFTs = "1.0"
Adapt = "3.0"
Adapt = "4"
Atomix = "0.1"
CEnum = "0.4, 0.5"
ExprTools = "0.1"
GPUArrays = "9"
GPUArrays = "10"
GPUCompiler = "0.25"
HIP_jll = "5.4"
KernelAbstractions = "0.9.2"
Expand Down
12 changes: 9 additions & 3 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Return the device associated with the array `A`.
"""
device(A::ROCArray) = A.buf[].device

buftype(x::ROCArray) = buftype(typeof(x))
buftype(::Type{<:ROCArray{<:Any, <:Any, B}}) where B = B # TODO check `@isdefined`?

## aliases

const ROCVector{T} = ROCArray{T,1}
Expand Down Expand Up @@ -102,9 +105,12 @@ end
# empty vector constructor
ROCArray{T,1}() where {T} = ROCArray{T,1}(undef, 0)

Base.similar(a::ROCArray{T,N}) where {T,N} = ROCArray{T,N}(undef, size(a))
Base.similar(::ROCArray{T}, dims::Base.Dims{N}) where {T,N} = ROCArray{T,N}(undef, dims)
Base.similar(::ROCArray, ::Type{T}, dims::Base.Dims{N}) where {T,N} = ROCArray{T,N}(undef, dims)
Base.similar(a::ROCArray{T, N, B}) where {T, N, B} =
ROCArray{T, N, B}(undef, size(a))
Base.similar(::ROCArray{T, <:Any, B}, dims::Base.Dims{N}) where {T, N, B} =
ROCArray{T, N, B}(undef, dims)
Base.similar(::ROCArray{<:Any, <:Any, B}, ::Type{T}, dims::Base.Dims{N}) where {T, N, B} =
ROCArray{T, N, B}(undef, dims)

## array interface

Expand Down
31 changes: 18 additions & 13 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# broadcasting

using Base.Broadcast: BroadcastStyle, Broadcasted

struct ROCArrayStyle{N} <: AbstractGPUArrayStyle{N} end
ROCArrayStyle(::Val{N}) where N = ROCArrayStyle{N}()
ROCArrayStyle{M}(::Val{N}) where {N,M} = ROCArrayStyle{N}()
struct ROCArrayStyle{N, B} <: AbstractGPUArrayStyle{N} end
ROCArrayStyle{M, B}(::Val{N}) where {N, M, B} = ROCArrayStyle{N, B}()

BroadcastStyle(::Type{<:ROCArray{T,N}}) where {T,N} = ROCArrayStyle{N}()
# Identify the broadcast style of a wrapped ROCArray.
BroadcastStyle(::Type{<:ROCArray{T, N, B}}) where {T, N, B} =
ROCArrayStyle{N, B}()
BroadcastStyle(W::Type{<:AnyROCArray{T, N}}) where {T, N} =
ROCArrayStyle{N, buftype(Adapt.unwrap_type(W))}()

Base.similar(bc::Broadcasted{ROCArrayStyle{N}}, ::Type{T}) where {N,T} =
similar(ROCArray{T}, axes(bc))
# TODO handle broadcast of different buffer types (use unified memory).

Base.similar(bc::Broadcasted{ROCArrayStyle{N}}, ::Type{T}, dims...) where {N,T} =
ROCArray{T}(undef, dims...)
# Allocation of output arrays.
function Base.similar(
bc::Broadcasted{ROCArrayStyle{N, B}}, ::Type{T}, dims,
) where {N, B, T}
similar(ROCArray{T, length(dims), B}, dims)
end

# broadcasting type ctors isn't GPU compatible
Broadcast.broadcasted(::ROCArrayStyle{N}, f::Type{T}, args...) where {N, T} =
Broadcasted{ROCArrayStyle{N}}((x...) -> T(x...), args, nothing)
# TODO: revise
# Broadcasting type ctors isn't GPU compatible.
Broadcast.broadcasted(::ROCArrayStyle{N, B}, f::Type{T}, args...) where {N, B, T} =
Broadcasted{ROCArrayStyle{N, B}}((x...) -> T(x...), args, nothing)
4 changes: 3 additions & 1 deletion src/gpuarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ GPUArrays.device(x::ROCArray) = x.buf[].device

GPUArrays.backend(::Type{<:ROCArray}) = ROCArrayBackend()

function GPUArrays.derive(::Type{T}, N::Int, x::ROCArray, dims::Dims, offset::Int) where T
function GPUArrays.derive(
::Type{T}, x::ROCArray, dims::Dims{N}, offset::Int,
) where {N, T}
ref = copy(x.buf)
offset += (x.offset * Base.elsize(x)) ÷ sizeof(T)
ROCArray{T, N}(ref, dims; offset)
Expand Down
14 changes: 8 additions & 6 deletions src/runtime/memory/hip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,24 @@ function HostBuffer()
HostBuffer(s.device, s.ctx, C_NULL, C_NULL, 0, true)
end

function HostBuffer(bytesize::Integer, flags = 0)
function HostBuffer(
bytesize::Integer, flags = 0; stream::HIP.HIPStream = AMDGPU.stream(),
)
bytesize == 0 && return HostBuffer()

ptr_ref = Ref{Ptr{Cvoid}}()
HIP.hipHostMalloc(ptr_ref, bytesize, flags) |> HIP.check
ptr = ptr_ref[]
dev_ptr = get_device_ptr(ptr)
s = AMDGPU.stream()
HostBuffer(s.device, s.ctx, ptr, dev_ptr, bytesize, true)
HostBuffer(stream.device, stream.ctx, ptr, dev_ptr, bytesize, true)
end

function HostBuffer(ptr::Ptr{Cvoid}, sz::Integer)
function HostBuffer(
ptr::Ptr{Cvoid}, sz::Integer; stream::HIP.HIPStream = AMDGPU.stream(),
)
HIP.hipHostRegister(ptr, sz, HIP.hipHostRegisterMapped) |> HIP.check
dev_ptr = get_device_ptr(ptr)
s = AMDGPU.stream()
HostBuffer(s.device, s.ctx, ptr, dev_ptr, sz, false)
HostBuffer(stream.device, stream.ctx, ptr, dev_ptr, sz, false)
end

function view(buf::HostBuffer, bytesize::Int)
Expand Down

0 comments on commit 8e6480a

Please sign in to comment.