Skip to content

Commit

Permalink
Utilize AMD GPUs
Browse files Browse the repository at this point in the history
Add compilation procedures for HipBlas.

Set the environment variable ROCM_PATH=<location of rocm, eg /opt/rocm>

If ROCM_PATH is not set, /opt/rocm will be used as the default location
of ROCm.

Tested on Linux w/ 6800xt. Further work may be necessary for Windows.

signed-off-by: Michael Mullin <[email protected]>
  • Loading branch information
masmullin2000 committed Mar 20, 2024
1 parent fa6c834 commit 7ff5bd7
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions crates/llama_cpp_sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,51 @@ fn compile_blis(cx: &mut Build) {
println!("cargo:rustc-link-lib=blis");
}

fn compile_hipblas(cx: &mut Build, cxx: &mut Build, mut hip: Build) -> &'static str {
const DEFAULT_ROCM_PATH_STR: &str = "/opt/rocm/";

let rocm_path_str = env::var("ROCM_PATH").map_err(|_| {
DEFAULT_ROCM_PATH_STR.to_string()
}).unwrap();
println!("Compiling HIPBLAS GGML. Using ROCm from {rocm_path_str}");

let rocm_path = PathBuf::from(rocm_path_str);
let rocm_include = rocm_path.join("include");
let rocm_lib = rocm_path.join("lib");
let rocm_hip_bin = rocm_path.join("bin/hipcc");

let cuda_lib = "ggml-cuda";
let cuda_file = cuda_lib.to_string() + ".cu";
let cuda_header = cuda_lib.to_string() + ".h";

let defines = ["GGML_USE_HIPBLAS", "GGML_USE_CUBLAS"];
for def in defines {
cx.define(def, None);
cxx.define(def, None);
}

cx.include(&rocm_include);
cxx.include(&rocm_include);

hip.compiler(rocm_hip_bin)
.file(LLAMA_PATH.join(cuda_file))
.include(LLAMA_PATH.join(cuda_header))
.define("GGML_USE_HIPBLAS", None)
.compile(cuda_lib);

println!(
"cargo:rustc-link-search=native={}",
rocm_lib.to_string_lossy()
);

let rocm_libs = ["hipblas", "rocblas", "amdhip64"];
for lib in rocm_libs {
println!("cargo:rustc-link-lib={lib}");
}

cuda_lib
}

fn compile_cuda(cx: &mut Build, cxx: &mut Build, featless_cxx: Build) -> &'static str {
println!("Compiling CUDA GGML..");

Expand Down Expand Up @@ -551,6 +596,8 @@ fn main() {
} else if cfg!(feature = "metal") && cfg!(target_os = "macos") {
compile_metal(&mut cx, &mut cxx);
None
} else if cfg!(feature = "hipblas") {
Some(compile_hipblas(&mut cx, &mut cxx, featless_cxx))
} else {
None
};
Expand Down

0 comments on commit 7ff5bd7

Please sign in to comment.