Skip to content

Commit

Permalink
Fix bug with device selection
Browse files Browse the repository at this point in the history
  • Loading branch information
utkinis committed Nov 9, 2023
1 parent eab0977 commit 646fb71
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ext/AMDGPUExt/AMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module AMDGPUExt

using AMDGPU
using AMDGPU, AMDGPU.ROCKernels

import FastIce.Architecture: heuristic_groupsize, set_device!

set_device!(dev::HIPDevice) = AMDGPU.device!(dev)

set_device!(::HIPDevice, id::Integer) = AMDGPU.device_id!(id)
get_device(::ROCBackend, id::Integer) = HIPDevice(id)

heuristic_groupsize(::HIPDevice, ::Val{1}) = (256, )
heuristic_groupsize(::HIPDevice, ::Val{2}) = (128, 2, )
Expand Down
10 changes: 5 additions & 5 deletions ext/CUDAExt/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
module CUDAExt

using CUDA
using CUDA, CUDA.CUDAKernels

import FastIce.Architecture: heuristic_groupsize, set_device!

set_device!(dev::CuDevice) = CUDA.device!(dev)

set_device!(::CuDevice, id::Integer) = CUDA.device!(id-1)
get_device(::CUDABackend, id::Integer) = CuDevice(id - 1)

heuristic_groupsize(::CuDevice, ::Val{1}) = (256, )
heuristic_groupsize(::CuDevice, ::Val{2}) = (32, 8, )
heuristic_groupsize(::CuDevice, ::Val{3}) = (32, 8, 1, )
heuristic_groupsize(::CuDevice, ::Val{1}) = (256,)
heuristic_groupsize(::CuDevice, ::Val{2}) = (32, 8)
heuristic_groupsize(::CuDevice, ::Val{3}) = (32, 8, 1)

end
3 changes: 1 addition & 2 deletions src/Architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
struct SingleDevice end

function Architecture(backend::Backend, device_id::Integer=1)
device = set_device!(backend, device_id)
device = get_device(backend, device_id)
return Architecture{SingleDevice,typeof(backend),typeof(device),Nothing}(backend, device, nothing)
end

Expand All @@ -37,7 +37,6 @@ function set_device_and_priority!(arch::Architecture, prio::Symbol)
end

set_device!(::Architecture{Kind,CPU}) where {Kind} = nothing
set_device!(::CPU, id::Integer) = nothing

heuristic_groupsize(arch::Architecture, ::Val{N}) where {N} = heuristic_groupsize(arch.device, Val(N))
heuristic_groupsize(::Architecture{Kind,CPU}, N) where {Kind} = 256
Expand Down
2 changes: 1 addition & 1 deletion src/Distributed/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct DistributedMPI end

function Architectures.Architecture(backend::Backend, dims::NTuple{N,Int}, comm::MPI.Comm=MPI.COMM_WORLD) where {N}
topo = CartesianTopology(dims; comm)
device = set_device!(backend, shared_rank(topo))
device = get_device(backend, shared_rank(topo))
return Architecture{DistributedMPI,typeof(backend),typeof(device),typeof(topo)}(backend, device, topo)
end

Expand Down

0 comments on commit 646fb71

Please sign in to comment.