Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BINDINGS/RUST: Initial Rust bindings #10291

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,7 @@ go.work
go.work.sum
bindings/go/pkg/
bindings/go/go.sum
*~
\#*
bindings/rust/target
bindings/rust/Cargo.lock
12 changes: 12 additions & 0 deletions bindings/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "ucx-sys"
version = "0.1.0"
edition = "2021"
description = "Low level Rust bindings"
license = "BSD"

[build-dependencies]
bindgen = "0.70.1"

[dependencies.bitflags]
version = "2.6.0"
45 changes: 45 additions & 0 deletions bindings/rust/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::env;
use std::path::PathBuf;

fn main() {
// Tell cargo to look for shared libraries in the specified directory
println!("cargo:rustc-link-search=../../src/ucp/.libs/");

// Tell cargo to tell rustc to link the system bzip2
// shared library.
println!("cargo:rustc-link-lib=ucp");

// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// Some of the UCX detailed examples in comments can confuse the
// bindgen parser and it will make bad code instead of comments
.generate_comments(false)
// ucs_status_t is defined as a packed enum and that will lead to
// badness without the flag which tells bindgen to repeat that
// trick with the rust enums
.rustified_enum(".*")
.clang_arg("-I../../src/ucp/api/")
.clang_arg("-I../../")
.clang_arg("-I../../src/")
// Annotate ucs_status_t and ucs_status_ptr_t as #[must_use]
.must_use_type("ucs_status_t")
.must_use_type("ucs_status_ptr_t")
// The input header we would like to generate
// bindings for.
.header("wrapper.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");

// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
2 changes: 2 additions & 0 deletions bindings/rust/rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
hard_tabs = false
tab_spaces = 4
120 changes: 120 additions & 0 deletions bindings/rust/src/am.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use crate::ep::Ep;
use crate::ffi::*;
use crate::status_ptr_to_result;
use crate::status_to_result;
use crate::worker::Worker;
use crate::Request;
use crate::RequestParam;
use bitflags::bitflags;

type AmRecvCb = unsafe extern "C" fn(
arg: *mut ::std::os::raw::c_void,
header: *const ::std::os::raw::c_void,
header_length: usize,
data: *mut ::std::os::raw::c_void,
length: usize,
param: *const ucp_am_recv_param_t,
) -> ucs_status_t;

impl Worker<'_> {
#[inline]
pub fn am_register(&self, am_param: &HandlerParams) -> Result<(), ucs_status_t> {
status_to_result(unsafe { ucp_worker_set_am_recv_handler(self.handle, &am_param.handle) })
}
}

impl Ep<'_> {
#[inline]
pub fn am_send(
&self,
id: u32,
header: &[u8],
data: &[u8],
params: &RequestParam,
) -> Result<Option<Request>, ucs_status_t> {
status_ptr_to_result(unsafe {
ucp_am_send_nbx(
self.handle,
id,
header.as_ptr() as _,
header.len(),
data.as_ptr() as _,
data.len(),
&params.handle,
)
})
}
}

bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CbFlags: u32 {
const WholeMsg = ucp_am_cb_flags::UCP_AM_FLAG_WHOLE_MSG as u32;
const PersistentData = ucp_am_cb_flags::UCP_AM_FLAG_PERSISTENT_DATA as u32;
}
}

#[derive(Debug, Clone)]
pub struct HandlerParamsBuilder {
uninit_handle: std::mem::MaybeUninit<ucp_am_handler_param_t>,
flags: u64,
}

impl HandlerParamsBuilder {
#[inline]
pub fn new() -> HandlerParamsBuilder {
let uninit_params = std::mem::MaybeUninit::<ucp_am_handler_param_t>::uninit();
HandlerParamsBuilder {
uninit_handle: uninit_params,
flags: 0,
}
}

#[inline]
pub fn id(&mut self, id: u32) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_ID as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.id = id;
self
}

#[inline]
pub fn flags(&mut self, flags: CbFlags) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_FLAGS as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.flags = flags.bits();
self
}

#[inline]
pub fn cb(&mut self, cb: AmRecvCb) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_CB as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.cb = Some(cb);
self
}

#[inline]
pub fn arg(&mut self, arg: *mut std::os::raw::c_void) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_ARG as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.arg = arg;
self
}

#[inline]
pub fn build(&mut self) -> HandlerParams {
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.field_mask = self.flags;

let handler_param = HandlerParams {
handle: unsafe { self.uninit_handle.assume_init() },
};

handler_param
}
}

pub struct HandlerParams {
pub(crate) handle: ucp_am_handler_param_t,
}
201 changes: 201 additions & 0 deletions bindings/rust/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use crate::ffi::*;
use crate::status_to_result;
use crate::worker;
use crate::worker::Worker;
use bitflags::bitflags;
use std::ffi::CString;

type RequestInitCb = unsafe extern "C" fn(request: *mut ::std::os::raw::c_void);
type RequestCleanUpCb = unsafe extern "C" fn(request: *mut ::std::os::raw::c_void);

bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Flags: u64 {
const Tag = ucp_feature::UCP_FEATURE_TAG as u64;
const Rma = ucp_feature::UCP_FEATURE_RMA as u64;
const Amo32 = ucp_feature::UCP_FEATURE_AMO32 as u64;
const Amo64 = ucp_feature::UCP_FEATURE_AMO64 as u64;
const Wakeup = ucp_feature::UCP_FEATURE_WAKEUP as u64;
const Stream = ucp_feature::UCP_FEATURE_STREAM as u64;
const Am = ucp_feature::UCP_FEATURE_AM as u64;
const ExportedMemH = ucp_feature::UCP_FEATURE_EXPORTED_MEMH as u64;
}
}

pub struct Config {
handle: *mut ucp_config_t,
}

impl Config {
pub fn read(name: &str, file: &str) -> Result<*mut ucp_config_t, ucs_status_t> {
let mut config: *mut ucp_config_t = std::ptr::null_mut();
let c_name = CString::new(name).unwrap();
let c_file = CString::new(file).unwrap();
status_to_result(unsafe { ucp_config_read(c_name.as_ptr(), c_file.as_ptr(), &mut config) })
.unwrap();
return Ok(config);
}
}

impl Default for Config {
fn default() -> Self {
let config = Config::read("", "").unwrap();
Config { handle: config }
}
}

impl Drop for Config {
fn drop(&mut self) {
unsafe { ucp_config_release(self.handle) };
}
}

#[derive(Debug, Clone)]
pub struct ParamsBuilder {
uninit_handle: std::mem::MaybeUninit<ucp_params_t>,
field_mask: u64,
name: Option<CString>,
}

#[derive(Debug, Clone)]
pub struct Params {
handle: ucp_params_t,
name: Option<CString>,
}

// This builder wraps up the unsafe parts of building the ucp_param_t struct. On construction
// it makes a zero filled ucp_params_t which Rust considers uninitialized. Each call on the builder
// will fill in the fields of the struct and add the mask for that field. On the final build()
// it will fill in the final value of the features field_mask and proclame the rest of the struct
// as initialized. This is Rust safe because all of the other fields are guaranteed to not be used
// by the library since the proper feature flag is not set.

impl ParamsBuilder {
pub fn new() -> ParamsBuilder {
let uninit_params = std::mem::MaybeUninit::<ucp_params_t>::uninit();
ParamsBuilder {
uninit_handle: uninit_params,
field_mask: 0,
name: None,
}
}

pub fn features(&mut self, features: Flags) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_FEATURES as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.features = features.bits();
self
}

pub fn request_size(&mut self, size: usize) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_SIZE as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.request_size = size;
self
}

pub fn request_init(&mut self, cb: RequestInitCb) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_INIT as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };

params.request_init = Some(cb);
self
}

pub fn request_cleanup(&mut self, cb: RequestCleanUpCb) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_CLEANUP as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.request_cleanup = Some(cb);
self
}

pub fn tag_sender_mask(&mut self, mask: u64) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_TAG_SENDER_MASK as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.tag_sender_mask = mask;
self
}

pub fn mt_workers_shared(&mut self, shared: i32) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.mt_workers_shared = shared;
self
}

pub fn estimated_num_eps(&mut self, num_eps: usize) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_ESTIMATED_NUM_EPS as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.estimated_num_eps = num_eps;
self
}

pub fn estimated_num_ppn(&mut self, num_ppn: usize) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_ESTIMATED_NUM_PPN as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.estimated_num_ppn = num_ppn;
self
}

pub fn name(&mut self, name: &str) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_NAME as u64;
let name_cs = CString::new(name).unwrap();
self.name = Some(name_cs);
self
}

pub fn build(&mut self) -> Params {
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.field_mask = self.field_mask;

let mut ucp_param = Params {
name: None,
handle: unsafe { self.uninit_handle.assume_init() },
};

if self.name.is_some() {
let new_name = self.name.clone().unwrap();
ucp_param.handle.name = new_name.as_ptr();
ucp_param.name = Some(new_name);
}

ucp_param
}
}

impl Context {
pub fn new(config: &Config, params: &Params) -> Result<Context, ucs_status_t> {
let mut context: ucp_context_h = std::ptr::null_mut();

let result = status_to_result(unsafe {
ucp_init_version(
UCP_API_MAJOR,
UCP_API_MINOR,
&params.handle,
config.handle,
&mut context,
)
});
match result {
Ok(()) => Ok(Context { handle: context }),
Err(ucs_status_t) => Err(ucs_status_t),
}
}

pub fn worker_create<'a>(
&'a self,
params: &'a worker::Params,
) -> Result<Worker<'a>, ucs_status_t> {
Worker::new(self, params)
}
}

pub struct Context {
pub(crate) handle: ucp_context_h,
}

impl Drop for Context {
fn drop(&mut self) {
unsafe { ucp_cleanup(self.handle) };
}
}
Loading
Loading