Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable rocm and vulkan build in CI workflow for PJRT plugin #19279

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Nov 23, 2024

This PR trys to enable rocm and vulkan PJRT plugin in CI workflow (via the runner nodai-amdgpu-w7900-x86-64, mentioned here: #19222 (comment)).

Currently it's marked as a draft PR for potentially CI debuging.

ci-exactly: build_packages, test_pjrt

@PragmaTwice PragmaTwice force-pushed the pjrt-ci-rocm branch 5 times, most recently from 311bd17 to f39168f Compare November 24, 2024 07:08
.github/workflows/pkgci_test_pjrt.yml Show resolved Hide resolved
Comment on lines -16 to +17
iree::experimental::rocm
iree::experimental::rocm::registration
iree::hal::drivers::hip
iree::hal::drivers::hip::registration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof sorry that this code has bit-rotted so much, and thanks for finding what needed to be updated.

@PragmaTwice
Copy link
Member Author

+ python test/test_add.py
WARNING:jax._src.xla_bridge:Platform 'iree_rocm' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/esaimana/actions-runner/_work/iree/iree/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.1.0.dev+da45650503b955e[35](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044454?pr=19279#step:7:36)a45ea8e34b5f10b30a2f912 @ da45650503b955e35a45ea8e34b5f10b30a2f912 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: ROCM driver created
[IREE-PJRT] DEBUG: HIP target detected: gfx1100
iree/runtime/src/iree/hal/drivers/hip/native_executable.c:306: UNKNOWN; HIP driver error 'hipErrorFileNotFound' (301): file not found; mismatched target chip? missing/wrong bitcode directory?; while invoking native function hal.executable.create; while calling import; 
[ 0] bytecode jit__add.__init:1284 "jit(_add)/jit(main)/add"("<module>"(/home/esaimana/actions-runner/_work/iree/iree/integrations/pjrt/test/test_add.py:9:6))
build_tools/testing/run_jax_tests.sh: line 31: 888283 Aborted                 (core dumped) JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out

Hmm the error in rocm plugin comes with a hint (mismatched target chip? missing/wrong bitcode directory?) but it seems weird that I've already add some new logic to detect the HIP target and pass to IREE (via --iree-hip-target) before the compilation phase.

I'll investigate these errors in these days.

Comment on lines +48 to +65
// TODO: here we just use the target name of the first available device,
// but ideally we should find the device which will run the program
if (device_info_count > 0) {
hipDeviceProp_tR0000 props;
IREE_RETURN_IF_ERROR(iree_hal_hip_get_device_properties(
*out_driver, device_infos->device_id, &props));

// `gcnArchName` comes back like gfx90a:sramecc+:xnack- for a fully
// specified target. However the IREE target-chip flag only expects the
// prefix. refer to
// https://github.com/iree-org/iree-turbine/blob/965247e/iree/turbine/runtime/device.py#L495
std::string_view target = props.gcnArchName;
if (auto pos = target.find(':'); pos != target.npos) {
target = target.substr(0, pos);
}

hip_target_ = target;
logger().debug("HIP target detected: " + hip_target_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ python test/test_add.py
WARNING:jax._src.xla_bridge:Platform 'iree_rocm' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/esaimana/actions-runner/_work/iree/iree/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.1.0.dev+da45650503b955e[35](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044454?pr=19279#step:7:36)a45ea8e34b5f10b30a2f912 @ da45650503b955e35a45ea8e34b5f10b30a2f912 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: ROCM driver created
[IREE-PJRT] DEBUG: HIP target detected: gfx1100
iree/runtime/src/iree/hal/drivers/hip/native_executable.c:306: UNKNOWN; HIP driver error 'hipErrorFileNotFound' (301): file not found; mismatched target chip? missing/wrong bitcode directory?; while invoking native function hal.executable.create; while calling import; 
[ 0] bytecode jit__add.__init:1284 "jit(_add)/jit(main)/add"("<module>"(/home/esaimana/actions-runner/_work/iree/iree/integrations/pjrt/test/test_add.py:9:6))
build_tools/testing/run_jax_tests.sh: line 31: 888283 Aborted                 (core dumped) JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out

Hmm the error in rocm plugin comes with a hint (mismatched target chip? missing/wrong bitcode directory?) but it seems weird that I've already add some new logic to detect the HIP target and pass to IREE (via --iree-hip-target) before the compilation phase.

I'll investigate these errors in these days.

For test workflows, we typically set an environment variable explicitly, matched with the hardware installed into the runner chosen using the runs-on property:

IREE_HIP_TEST_TARGET_CHIP: "gfx1100"

In PJRT, matching the code from iree-turbine that detects available devices makes sense to me, instead of that environment variable approach. This is a JIT scenario, where the compiler is being used to generate code for the current device, not an arbitrary deployment target.

Can you have the PJRT client log the compile command used, including all flags?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants