Skip to content

Commit

Permalink
153 add hipblas (#154)
Browse files Browse the repository at this point in the history
* move bindings to src/sys/. Add hipblas/

* add hipBLAS link to build.rs
  • Loading branch information
smedegaard authored Dec 30, 2024
1 parent 82c975f commit 29fc0ed
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 6 deletions.
5 changes: 4 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ fn main() {
return;
}

// link hipBLAS
println!("cargo:rustc-link-lib=dylib=hipblas");

// Tell cargo when to rerun this build script
println!("cargo:rerun-if-changed=src/core/sys/wrapper.h");
println!("cargo:rerun-if-changed=build.rs");
Expand All @@ -38,7 +41,7 @@ fn main() {
fn generate_bindings(hip_include_path: &str) {
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
let bindings = bindgen::Builder::default()
.header("src/core/sys/wrapper.h")
.header("src/sys/wrapper.h")
.clang_arg(&format!("-I{}", hip_include_path))
.clang_arg("-D__HIP_PLATFORM_AMD__")
// Blocklist problematic items
Expand Down
2 changes: 1 addition & 1 deletion src/core/device.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::sys;
use super::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, MemPool, PCIBusId, Result};
use crate::sys;
use semver::Version;
use std::ffi::CStr;
use std::i32;
Expand Down
2 changes: 1 addition & 1 deletion src/core/device_types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::sys;
use super::{HipError, HipErrorKind, HipResult, Result};
use crate::sys;
use std::ffi::CStr;
use std::i32;

Expand Down
3 changes: 2 additions & 1 deletion src/core/hip_call.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{sys, HipResult, Result};
use super::{HipResult, Result};
use crate::sys;

#[macro_export]
macro_rules! hip_call {
Expand Down
2 changes: 1 addition & 1 deletion src/core/init.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::sys;
use crate::sys;
use crate::{HipResult, Result};
use semver::Version;
use std::i32;
Expand Down
2 changes: 1 addition & 1 deletion src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ mod init;
mod memory;
mod result;
mod stream;
pub mod sys;

// use crate::sys::*;
// Re-export core functionality
pub use device::*;
pub use device_types::*;
Expand Down
198 changes: 198 additions & 0 deletions src/hipblas/handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use crate::sys;
use crate::{HipResult, Result};
use std::fmt;

/// A handle to a hipBLAS library context.
///
/// This handle is required for all hipBLAS library calls and encapsulates the
/// hipBLAS library context. The context includes the HIP device number and
/// stream used for all hipBLAS operations using this handle.
///
/// # Thread Safety
///
/// The handle is thread-safe and can be shared between threads. It implements
/// Send and Sync traits.
///
/// # Examples
///
/// ```
/// use hip_rs::BlasHandle;
///
/// let handle = BlasHandle::new().unwrap();
/// // Use handle for hipBLAS operations
/// ```
#[derive(Debug)]
pub struct BlasHandle {
handle: sys::hipblasHandle_t,
}

impl BlasHandle {
/// Creates a new hipBLAS library context.
///
/// # Returns
///
/// * `Ok(BlasHandle)` - A new handle for hipBLAS operations
/// * `Err(HipError)` - If handle creation fails
///
/// # Examples
///
/// ```
/// use hip_rs::BlasHandle;
///
/// let handle = BlasHandle::new().unwrap();
/// ```
pub fn new() -> Result<Self> {
let mut handle = std::ptr::null_mut();
unsafe {
let status = sys::hipblasCreate(&mut handle);
(Self { handle }, status).to_result()
}
}

/// Returns the raw hipBLAS handle.
///
/// # Safety
///
/// The returned handle should not be destroyed manually or used after
/// the BlasHandle is dropped.
pub fn handle(&self) -> sys::hipblasHandle_t {
self.handle
}
}

// Implement Drop to clean up the handle
impl Drop for BlasHandle {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe {
let status = sys::hipblasDestroy(self.handle);
if status != 0 {
log::error!("Failed to destroy hipBLAS handle: {}", status);
}
}
}
}
}

// Implement Display for better error messages
impl fmt::Display for BlasHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BlasHandle({:p})", self.handle)
}
}

