-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* move bindings to src/sys/. Add hipblas/ * add hipBLAS link to build.rs
- Loading branch information
1 parent
82c975f
commit 29fc0ed
Showing
12 changed files
with
269 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
use super::sys; | ||
use crate::sys; | ||
use crate::{HipResult, Result}; | ||
use semver::Version; | ||
use std::i32; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
use crate::sys; | ||
use crate::{HipResult, Result}; | ||
use std::fmt; | ||
|
||
/// A handle to a hipBLAS library context. | ||
/// | ||
/// This handle is required for all hipBLAS library calls and encapsulates the | ||
/// hipBLAS library context. The context includes the HIP device number and | ||
/// stream used for all hipBLAS operations using this handle. | ||
/// | ||
/// # Thread Safety | ||
/// | ||
/// The handle is thread-safe and can be shared between threads. It implements | ||
/// Send and Sync traits. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ``` | ||
/// use hip_rs::BlasHandle; | ||
/// | ||
/// let handle = BlasHandle::new().unwrap(); | ||
/// // Use handle for hipBLAS operations | ||
/// ``` | ||
#[derive(Debug)] | ||
pub struct BlasHandle { | ||
handle: sys::hipblasHandle_t, | ||
} | ||
|
||
impl BlasHandle { | ||
/// Creates a new hipBLAS library context. | ||
/// | ||
/// # Returns | ||
/// | ||
/// * `Ok(BlasHandle)` - A new handle for hipBLAS operations | ||
/// * `Err(HipError)` - If handle creation fails | ||
/// | ||
/// # Examples | ||
/// | ||
/// ``` | ||
/// use hip_rs::BlasHandle; | ||
/// | ||
/// let handle = BlasHandle::new().unwrap(); | ||
/// ``` | ||
pub fn new() -> Result<Self> { | ||
let mut handle = std::ptr::null_mut(); | ||
unsafe { | ||
let status = sys::hipblasCreate(&mut handle); | ||
(Self { handle }, status).to_result() | ||
} | ||
} | ||
|
||
/// Returns the raw hipBLAS handle. | ||
/// | ||
/// # Safety | ||
/// | ||
/// The returned handle should not be destroyed manually or used after | ||
/// the BlasHandle is dropped. | ||
pub fn handle(&self) -> sys::hipblasHandle_t { | ||
self.handle | ||
} | ||
} | ||
|
||
// Implement Drop to clean up the handle | ||
impl Drop for BlasHandle { | ||
fn drop(&mut self) { | ||
if !self.handle.is_null() { | ||
unsafe { | ||
let status = sys::hipblasDestroy(self.handle); | ||
if status != 0 { | ||
log::error!("Failed to destroy hipBLAS handle: {}", status); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Implement Display for better error messages | ||
impl fmt::Display for BlasHandle { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
write!(f, "BlasHandle({:p})", self.handle) | ||
} | ||
} | ||
|
||
// Implement Send and Sync as hipBLAS handles are thread-safe | ||
unsafe impl Send for BlasHandle {} | ||
unsafe impl Sync for BlasHandle {} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn test_handle_create() { | ||
let handle = BlasHandle::new(); | ||
assert!(handle.is_ok(), "Failed to create BlasHandle"); | ||
let handle = handle.unwrap(); | ||
assert!(!handle.handle().is_null(), "Handle is null after creation"); | ||
} | ||
|
||
#[test] | ||
fn test_handle_drop() { | ||
let handle = BlasHandle::new().unwrap(); | ||
drop(handle); // Should not panic or cause memory leaks | ||
} | ||
|
||
#[test] | ||
fn test_multiple_handles() { | ||
// Create multiple handles to ensure they don't interfere | ||
let handle1 = BlasHandle::new().unwrap(); | ||
let handle2 = BlasHandle::new().unwrap(); | ||
|
||
assert!(!handle1.handle().is_null()); | ||
assert!(!handle2.handle().is_null()); | ||
assert_ne!( | ||
handle1.handle(), | ||
handle2.handle(), | ||
"Handles should be unique" | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_handle_clone_not_implemented() { | ||
let handle = BlasHandle::new().unwrap(); | ||
// This should fail to compile if you try to uncomment it | ||
// let _cloned = handle.clone(); | ||
} | ||
|
||
#[test] | ||
fn test_handle_send_sync() { | ||
// Test that handle can be sent between threads | ||
let handle = BlasHandle::new().unwrap(); | ||
let handle_ptr = handle.handle(); | ||
|
||
let handle = std::thread::spawn(move || { | ||
assert!(!handle.handle().is_null()); | ||
handle | ||
}) | ||
.join() | ||
.unwrap(); | ||
|
||
assert_eq!(handle.handle(), handle_ptr); | ||
} | ||
|
||
#[test] | ||
fn test_handle_concurrent_use() { | ||
use std::sync::Arc; | ||
use std::thread; | ||
|
||
let handle = Arc::new(BlasHandle::new().unwrap()); | ||
let mut threads = vec![]; | ||
|
||
// Spawn multiple threads using the same handle | ||
for _ in 0..4 { | ||
let handle_clone = Arc::clone(&handle); | ||
threads.push(thread::spawn(move || { | ||
assert!(!handle_clone.handle().is_null()); | ||
})); | ||
} | ||
|
||
// Wait for all threads to complete | ||
for thread in threads { | ||
thread.join().unwrap(); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_handle_in_closure() { | ||
let handle = BlasHandle::new().unwrap(); | ||
let closure = || { | ||
assert!(!handle.handle().is_null()); | ||
}; | ||
closure(); | ||
} | ||
|
||
#[test] | ||
fn test_handle_debug_format() { | ||
let handle = BlasHandle::new().unwrap(); | ||
let debug_str = format!("{:?}", handle); | ||
assert!(!debug_str.is_empty(), "Debug formatting failed"); | ||
println!("Debug format of BlasHandle: {}", debug_str); | ||
} | ||
|
||
#[test] | ||
fn test_handle_memory_stress() { | ||
// Create and destroy multiple handles in a loop | ||
for _ in 0..100 { | ||
let handle = BlasHandle::new().unwrap(); | ||
assert!(!handle.handle().is_null()); | ||
drop(handle); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_handle_null_check() { | ||
let handle = BlasHandle::new().unwrap(); | ||
assert!(!handle.handle().is_null(), "Handle should not be null"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod handle; | ||
mod types; | ||
|
||
use crate::sys; | ||
|
||
pub use handle::*; | ||
pub use types::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
use crate::sys; | ||
|
||
#[repr(u32)] | ||
#[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||
pub enum Operation { | ||
None = 0, // HIPBLAS_OP_N | ||
Transpose = 1, // HIPBLAS_OP_T | ||
Conjugate = 2, // HIPBLAS_OP_C | ||
} | ||
|
||
impl From<Operation> for sys::hipblasOperation_t { | ||
fn from(op: Operation) -> Self { | ||
op as sys::hipblasOperation_t | ||
} | ||
} | ||
|
||
#[repr(u32)] | ||
#[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||
pub enum Status { | ||
Success = 0, | ||
Handle = 1, | ||
NotInitialized = 2, | ||
InvalidValue = 3, | ||
ArchMismatch = 4, | ||
MappingError = 5, | ||
ExecutionFailed = 6, | ||
InternalError = 7, | ||
NotSupported = 8, | ||
MemoryError = 9, | ||
AllocationFailed = 10, | ||
} | ||
|
||
impl From<sys::hipblasStatus_t> for Status { | ||
fn from(status: sys::hipblasStatus_t) -> Self { | ||
match status { | ||
0 => Status::Success, | ||
1 => Status::Handle, | ||
2 => Status::NotInitialized, | ||
3 => Status::InvalidValue, | ||
4 => Status::ArchMismatch, | ||
5 => Status::MappingError, | ||
6 => Status::ExecutionFailed, | ||
7 => Status::InternalError, | ||
8 => Status::NotSupported, | ||
9 => Status::MemoryError, | ||
10 => Status::AllocationFailed, | ||
_ => Status::InternalError, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
#![allow(non_upper_case_globals)] | ||
mod core; | ||
mod hipblas; | ||
mod sys; | ||
|
||
pub use core::*; | ||
pub use hipblas::*; |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
#include <hip/hip_runtime.h> | ||
#include <hipblas/hipblas.h> |