Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

153 add hipblas #154

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ fn main() {
return;
}

// link hipBLAS
println!("cargo:rustc-link-lib=dylib=hipblas");

// Tell cargo when to rerun this build script
println!("cargo:rerun-if-changed=src/core/sys/wrapper.h");
println!("cargo:rerun-if-changed=build.rs");
Expand All @@ -38,7 +41,7 @@ fn main() {
fn generate_bindings(hip_include_path: &str) {
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
let bindings = bindgen::Builder::default()
.header("src/core/sys/wrapper.h")
.header("src/sys/wrapper.h")
.clang_arg(&format!("-I{}", hip_include_path))
.clang_arg("-D__HIP_PLATFORM_AMD__")
// Blocklist problematic items
Expand Down
2 changes: 1 addition & 1 deletion src/core/device.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::sys;
use super::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, MemPool, PCIBusId, Result};
use crate::sys;
use semver::Version;
use std::ffi::CStr;
use std::i32;
Expand Down
2 changes: 1 addition & 1 deletion src/core/device_types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::sys;
use super::{HipError, HipErrorKind, HipResult, Result};
use crate::sys;
use std::ffi::CStr;
use std::i32;

Expand Down
3 changes: 2 additions & 1 deletion src/core/hip_call.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{sys, HipResult, Result};
use super::{HipResult, Result};
use crate::sys;

#[macro_export]
macro_rules! hip_call {
Expand Down
2 changes: 1 addition & 1 deletion src/core/init.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::sys;
use crate::sys;
use crate::{HipResult, Result};
use semver::Version;
use std::i32;
Expand Down
2 changes: 1 addition & 1 deletion src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ mod init;
mod memory;
mod result;
mod stream;
pub mod sys;

// use crate::sys::*;
// Re-export core functionality
pub use device::*;
pub use device_types::*;
Expand Down
198 changes: 198 additions & 0 deletions src/hipblas/handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use crate::sys;
use crate::{HipResult, Result};
use std::fmt;

/// A handle to a hipBLAS library context.
///
/// This handle is required for all hipBLAS library calls and encapsulates the
/// hipBLAS library context. The context includes the HIP device number and
/// stream used for all hipBLAS operations using this handle.
///
/// # Thread Safety
///
/// The handle is thread-safe and can be shared between threads. It implements
/// Send and Sync traits.
///
/// # Examples
///
/// ```
/// use hip_rs::BlasHandle;
///
/// let handle = BlasHandle::new().unwrap();
/// // Use handle for hipBLAS operations
/// ```
#[derive(Debug)]
pub struct BlasHandle {
handle: sys::hipblasHandle_t,
}

impl BlasHandle {
/// Creates a new hipBLAS library context.
///
/// # Returns
///
/// * `Ok(BlasHandle)` - A new handle for hipBLAS operations
/// * `Err(HipError)` - If handle creation fails
///
/// # Examples
///
/// ```
/// use hip_rs::BlasHandle;
///
/// let handle = BlasHandle::new().unwrap();
/// ```
pub fn new() -> Result<Self> {
let mut handle = std::ptr::null_mut();
unsafe {
let status = sys::hipblasCreate(&mut handle);
(Self { handle }, status).to_result()
}
}

/// Returns the raw hipBLAS handle.
///
/// # Safety
///
/// The returned handle should not be destroyed manually or used after
/// the BlasHandle is dropped.
pub fn handle(&self) -> sys::hipblasHandle_t {
self.handle
}
}

// Implement Drop to clean up the handle
impl Drop for BlasHandle {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe {
let status = sys::hipblasDestroy(self.handle);
if status != 0 {
log::error!("Failed to destroy hipBLAS handle: {}", status);
}
}
}
}
}

// Implement Display for better error messages
impl fmt::Display for BlasHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BlasHandle({:p})", self.handle)
}
}

// Implement Send and Sync as hipBLAS handles are thread-safe
unsafe impl Send for BlasHandle {}
unsafe impl Sync for BlasHandle {}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_handle_create() {
let handle = BlasHandle::new();
assert!(handle.is_ok(), "Failed to create BlasHandle");
let handle = handle.unwrap();
assert!(!handle.handle().is_null(), "Handle is null after creation");
}

