Skip to content

Commit

Permalink
Detect HIP target automatically
Browse files Browse the repository at this point in the history
Signed-off-by: PragmaTwice <[email protected]>
  • Loading branch information
PragmaTwice committed Nov 24, 2024
1 parent f378748 commit f39168f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 6 deletions.
40 changes: 34 additions & 6 deletions integrations/pjrt/src/iree_pjrt/rocm/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree_pjrt/rocm/client.h"

#include "iree/hal/drivers/hip/hip_device.h"
#include "iree/hal/drivers/hip/registration/driver_module.h"

namespace iree::pjrt::rocm {
Expand All @@ -27,20 +28,47 @@ iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) {
IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create(
driver_registry_, driver_name, host_allocator_, out_driver));
logger().debug("ROCM driver created");

// retrieve the target name of current available device
iree_host_size_t device_info_count;
iree_hal_device_info_t* device_infos;
IREE_RETURN_IF_ERROR(iree_hal_driver_query_available_devices(
*out_driver, host_allocator_, &device_info_count, &device_infos));

// TODO: here we just use the target name of the first available device,
// but ideally we should find the device which will run the program
if (device_info_count > 0) {
hipDeviceProp_tR0000 props;
IREE_RETURN_IF_ERROR(iree_hal_hip_get_device_properties(
*out_driver, device_infos->device_id, &props));

// `gcnArchName` comes back like gfx90a:sramecc+:xnack- for a fully
// specified target. However the IREE target-chip flag only expects the
// prefix. refer to
// https://github.com/iree-org/iree-turbine/blob/965247e/iree/turbine/runtime/device.py#L495
std::string_view target = props.gcnArchName;
if (auto pos = target.find(':'); pos != target.npos) {
target = target.substr(0, pos);
}

hip_target_ = target;
logger().debug("HIP target detected: " + hip_target_);
}

return iree_ok_status();
}

bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
auto flags = {
std::vector<std::string> flags = {
"--iree-hal-target-backends=rocm",

// TODO: gfx908 is just a placeholder here to make it work,
// we should instead detect the device target on the fly
"--iree-hip-target=gfx908",
};

if (!hip_target_.empty()) {
flags.push_back("--iree-hip-target=" + hip_target_);
}

for (auto flag : flags) {
if (!compiler_job->SetFlag(flag)) return false;
if (!compiler_job->SetFlag(flag.c_str())) return false;
}
return true;
}
Expand Down
1 change: 1 addition & 0 deletions integrations/pjrt/src/iree_pjrt/rocm/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ROCMClientInstance final : public ClientInstance {
bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override;

private:
std::string hip_target_;
};

} // namespace iree::pjrt::rocm
Expand Down
5 changes: 5 additions & 0 deletions runtime/src/iree/hal/drivers/hip/hip_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer(
// contexts and the context may be in use on other threads.
hipCtx_t iree_hal_hip_device_context(iree_hal_device_t* device);

// Retrieve device properties for the given |device_id| to |out_props|
iree_status_t iree_hal_hip_get_device_properties(
iree_hal_driver_t* driver, iree_hal_device_id_t device_id,
hipDeviceProp_tR0000* out_props);

// Returns the dynamic symbol table from the |device| if it is a HIP device
// and otherwise returns NULL.
//
Expand Down
14 changes: 14 additions & 0 deletions runtime/src/iree/hal/drivers/hip/hip_driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ static iree_status_t iree_hal_hip_driver_query_available_devices(
return status;
}

iree_status_t iree_hal_hip_get_device_properties(
iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
hipDeviceProp_tR0000* out_props) {
iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver);

hipDevice_t device = IREE_DEVICE_ID_TO_HIPDEVICE(device_id);

IREE_HIP_RETURN_IF_ERROR(&driver->hip_symbols,
hipGetDeviceProperties(out_props, device),
"hipGetDeviceProperties");

return iree_ok_status();
}

static iree_status_t iree_hal_hip_driver_dump_device_info(
iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
iree_string_builder_t* builder) {
Expand Down

0 comments on commit f39168f

Please sign in to comment.