Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable rocm and vulkan build in CI workflow for PJRT plugin #19279

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
10 changes: 5 additions & 5 deletions .github/workflows/pkgci_test_pjrt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ on:
jobs:
build_and_test:
strategy:
fail-fast: false
matrix:
include:
- runner: ubuntu-20.04
pjrt_platform: cpu
# TODO: cuda runner is not available yet, refer to #18814
# - runner: some-cuda-available-runner
# pjrt_platform: cuda
# TODO: enable these AMD runners
# - runner: nodai-amdgpu-w7900-x86-64
# pjrt_platform: rocm
# - runner: nodai-amdgpu-w7900-x86-64
# pjrt_platform: vulkan
- runner: nodai-amdgpu-w7900-x86-64
pjrt_platform: rocm
- runner: nodai-amdgpu-w7900-x86-64
pjrt_platform: vulkan
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
name: Build and test
runs-on: ${{ matrix.runner }}
env:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ requires = [
"setuptools>=42",
"wheel",
"ninja",
"cmake",
]
build-backend = "setuptools.build_meta"
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ requires = [
"setuptools>=42",
"wheel",
"ninja",
"cmake",
]
build-backend = "setuptools.build_meta"
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ requires = [
"setuptools>=42",
"wheel",
"ninja",
"cmake",
]
build-backend = "setuptools.build_meta"
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def build_default_configuration(self):
print("*****************************", file=sys.stderr)
self.build_configuration(
os.path.join(THIS_DIR, "build", "cmake"),
extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=rocm",),
extra_cmake_args=("-DIREE_HAL_DRIVER_HIP=ON",),
)
print("Target populated.", file=sys.stderr)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ requires = [
"setuptools>=42",
"wheel",
"ninja",
"cmake",
]
build-backend = "setuptools.build_meta"
2 changes: 1 addition & 1 deletion integrations/pjrt/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ endif()
if(IREE_HAL_DRIVER_CUDA)
add_subdirectory(iree_pjrt/cuda)
endif()
if("rocm" IN_LIST IREE_EXTERNAL_HAL_DRIVERS)
if(IREE_HAL_DRIVER_HIP)
add_subdirectory(iree_pjrt/rocm)
endif()
if(IREE_HAL_DRIVER_VULKAN)
Expand Down
4 changes: 2 additions & 2 deletions integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ iree_cc_library(
"client.cc"
DEPS
iree_pjrt::common
iree::experimental::rocm
iree::experimental::rocm::registration
iree::hal::drivers::hip
iree::hal::drivers::hip::registration
Comment on lines -16 to +17
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof sorry that this code has bit-rotted so much, and thanks for finding what needed to be updated.

)

iree_cc_library(
Expand Down
61 changes: 55 additions & 6 deletions integrations/pjrt/src/iree_pjrt/rocm/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

#include "iree_pjrt/rocm/client.h"

#include "experimental/rocm/registration/driver_module.h"
#include "iree/hal/drivers/hip/api.h"
#include "iree/hal/drivers/hip/hip_device.h"
#include "iree/hal/drivers/hip/registration/driver_module.h"

namespace iree::pjrt::rocm {

Expand All @@ -17,21 +19,68 @@ ROCMClientInstance::ROCMClientInstance(std::unique_ptr<Platform> platform)
// TODO: Get this when constructing the client so it is guaranteed to
// match.
cached_platform_name_ = "iree_rocm";
IREE_CHECK_OK(iree_hal_rocm_driver_module_register(driver_registry_));
}

ROCMClientInstance::~ROCMClientInstance() {}

iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) {
iree_string_view_t driver_name = iree_make_cstring_view("rocm");
IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create(
driver_registry_, driver_name, host_allocator_, out_driver));
iree_string_view_t driver_name = iree_make_cstring_view("hip");

// Device params.
iree_hal_hip_device_params_t default_params;
iree_hal_hip_device_params_initialize(&default_params);

// Driver params.
iree_hal_hip_driver_options_t driver_options;
iree_hal_hip_driver_options_initialize(&driver_options);

IREE_RETURN_IF_ERROR(iree_hal_hip_driver_create(driver_name, &driver_options,
&default_params,
host_allocator_, out_driver));
logger().debug("ROCM driver created");

// retrieve the target name of current available device
iree_host_size_t device_info_count;
iree_hal_device_info_t* device_infos;
IREE_RETURN_IF_ERROR(iree_hal_driver_query_available_devices(
*out_driver, host_allocator_, &device_info_count, &device_infos));