#[test]
fn test_handle_drop() {
let handle = BlasHandle::new().unwrap();
drop(handle); // Should not panic or cause memory leaks
}

#[test]
fn test_multiple_handles() {
// Create multiple handles to ensure they don't interfere
let handle1 = BlasHandle::new().unwrap();
let handle2 = BlasHandle::new().unwrap();

assert!(!handle1.handle().is_null());
assert!(!handle2.handle().is_null());
assert_ne!(
handle1.handle(),
handle2.handle(),
"Handles should be unique"
);
}

#[test]
fn test_handle_clone_not_implemented() {
let handle = BlasHandle::new().unwrap();
// This should fail to compile if you try to uncomment it
// let _cloned = handle.clone();
}

#[test]
fn test_handle_send_sync() {
// Test that handle can be sent between threads
let handle = BlasHandle::new().unwrap();
let handle_ptr = handle.handle();

let handle = std::thread::spawn(move || {
assert!(!handle.handle().is_null());
handle
})
.join()
.unwrap();

assert_eq!(handle.handle(), handle_ptr);
}

#[test]
fn test_handle_concurrent_use() {
use std::sync::Arc;
use std::thread;

let handle = Arc::new(BlasHandle::new().unwrap());
let mut threads = vec![];

// Spawn multiple threads using the same handle
for _ in 0..4 {
let handle_clone = Arc::clone(&handle);
threads.push(thread::spawn(move || {
assert!(!handle_clone.handle().is_null());
}));
}

// Wait for all threads to complete
for thread in threads {
thread.join().unwrap();
}
}

#[test]
fn test_handle_in_closure() {
let handle = BlasHandle::new().unwrap();
let closure = || {
assert!(!handle.handle().is_null());
};
closure();
}

#[test]
fn test_handle_debug_format() {
let handle = BlasHandle::new().unwrap();
let debug_str = format!("{:?}", handle);
assert!(!debug_str.is_empty(), "Debug formatting failed");
println!("Debug format of BlasHandle: {}", debug_str);
}

#[test]
fn test_handle_memory_stress() {
// Create and destroy multiple handles in a loop
for _ in 0..100 {
let handle = BlasHandle::new().unwrap();
assert!(!handle.handle().is_null());
drop(handle);
}
}

#[test]
fn test_handle_null_check() {
let handle = BlasHandle::new().unwrap();
assert!(!handle.handle().is_null(), "Handle should not be null");
}
}
7 changes: 7 additions & 0 deletions src/hipblas/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod handle;
mod types;

use crate::sys;

pub use handle::*;
pub use types::*;
50 changes: 50 additions & 0 deletions src/hipblas/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::sys;

#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Operation {
None = 0, // HIPBLAS_OP_N
Transpose = 1, // HIPBLAS_OP_T
Conjugate = 2, // HIPBLAS_OP_C
}

impl From<Operation> for sys::hipblasOperation_t {
fn from(op: Operation) -> Self {
op as sys::hipblasOperation_t
}
}

#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Status {
Success = 0,
Handle = 1,
NotInitialized = 2,
InvalidValue = 3,
ArchMismatch = 4,
MappingError = 5,
ExecutionFailed = 6,
InternalError = 7,
NotSupported = 8,
MemoryError = 9,
AllocationFailed = 10,
}

impl From<sys::hipblasStatus_t> for Status {
fn from(status: sys::hipblasStatus_t) -> Self {
match status {
0 => Status::Success,
1 => Status::Handle,
2 => Status::NotInitialized,
3 => Status::InvalidValue,
4 => Status::ArchMismatch,
5 => Status::MappingError,
6 => Status::ExecutionFailed,
7 => Status::InternalError,
8 => Status::NotSupported,
9 => Status::MemoryError,
10 => Status::AllocationFailed,
_ => Status::InternalError,
}
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#![allow(non_upper_case_globals)]
mod core;
mod hipblas;
mod sys;

pub use core::*;
pub use hipblas::*;
File renamed without changes.
1 change: 1 addition & 0 deletions src/core/sys/wrapper.h → src/sys/wrapper.h
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>