From 8c282e1b09a4c52495cc05817d92906d7cf42ebd Mon Sep 17 00:00:00 2001 From: Matthew Baker Date: Sat, 19 Oct 2024 21:06:51 -0400 Subject: [PATCH] BINDINGS/RUST: Initial Rust bindings --- .gitignore | 4 + bindings/rust/Cargo.toml | 12 ++ bindings/rust/build.rs | 45 ++++++++ bindings/rust/rustfmt.toml | 2 + bindings/rust/src/am.rs | 120 +++++++++++++++++++ bindings/rust/src/context.rs | 201 ++++++++++++++++++++++++++++++++ bindings/rust/src/ep.rs | 110 ++++++++++++++++++ bindings/rust/src/ffi.rs | 6 + bindings/rust/src/lib.rs | 208 +++++++++++++++++++++++++++++++++ bindings/rust/src/worker.rs | 216 +++++++++++++++++++++++++++++++++++ bindings/rust/wrapper.h | 1 + 11 files changed, 925 insertions(+) create mode 100644 bindings/rust/Cargo.toml create mode 100644 bindings/rust/build.rs create mode 100644 bindings/rust/rustfmt.toml create mode 100644 bindings/rust/src/am.rs create mode 100644 bindings/rust/src/context.rs create mode 100644 bindings/rust/src/ep.rs create mode 100644 bindings/rust/src/ffi.rs create mode 100644 bindings/rust/src/lib.rs create mode 100644 bindings/rust/src/worker.rs create mode 100644 bindings/rust/wrapper.h diff --git a/.gitignore b/.gitignore index ef13efcace9..8234f162181 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,7 @@ go.work go.work.sum bindings/go/pkg/ bindings/go/go.sum +*~ +\#* +bindings/rust/target +bindings/rust/Cargo.lock diff --git a/bindings/rust/Cargo.toml b/bindings/rust/Cargo.toml new file mode 100644 index 00000000000..7a7193eb8b3 --- /dev/null +++ b/bindings/rust/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "ucx-sys" +version = "0.1.0" +edition = "2021" +description = "Low level Rust bindings" +license = "BSD" + +[build-dependencies] +bindgen = "0.70.1" + +[dependencies.bitflags] +version = "2.6.0" diff --git a/bindings/rust/build.rs b/bindings/rust/build.rs new file mode 100644 index 00000000000..5fe0cec84fd --- /dev/null +++ b/bindings/rust/build.rs @@ -0,0 +1,45 @@ +use std::env; +use std::path::PathBuf; + +fn main() { + // Tell cargo to look for shared libraries in the specified directory + println!("cargo:rustc-link-search=../../src/ucp/.libs/"); + + // Tell cargo to tell rustc to link the system bzip2 + // shared library. + println!("cargo:rustc-link-lib=ucp"); + + // The bindgen::Builder is the main entry point + // to bindgen, and lets you build up options for + // the resulting bindings. + let bindings = bindgen::Builder::default() + // Some of the UCX detailed examples in comments can confuse the + // bindgen parser and it will make bad code instead of comments + .generate_comments(false) + // ucs_status_t is defined as a packed enum and that will lead to + // badness without the flag which tells bindgen to repeat that + // trick with the rust enums + .rustified_enum(".*") + .clang_arg("-I../../src/ucp/api/") + .clang_arg("-I../../") + .clang_arg("-I../../src/") + // Annotate ucs_status_t and ucs_status_ptr_t as #[must_use] + .must_use_type("ucs_status_t") + .must_use_type("ucs_status_ptr_t") + // The input header we would like to generate + // bindings for. + .header("wrapper.h") + // Tell cargo to invalidate the built crate whenever any of the + // included header files changed. + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + // Finish the builder and generate the bindings. + .generate() + // Unwrap the Result and panic on failure. + .expect("Unable to generate bindings"); + + // Write the bindings to the $OUT_DIR/bindings.rs file. + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); +} diff --git a/bindings/rust/rustfmt.toml b/bindings/rust/rustfmt.toml new file mode 100644 index 00000000000..83f0002f0f1 --- /dev/null +++ b/bindings/rust/rustfmt.toml @@ -0,0 +1,2 @@ +hard_tabs = false +tab_spaces = 4 diff --git a/bindings/rust/src/am.rs b/bindings/rust/src/am.rs new file mode 100644 index 00000000000..ef9508a00f9 --- /dev/null +++ b/bindings/rust/src/am.rs @@ -0,0 +1,120 @@ +use crate::ep::Ep; +use crate::ffi::*; +use crate::status_ptr_to_result; +use crate::status_to_result; +use crate::worker::Worker; +use crate::Request; +use crate::RequestParam; +use bitflags::bitflags; + +type AmRecvCb = unsafe extern "C" fn( + arg: *mut ::std::os::raw::c_void, + header: *const ::std::os::raw::c_void, + header_length: usize, + data: *mut ::std::os::raw::c_void, + length: usize, + param: *const ucp_am_recv_param_t, +) -> ucs_status_t; + +impl Worker<'_> { + #[inline] + pub fn am_register(&self, am_param: &HandlerParams) -> Result<(), ucs_status_t> { + status_to_result(unsafe { ucp_worker_set_am_recv_handler(self.handle, &am_param.handle) }) + } +} + +impl Ep<'_> { + #[inline] + pub fn am_send( + &self, + id: u32, + header: &[u8], + data: &[u8], + params: &RequestParam, + ) -> Result, ucs_status_t> { + status_ptr_to_result(unsafe { + ucp_am_send_nbx( + self.handle, + id, + header.as_ptr() as _, + header.len(), + data.as_ptr() as _, + data.len(), + ¶ms.handle, + ) + }) + } +} + +bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct CbFlags: u32 { + const WholeMsg = ucp_am_cb_flags::UCP_AM_FLAG_WHOLE_MSG as u32; + const PersistentData = ucp_am_cb_flags::UCP_AM_FLAG_PERSISTENT_DATA as u32; + } +} + +#[derive(Debug, Clone)] +pub struct HandlerParamsBuilder { + uninit_handle: std::mem::MaybeUninit, + flags: u64, +} + +impl HandlerParamsBuilder { + #[inline] + pub fn new() -> HandlerParamsBuilder { + let uninit_params = std::mem::MaybeUninit::::uninit(); + HandlerParamsBuilder { + uninit_handle: uninit_params, + flags: 0, + } + } + + #[inline] + pub fn id(&mut self, id: u32) -> &mut HandlerParamsBuilder { + self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_ID as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.id = id; + self + } + + #[inline] + pub fn flags(&mut self, flags: CbFlags) -> &mut HandlerParamsBuilder { + self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_FLAGS as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.flags = flags.bits(); + self + } + + #[inline] + pub fn cb(&mut self, cb: AmRecvCb) -> &mut HandlerParamsBuilder { + self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_CB as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.cb = Some(cb); + self + } + + #[inline] + pub fn arg(&mut self, arg: *mut std::os::raw::c_void) -> &mut HandlerParamsBuilder { + self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_ARG as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.arg = arg; + self + } + + #[inline] + pub fn build(&mut self) -> HandlerParams { + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.field_mask = self.flags; + + let handler_param = HandlerParams { + handle: unsafe { self.uninit_handle.assume_init() }, + }; + + handler_param + } +} + +pub struct HandlerParams { + pub(crate) handle: ucp_am_handler_param_t, +} diff --git a/bindings/rust/src/context.rs b/bindings/rust/src/context.rs new file mode 100644 index 00000000000..73730b6d77b --- /dev/null +++ b/bindings/rust/src/context.rs @@ -0,0 +1,201 @@ +use crate::ffi::*; +use crate::status_to_result; +use crate::worker; +use crate::worker::Worker; +use bitflags::bitflags; +use std::ffi::CString; + +type RequestInitCb = unsafe extern "C" fn(request: *mut ::std::os::raw::c_void); +type RequestCleanUpCb = unsafe extern "C" fn(request: *mut ::std::os::raw::c_void); + +bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct Flags: u64 { + const Tag = ucp_feature::UCP_FEATURE_TAG as u64; + const Rma = ucp_feature::UCP_FEATURE_RMA as u64; + const Amo32 = ucp_feature::UCP_FEATURE_AMO32 as u64; + const Amo64 = ucp_feature::UCP_FEATURE_AMO64 as u64; + const Wakeup = ucp_feature::UCP_FEATURE_WAKEUP as u64; + const Stream = ucp_feature::UCP_FEATURE_STREAM as u64; + const Am = ucp_feature::UCP_FEATURE_AM as u64; + const ExportedMemH = ucp_feature::UCP_FEATURE_EXPORTED_MEMH as u64; + } +} + +pub struct Config { + handle: *mut ucp_config_t, +} + +impl Config { + pub fn read(name: &str, file: &str) -> Result<*mut ucp_config_t, ucs_status_t> { + let mut config: *mut ucp_config_t = std::ptr::null_mut(); + let c_name = CString::new(name).unwrap(); + let c_file = CString::new(file).unwrap(); + status_to_result(unsafe { ucp_config_read(c_name.as_ptr(), c_file.as_ptr(), &mut config) }) + .unwrap(); + return Ok(config); + } +} + +impl Default for Config { + fn default() -> Self { + let config = Config::read("", "").unwrap(); + Config { handle: config } + } +} + +impl Drop for Config { + fn drop(&mut self) { + unsafe { ucp_config_release(self.handle) }; + } +} + +#[derive(Debug, Clone)] +pub struct ParamsBuilder { + uninit_handle: std::mem::MaybeUninit, + field_mask: u64, + name: Option, +} + +#[derive(Debug, Clone)] +pub struct Params { + handle: ucp_params_t, + name: Option, +} + +// This builder wraps up the unsafe parts of building the ucp_param_t struct. On construction +// it makes a zero filled ucp_params_t which Rust considers uninitialized. Each call on the builder +// will fill in the fields of the struct and add the mask for that field. On the final build() +// it will fill in the final value of the features field_mask and proclame the rest of the struct +// as initialized. This is Rust safe because all of the other fields are guaranteed to not be used +// by the library since the proper feature flag is not set. + +impl ParamsBuilder { + pub fn new() -> ParamsBuilder { + let uninit_params = std::mem::MaybeUninit::::uninit(); + ParamsBuilder { + uninit_handle: uninit_params, + field_mask: 0, + name: None, + } + } + + pub fn features(&mut self, features: Flags) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_FEATURES as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.features = features.bits(); + self + } + + pub fn request_size(&mut self, size: usize) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_SIZE as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.request_size = size; + self + } + + pub fn request_init(&mut self, cb: RequestInitCb) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_INIT as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + + params.request_init = Some(cb); + self + } + + pub fn request_cleanup(&mut self, cb: RequestCleanUpCb) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_CLEANUP as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.request_cleanup = Some(cb); + self + } + + pub fn tag_sender_mask(&mut self, mask: u64) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_TAG_SENDER_MASK as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.tag_sender_mask = mask; + self + } + + pub fn mt_workers_shared(&mut self, shared: i32) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.mt_workers_shared = shared; + self + } + + pub fn estimated_num_eps(&mut self, num_eps: usize) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_ESTIMATED_NUM_EPS as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.estimated_num_eps = num_eps; + self + } + + pub fn estimated_num_ppn(&mut self, num_ppn: usize) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_ESTIMATED_NUM_PPN as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.estimated_num_ppn = num_ppn; + self + } + + pub fn name(&mut self, name: &str) -> &mut ParamsBuilder { + self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_NAME as u64; + let name_cs = CString::new(name).unwrap(); + self.name = Some(name_cs); + self + } + + pub fn build(&mut self) -> Params { + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.field_mask = self.field_mask; + + let mut ucp_param = Params { + name: None, + handle: unsafe { self.uninit_handle.assume_init() }, + }; + + if self.name.is_some() { + let new_name = self.name.clone().unwrap(); + ucp_param.handle.name = new_name.as_ptr(); + ucp_param.name = Some(new_name); + } + + ucp_param + } +} + +impl Context { + pub fn new(config: &Config, params: &Params) -> Result { + let mut context: ucp_context_h = std::ptr::null_mut(); + + let result = status_to_result(unsafe { + ucp_init_version( + UCP_API_MAJOR, + UCP_API_MINOR, + ¶ms.handle, + config.handle, + &mut context, + ) + }); + match result { + Ok(()) => Ok(Context { handle: context }), + Err(ucs_status_t) => Err(ucs_status_t), + } + } + + pub fn worker_create<'a>( + &'a self, + params: &'a worker::Params, + ) -> Result, ucs_status_t> { + Worker::new(self, params) + } +} + +pub struct Context { + pub(crate) handle: ucp_context_h, +} + +impl Drop for Context { + fn drop(&mut self) { + unsafe { ucp_cleanup(self.handle) }; + } +} diff --git a/bindings/rust/src/ep.rs b/bindings/rust/src/ep.rs new file mode 100644 index 00000000000..6c5b3bdf00b --- /dev/null +++ b/bindings/rust/src/ep.rs @@ -0,0 +1,110 @@ +use crate::ffi::*; +use crate::status_ptr_to_result; +use crate::status_to_result; +use crate::worker::RemoteWorkerAddress; +use crate::worker::Worker; +use crate::worker::WorkerAddress; +use bitflags::bitflags; +use std::ffi::CString; +use std::ptr::NonNull; + +pub struct Ep<'a> { + pub(crate) handle: ucp_ep_h, + worker: &'a Worker<'a>, +} + +impl Ep<'_> { + pub fn new<'a>(ep_params: &Params, worker: &'a Worker<'a>) -> Result, ucs_status_t> { + let mut ep: ucp_ep_h = std::ptr::null_mut(); + let result = + status_to_result(unsafe { ucp_ep_create(worker.handle, &ep_params.handle, &mut ep) }); + match result { + Ok(()) => Ok(Ep { + handle: ep, + worker: worker, + }), + Err(ucs_status_t) => Err(ucs_status_t), + } + } +} + +impl Drop for Ep<'_> { + fn drop(&mut self) { + let param: ucp_request_param_t = unsafe { std::mem::zeroed() }; + let result = + status_ptr_to_result(unsafe { ucp_ep_close_nbx(self.handle, ¶m) }).unwrap(); + if result.is_some() { + unsafe { ucp_request_free(result.unwrap().handle.as_mut()) }; + } + } +} + +bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct UcpEpFields: u64 { + const None = ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_NONE as u64; + const Peer = ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER as u64; + } +} + +bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct ParamsFlags: u64 { + const ClientServer = ucp_ep_params_flags_field::UCP_EP_PARAMS_FLAGS_CLIENT_SERVER as u64; + const NoLoopback = ucp_ep_params_flags_field::UCP_EP_PARAMS_FLAGS_NO_LOOPBACK as u64; + const SendClientId = ucp_ep_params_flags_field::UCP_EP_PARAMS_FLAGS_SEND_CLIENT_ID as u64; + } +} + +#[derive(Debug, Clone)] +pub struct Params { + pub(crate) handle: ucp_ep_params_t, + name: Option, +} + +#[derive(Debug, Clone)] +pub struct ParamsBuilder { + uninit_handle: std::mem::MaybeUninit, + field_mask: u64, + name: Option, +} + +impl ParamsBuilder { + pub fn new() -> ParamsBuilder { + let uninit_params = std::mem::MaybeUninit::::uninit(); + ParamsBuilder { + uninit_handle: uninit_params, + field_mask: 0, + name: None, + } + } + + pub fn local_address(&mut self, worker_address: &WorkerAddress) -> &mut ParamsBuilder { + self.field_mask |= ucp_ep_params_field::UCP_EP_PARAM_FIELD_REMOTE_ADDRESS as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.address = worker_address.handle; + self + } + + pub fn name(&mut self, name: &str) -> &mut ParamsBuilder { + self.field_mask |= ucp_ep_params_field::UCP_EP_PARAM_FIELD_NAME as u64; + let name_cs = CString::new(name).unwrap(); + self.name = Some(name_cs); + self + } + + pub fn build(&mut self) -> Params { + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.field_mask = self.field_mask; + let mut ep_param = Params { + handle: unsafe { self.uninit_handle.assume_init() }, + name: None, + }; + if self.name.is_some() { + let new_name = self.name.clone().unwrap(); + ep_param.handle.name = new_name.as_ptr(); + ep_param.name = Some(new_name); + } + ep_param + } +} diff --git a/bindings/rust/src/ffi.rs b/bindings/rust/src/ffi.rs new file mode 100644 index 00000000000..cd503e4b5d7 --- /dev/null +++ b/bindings/rust/src/ffi.rs @@ -0,0 +1,6 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(dead_code)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs new file mode 100644 index 00000000000..905be4dfcfe --- /dev/null +++ b/bindings/rust/src/lib.rs @@ -0,0 +1,208 @@ +#![allow(unused_imports)] + +mod ffi; +use crate::ffi::*; + +pub mod am; +pub mod context; +pub mod ep; +pub mod worker; + +use std::ffi::CString; +use std::ptr::NonNull; + +// UCX request backed by a ucs_status_ptr_t that is non-null and not an error, thus is a request pointer +pub struct Request { + pub(crate) handle: NonNull<::std::os::raw::c_void>, +} + +impl Drop for Request { + fn drop(&mut self) { + unsafe { ucp_request_free(self.handle.as_ptr()) }; + } +} + +impl Request { + // new assumes that the type has already been error checked. + #[inline] + pub fn new(request_handle: *mut std::os::raw::c_void) -> Option { + let request = NonNull::<::std::os::raw::c_void>::new(request_handle); + match request { + None => None, + Some(x) => Some(Request { handle: x }), + } + } + + // check an outstanding request. Returns an error if the request had an error, returns false if the request is not completed, returns true if the request is completed + #[inline] + pub fn check_status(&self) -> Result { + let status = unsafe { ucp_request_check_status(self.handle.as_ptr()) }; + if status as usize >= ucs_status_t::UCS_ERR_LAST as usize { + return Err(unsafe { std::mem::transmute(status as i8) }); + } + Ok(status == ucs_status_t::UCS_OK) + } +} + +// In UCX we usually use a ucs_status_ptr_t to represent the status of a nonblocking operation +// in this the possible outcomes can be UCS_OK, where the application can reuse all the input +// parameters immediately, a pointer that can be queried for the status of the underlying +// nonblocking operation, or an error. Rust APIs operate similarly, except it uses the Rust +// type system to express this. It will have a Result type that either contains an Ok() type +// or an Err() type. It also has an Option() type that basically is the equivalent of a nullable +// pointer, except Rust will force the user to be sure to check the Option(). + +// This helper function will automatically translate the ucs_status_ptr_t into a Result that +// either is an empty Ok() as the equivilent to UCS_OK, a Ok(Request) that represents getting +// back a pointer or an Err(ucs_status_t) that indicates an error. Compile test shows that this +// produces extremely efficient assembly + +#[inline] +pub fn status_ptr_to_result(ptr: ucs_status_ptr_t) -> Result, ucs_status_t> { + // This is equivlent to the UCS_PTR_IS_ERR() macro. + if ptr as usize >= ucs_status_t::UCS_ERR_LAST as usize { + // The transmute() function is how you access C style memory magic. This function will + // take the intput pointer and then translate it into i8 and then rust will turn the i8 + // into the proper ucs_status_t. + return Err(unsafe { std::mem::transmute(ptr as i8) }); + } + Ok(Request::new(ptr)) +} + +#[inline] +pub fn status_to_result(status: ucs_status_t) -> Result<(), ucs_status_t> { + if (status as i8) < 0 { + return Err(status); + } + Ok(()) +} + +pub struct RequestParam { + pub(crate) handle: ucp_request_param_t, +} + +#[derive(Debug, Copy, Clone)] +pub struct RequestParamBuilder { + uninit_handle: std::mem::MaybeUninit, + field_mask: u32, +} + +impl RequestParamBuilder { + pub fn new() -> RequestParamBuilder { + let uninit_params = std::mem::MaybeUninit::::uninit(); + RequestParamBuilder { + uninit_handle: uninit_params, + field_mask: 0, + } + } + + pub fn force_imm_cmpl(&mut self) -> &mut RequestParamBuilder { + if self.field_mask & ucp_op_attr_t::UCP_OP_ATTR_FLAG_NO_IMM_CMPL as u32 != 0 { + panic!("Requesting UCP_OP_ATTR_FLAG_FORCE_IMM_CMPL while UCP_OP_ATTR_FLAG_NO_IMM_CMPL is also set"); + } + self.field_mask |= ucp_op_attr_t::UCP_OP_ATTR_FLAG_FORCE_IMM_CMPL as u32; + self + } + + pub fn no_imm_cmpl(&mut self) -> &mut RequestParamBuilder { + if self.field_mask & ucp_op_attr_t::UCP_OP_ATTR_FLAG_FORCE_IMM_CMPL as u32 != 0 { + panic!("Requesting UCP_OP_ATTR_FLAG_NO_IMM_CMPL while UCP_OP_ATTR_FLAG_FORCE_IMM_CMPL is also set"); + } + self.field_mask |= ucp_op_attr_t::UCP_OP_ATTR_FLAG_NO_IMM_CMPL as u32; + self + } + + pub fn build(&mut self) -> RequestParam { + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.op_attr_mask = self.field_mask; + + let ucp_param = RequestParam { + handle: unsafe { self.uninit_handle.assume_init() }, + }; + + ucp_param + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::context; + use crate::context::Context; + use crate::ep; + use crate::worker; + + const TEST_AM_ID: u32 = 5; + + extern "C" fn init(_request: *mut ::std::os::raw::c_void) {} + + extern "C" fn cleanup(_request: *mut ::std::os::raw::c_void) {} + + unsafe extern "C" fn am_cb( + arg: *mut ::std::os::raw::c_void, + header: *const ::std::os::raw::c_void, + header_length: usize, + _data: *mut ::std::os::raw::c_void, + _length: usize, + _param: *const ucp_am_recv_param_t, + ) -> ucs_status_t { + let message = std::slice::from_raw_parts_mut(arg as *mut i8, 1); + let in_data = std::slice::from_raw_parts(header as *const i8, header_length); + message[0] = in_data[0]; + ucs_status_t::UCS_OK + } + + #[test] + fn it_works() { + let mut message = vec![0]; + let features = context::Flags::Am + | context::Flags::Rma + | context::Flags::Amo32 + | context::Flags::Amo64; + let params = context::ParamsBuilder::new() + .features(features) + .mt_workers_shared(1) + .request_init(init) + .request_cleanup(cleanup) + .request_size(8) + .name("My Awesome Test") + .tag_sender_mask(std::u64::MAX) + .estimated_num_eps(4) + .estimated_num_ppn(2) + .build(); + let context = Context::new(&context::Config::default(), ¶ms).unwrap(); + + let worker_features = worker::ParamsBuilder::new() + .thread_mode(ucs_thread_mode_t::UCS_THREAD_MODE_MULTI) + .build(); + let worker = context.worker_create(&worker_features).unwrap(); + + let am_params = am::HandlerParamsBuilder::new() + .id(TEST_AM_ID) + .cb(am_cb) + .arg(message.as_mut_ptr() as *mut std::os::raw::c_void) + .build(); + worker.am_register(&am_params).unwrap(); + + let addr = worker.pack_address().unwrap(); + let ep_param = ep::ParamsBuilder::new().local_address(&addr).build(); + let ep = worker.create_ep(&ep_param).unwrap(); + + let tag = vec![32]; + let am_flags = RequestParamBuilder::new() + //.force_imm_cmpl() // uncomment this line to see the the compile time error checker in action + .no_imm_cmpl() + .build(); + + let req = ep + .am_send(TEST_AM_ID, tag.as_slice(), b"", &am_flags) + .unwrap(); + if req.is_some() { + let req = req.unwrap(); + while !req.check_status().unwrap() { + worker.progress(); + } + } + assert_eq!(message[0], tag[0]); + } +} diff --git a/bindings/rust/src/worker.rs b/bindings/rust/src/worker.rs new file mode 100644 index 00000000000..b44df2b00f8 --- /dev/null +++ b/bindings/rust/src/worker.rs @@ -0,0 +1,216 @@ +use crate::context::Context; +use crate::ep; +use crate::ep::Ep; +use crate::ffi::*; +use crate::status_to_result; +use bitflags::bitflags; +use std::ffi::CString; +use std::ptr::NonNull; + +pub struct Worker<'a> { + pub(crate) handle: ucp_worker_h, + #[allow(dead_code)] + parent: &'a Context, +} + +impl Drop for Worker<'_> { + fn drop(&mut self) { + unsafe { ucp_worker_destroy(self.handle) }; + } +} + +impl Worker<'_> { + pub(crate) fn new<'a>( + context: &'a Context, + params: &'a Params, + ) -> Result, ucs_status_t> { + let mut worker: ucp_worker_h = std::ptr::null_mut(); + + let result = status_to_result(unsafe { + ucp_worker_create(context.handle, ¶ms.handle, &mut worker) + }); + match result { + Ok(()) => Ok(Worker { + handle: worker, + parent: context, + }), + Err(ucs_status_t) => Err(ucs_status_t), + } + } + + pub fn pack_address(&self) -> Result { + let mut address: *mut ucp_address_t = std::ptr::null_mut(); + let mut size: usize = 0; + + let result = status_to_result(unsafe { + ucp_worker_get_address(self.handle, &mut address, &mut size) + }); + match result { + Ok(()) => Ok(WorkerAddress { + handle: address, + parent: self, + size: size, + }), + Err(ucs_status_t) => Err(ucs_status_t), + } + } + + #[inline] + pub fn progress(&self) -> bool { + let progress = unsafe { ucp_worker_progress(self.handle) }; + progress > 0 + } + + pub fn create_ep(&self, ep_params: &ep::Params) -> Result { + return Ep::new(&ep_params, &self); + } +} + +pub struct RemoteWorkerAddress { + address: Vec, +} + +impl RemoteWorkerAddress { + pub fn new(address: Vec) -> RemoteWorkerAddress { + RemoteWorkerAddress { address: address } + } + + pub fn get_handle(&self) -> (*const ucp_address_t, usize) { + ( + self.address.as_ptr() as *const ucp_address_t, + self.address.len(), + ) + } +} + +pub struct WorkerAddress<'a> { + pub(crate) handle: *const ucp_address_t, + size: usize, + parent: &'a Worker<'a>, +} + +impl WorkerAddress<'_> { + pub fn to_bytes(&self) -> Vec { + unsafe { std::slice::from_raw_parts(self.handle as *const u8, self.size) }.to_vec() + } +} + +impl Drop for WorkerAddress<'_> { + fn drop(&mut self) { + unsafe { + ucp_worker_release_address(self.parent.handle, self.handle as *mut ucp_address_t) + }; + } +} + +bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct UcpWorkerFlags: u64 { + const Flags = ucp_worker_flags_t::UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK as u64; + } +} + +impl ParamsBuilder { + pub fn new() -> ParamsBuilder { + let uninit_params = std::mem::MaybeUninit::::uninit(); + ParamsBuilder { + uninit_handle: uninit_params, + field_mask: 0, + name: None, + } + } + + pub fn thread_mode(&mut self, thread_mode: ucs_thread_mode_t) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_THREAD_MODE as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.thread_mode = thread_mode; + self + } + + pub fn cpu_set(&mut self, cpu_set: ucs_cpu_set_t) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_CPU_MASK as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.cpu_mask = cpu_set; + self + } + + pub fn events(&mut self, events: u32) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_EVENTS as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.events = events; + self + } + + pub fn user_data(&mut self, data: *mut std::ffi::c_void) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_USER_DATA as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.user_data = data; + self + } + + pub fn event_fd(&mut self, event_fd: i32) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_EVENT_FD as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.event_fd = event_fd; + self + } + + pub fn flags(&mut self, flags: UcpWorkerFlags) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_FLAGS as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.flags = flags.bits(); + self + } + + pub fn name(&mut self, name: &str) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_NAME as u64; + let name_cs = CString::new(name).unwrap(); + self.name = Some(name_cs); + self + } + + pub fn am_alignment(&mut self, am_alignment: usize) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_AM_ALIGNMENT as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.am_alignment = am_alignment; + self + } + + pub fn client_id(&mut self, client_id: u64) -> &mut ParamsBuilder { + self.field_mask |= ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_CLIENT_ID as u64; + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.client_id = client_id; + self + } + + pub fn build(&mut self) -> Params { + let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() }; + params.field_mask = self.field_mask; + + let mut ucp_param = Params { + name: None, + handle: unsafe { self.uninit_handle.assume_init() }, + }; + + if self.name.is_some() { + let new_name = self.name.clone().unwrap(); + ucp_param.handle.name = new_name.as_ptr(); + ucp_param.name = Some(new_name); + } + + ucp_param + } +} + +#[derive(Debug, Clone)] +pub struct ParamsBuilder { + uninit_handle: std::mem::MaybeUninit, + field_mask: u64, + name: Option, +} + +#[derive(Debug, Clone)] +pub struct Params { + pub(crate) handle: ucp_worker_params_t, + name: Option, +} diff --git a/bindings/rust/wrapper.h b/bindings/rust/wrapper.h new file mode 100644 index 00000000000..de21508129c --- /dev/null +++ b/bindings/rust/wrapper.h @@ -0,0 +1 @@ +#include