// TODO: here we just use the target name of the first available device,
// but ideally we should find the device which will run the program
if (device_info_count > 0) {
hipDeviceProp_tR0000 props;
IREE_RETURN_IF_ERROR(iree_hal_hip_get_device_properties(
*out_driver, device_infos->device_id, &props));

// `gcnArchName` comes back like gfx90a:sramecc+:xnack- for a fully
// specified target. However the IREE target-chip flag only expects the
// prefix. refer to
// https://github.com/iree-org/iree-turbine/blob/965247e/iree/turbine/runtime/device.py#L495
std::string_view target = props.gcnArchName;
if (auto pos = target.find(':'); pos != target.npos) {
target = target.substr(0, pos);
}

hip_target_ = target;
logger().debug("HIP target detected: " + hip_target_);
Comment on lines +48 to +65
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ python test/test_add.py
WARNING:jax._src.xla_bridge:Platform 'iree_rocm' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/esaimana/actions-runner/_work/iree/iree/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.1.0.dev+da45650503b955e[35](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044454?pr=19279#step:7:36)a45ea8e34b5f10b30a2f912 @ da45650503b955e35a45ea8e34b5f10b30a2f912 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: ROCM driver created
[IREE-PJRT] DEBUG: HIP target detected: gfx1100
iree/runtime/src/iree/hal/drivers/hip/native_executable.c:306: UNKNOWN; HIP driver error 'hipErrorFileNotFound' (301): file not found; mismatched target chip? missing/wrong bitcode directory?; while invoking native function hal.executable.create; while calling import; 
[ 0] bytecode jit__add.__init:1284 "jit(_add)/jit(main)/add"("<module>"(/home/esaimana/actions-runner/_work/iree/iree/integrations/pjrt/test/test_add.py:9:6))
build_tools/testing/run_jax_tests.sh: line 31: 888283 Aborted                 (core dumped) JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out

Hmm the error in rocm plugin comes with a hint (mismatched target chip? missing/wrong bitcode directory?) but it seems weird that I've already add some new logic to detect the HIP target and pass to IREE (via --iree-hip-target) before the compilation phase.

I'll investigate these errors in these days.

For test workflows, we typically set an environment variable explicitly, matched with the hardware installed into the runner chosen using the runs-on property:

IREE_HIP_TEST_TARGET_CHIP: "gfx1100"

In PJRT, matching the code from iree-turbine that detects available devices makes sense to me, instead of that environment variable approach. This is a JIT scenario, where the compiler is being used to generate code for the current device, not an arbitrary deployment target.

Can you have the PJRT client log the compile command used, including all flags?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll dump more information for debugging.

}

return iree_ok_status();
}

bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
return compiler_job->SetFlag("--iree-hal-target-backends=rocm");
std::vector<std::string> flags = {
"--iree-hal-target-backends=rocm",
};

if (!hip_target_.empty()) {
flags.push_back("--iree-hip-target=" + hip_target_);
}

for (auto flag : flags) {
if (!compiler_job->SetFlag(flag.c_str())) return false;
}
return true;
}

} // namespace iree::pjrt::rocm
3 changes: 2 additions & 1 deletion integrations/pjrt/src/iree_pjrt/rocm/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#ifndef IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_
#define IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_

#include "experimental/rocm/api.h"
#include "iree/hal/drivers/hip/api.h"
#include "iree_pjrt/common/api_impl.h"

namespace iree::pjrt::rocm {
Expand All @@ -20,6 +20,7 @@ class ROCMClientInstance final : public ClientInstance {
bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override;

private:
std::string hip_target_;
};

} // namespace iree::pjrt::rocm
Expand Down
2 changes: 1 addition & 1 deletion integrations/pjrt/src/iree_pjrt/vulkan/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ iree_status_t VulkanClientInstance::CreateDriver(
}

bool VulkanClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
return compiler_job->SetFlag("--iree-hal-target-backends=vulkan");
return compiler_job->SetFlag("--iree-hal-target-backends=vulkan-spirv");
}

} // namespace iree::pjrt::vulkan
5 changes: 5 additions & 0 deletions runtime/src/iree/hal/drivers/hip/hip_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer(
// contexts and the context may be in use on other threads.
hipCtx_t iree_hal_hip_device_context(iree_hal_device_t* device);

// Retrieve device properties for the given |device_id| to |out_props|
iree_status_t iree_hal_hip_get_device_properties(
iree_hal_driver_t* driver, iree_hal_device_id_t device_id,
hipDeviceProp_tR0000* out_props);

// Returns the dynamic symbol table from the |device| if it is a HIP device
// and otherwise returns NULL.
//
Expand Down
14 changes: 14 additions & 0 deletions runtime/src/iree/hal/drivers/hip/hip_driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ static iree_status_t iree_hal_hip_driver_query_available_devices(
return status;
}

iree_status_t iree_hal_hip_get_device_properties(
iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
hipDeviceProp_tR0000* out_props) {
iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver);

hipDevice_t device = IREE_DEVICE_ID_TO_HIPDEVICE(device_id);

IREE_HIP_RETURN_IF_ERROR(&driver->hip_symbols,
hipGetDeviceProperties(out_props, device),
"hipGetDeviceProperties");

return iree_ok_status();
}

static iree_status_t iree_hal_hip_driver_dump_device_info(
iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
iree_string_builder_t* builder) {
Expand Down
Loading