// Implement Send and Sync as hipBLAS handles are thread-safe
unsafe impl Send for BlasHandle {}
unsafe impl Sync for BlasHandle {}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_handle_create() {
let handle = BlasHandle::new();
assert!(handle.is_ok(), "Failed to create BlasHandle");
let handle = handle.unwrap();
assert!(!handle.handle().is_null(), "Handle is null after creation");
}

#[test]
fn test_handle_drop() {
let handle = BlasHandle::new().unwrap();
drop(handle); // Should not panic or cause memory leaks
}

#[test]
fn test_multiple_handles() {
// Create multiple handles to ensure they don't interfere
let handle1 = BlasHandle::new().unwrap();
let handle2 = BlasHandle::new().unwrap();

assert!(!handle1.handle().is_null());
assert!(!handle2.handle().is_null());
assert_ne!(
handle1.handle(),
handle2.handle(),
"Handles should be unique"
);
}

#[test]
fn test_handle_clone_not_implemented() {
let handle = BlasHandle::new().unwrap();
// This should fail to compile if you try to uncomment it
// let _cloned = handle.clone();
}

#[test]
fn test_handle_send_sync() {
// Test that handle can be sent between threads
let handle = BlasHandle::new().unwrap();
let handle_ptr = handle.handle();

let handle = std::thread::spawn(move || {
assert!(!handle.handle().is_null());
handle
})
.join()
.unwrap();

assert_eq!(handle.handle(), handle_ptr);
}

#[test]
fn test_handle_concurrent_use() {
use std::sync::Arc;
use std::thread;

let handle = Arc::new(BlasHandle::new().unwrap());
let mut threads = vec![];

// Spawn multiple threads using the same handle
for _ in 0..4 {
let handle_clone = Arc::clone(&handle);
threads.push(thread::spawn(move || {
assert!(!handle_clone.handle().is_null());
}));
}

// Wait for all threads to complete
for thread in threads {
thread.join().unwrap();
}
}

#[test]
fn test_handle_in_closure() {
let handle = BlasHandle::new().unwrap();
let closure = || {
assert!(!handle.handle().is_null());
};
closure();
}

#[test]
fn test_handle_debug_format() {
let handle = BlasHandle::new().unwrap();
let debug_str = format!("{:?}", handle);
assert!(!debug_str.is_empty(), "Debug formatting failed");
println!("Debug format of BlasHandle: {}", debug_str);
}

#[test]
fn test_handle_memory_stress() {
// Create and destroy multiple handles in a loop
for _ in 0..100 {
let handle = BlasHandle::new().unwrap();
assert!(!handle.handle().is_null());
drop(handle);
}
}

#[test]
fn test_handle_null_check() {
let handle = BlasHandle::new().unwrap();
assert!(!handle.handle().is_null(), "Handle should not be null");
}
}
7 changes: 7 additions & 0 deletions src/hipblas/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod handle;
mod types;

use crate::sys;

pub use handle::*;
pub use types::*;
50 changes: 50 additions & 0 deletions src/hipblas/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::sys;

#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Operation {
None = 0, // HIPBLAS_OP_N
Transpose = 1, // HIPBLAS_OP_T
Conjugate = 2, // HIPBLAS_OP_C
}

impl From<Operation> for sys::hipblasOperation_t {
fn from(op: Operation) -> Self {
op as sys::hipblasOperation_t
}
}

#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Status {
Success = 0,
Handle = 1,
NotInitialized = 2,
InvalidValue = 3,
ArchMismatch = 4,
MappingError = 5,
ExecutionFailed = 6,
InternalError = 7,
NotSupported = 8,
MemoryError = 9,
AllocationFailed = 10,
}

impl From<sys::hipblasStatus_t> for Status {
fn from(status: sys::hipblasStatus_t) -> Self {
match status {
0 => Status::Success,
1 => Status::Handle,
2 => Status::NotInitialized,
3 => Status::InvalidValue,
4 => Status::ArchMismatch,
5 => Status::MappingError,
6 => Status::ExecutionFailed,
7 => Status::InternalError,
8 => Status::NotSupported,
9 => Status::MemoryError,
10 => Status::AllocationFailed,
_ => Status::InternalError,
}
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#![allow(non_upper_case_globals)]
mod core;
mod hipblas;
mod sys;

pub use core::*;
pub use hipblas::*;
File renamed without changes.
1 change: 1 addition & 0 deletions src/core/sys/wrapper.h → src/sys/wrapper.h
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>

0 comments on commit 29fc0ed

Please sign in to comment.