Skip to content

Commit

Permalink
Feat/introducing the # summon()] macro (#8)
Browse files Browse the repository at this point in the history
* feat: Introducing the #[summon()] macro

* fix: Implement lazy mutexes to remove 4 instances of unsafe

* feat: Implement atomic bools to ensure rivets:finalize!() is called correctly.

* chore: cargo fmt

* refactor: rename summon macro to import macro
  • Loading branch information
notnotmelon authored Aug 19, 2024
1 parent 9d3b7d4 commit d5060bb
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 35 deletions.
203 changes: 168 additions & 35 deletions rivets-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ use lazy_regex::regex;
use proc_macro::{self, Diagnostic, Level, Span, TokenStream};
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use std::sync::{atomic::AtomicBool, LazyLock, Mutex};
use syn::{parse_macro_input, Abi, DeriveInput, Error, Expr, FnArg, Ident, ItemFn, Variant};

static IS_FINALIZED: AtomicBool = AtomicBool::new(false);
static MANGLED_NAMES: LazyLock<Mutex<Vec<(String, String)>>> = LazyLock::new(|| Mutex::new(vec![]));
static CPP_IMPORTS: LazyLock<Mutex<Vec<(String, String)>>> = LazyLock::new(|| Mutex::new(vec![]));

macro_rules! derive_error {
($string: tt) => {
Error::new(proc_macro2::Span::call_site(), $string)
Expand All @@ -19,6 +24,14 @@ macro_rules! derive_error {
};
}

macro_rules! check_finalized {
() => {
if IS_FINALIZED.load(std::sync::atomic::Ordering::Relaxed) {
panic!("The rivets library has already been finalized!");
}
};
}

fn failure(callback: proc_macro2::TokenStream, error_message: &str) -> TokenStream {
Diagnostic::spanned(Span::call_site(), Level::Error, error_message).emit();
callback.into()
Expand All @@ -42,8 +55,6 @@ fn determine_calling_convention(input: &ItemFn, unmangled_name: &str) -> Result<
}
}

static mut MANGLED_NAMES: Vec<(String, String)> = vec![];

/// A procedural macro for detouring a C++ compiled function.
///
/// The argument to the macro is the mangled name of the C++ function to detour.
Expand All @@ -58,6 +69,8 @@ static mut MANGLED_NAMES: Vec<(String, String)> = vec![];
///
/// This macro cannot hook into the middle of a C++ function. It can only hook into the beginning or end of a function.
///
/// This macro cannot hook into a function that has been inlined by the compiler. Prominent examples of this include `lua_gettop`.
///
/// Exposes an `unsafe` `back` function that can be called in order to resume control flow to the original C++ function.
///
/// Internally uses the `retour` crate to create a static detour for the function and thus inherits the safety guarantees of that crate.
Expand Down Expand Up @@ -88,6 +101,8 @@ static mut MANGLED_NAMES: Vec<(String, String)> = vec![];
/// See the `pdb2hpp` module for a tool that can generate the correct FFI types for C++ functions.
#[proc_macro_attribute]
pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
check_finalized!();

let mangled_name = attr.to_string();
let unmangled_name =
rivets_shared::demangle(&mangled_name).unwrap_or_else(|| mangled_name.clone());
Expand Down Expand Up @@ -144,61 +159,179 @@ pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
#callback

pub unsafe fn hook(address: u64) -> Result<(), rivets::retour::Error> {
let compiled_function: #cpp_function_header = std::mem::transmute(address);
let compiled_function: #cpp_function_header = std::mem::transmute(address); // todo: rust documentation recommends casting this to a raw function pointer. address as *const _
Detour.initialize(compiled_function, #name)?.enable()?;
Ok(())
}
}
};

unsafe {
MANGLED_NAMES.push((mangled_name.clone(), format!("{name}")));
}
MANGLED_NAMES
.lock()
.expect("Failed to lock mangled names")
.push((mangled_name.clone(), name.to_string()));

Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit();

result.into()
}

/// A procedural macro for importing a C++ compiled function into the rust scope.
/// This macro is useful in the case where you need to directly call any C++ function from rust.
///
/// # Arguments
/// * `mangled_name` - The mangled name of the C++ function to import.
/// * `dll_name` (optional) - Argument for the name of the DLL to import the function from. If not provided, factorio.exe will be used.
///
/// Note that most Factorio libraries (such as allegro and lua) are statically linked. In this case, the `dll_name` argument is not needed.
///
/// # Examples
/// ```
/// // Summons the lua_gettop function from the compiled lua library.
/// // lua_gettop is compiled without name mangling, so calling convention (in this case, extern "C") must be manually provided.
/// #[import(lua_gettop)]
/// extern "C" fn lua_gettop(lua_state: *mut luastate::lua_State) -> i64 {}
///
/// // Calls the lua_gettop function with correct arguments.
/// fn my_func(*mut luastate::lua_State) {
/// let top = unsafe { lua_gettop(lua_state) };
/// println!("Lua stack top: {top}");
/// }
/// ```
///
/// # Safety
/// The arguments and return type of the imported function must be exactly matching FFI types.
/// All structs, classes, enums, and union arguments must have a corresponding `#[repr(C)]` attribute and must also have the correct offsets and sizes.
/// Alternatively, the user can use the `rivets::Opaque` type to represent any arbitrary FFI data if you do not intend to interact with the data.
/// See the `pdb2hpp` module for a tool that can generate the correct FFI types for C++ functions.
///
/// The user must also ensure that the calling convention is correct.
/// Rivets attempts to automatically parse this information from the mangled name however
/// - If the calling convention is not one of cdecl, stdcall, fastcall, thiscall, or vectorcall, the user must specify the calling convention manually.
/// - If the calling convention is not present in the mangled name, the user must specify the calling convention manually.
/// - In rare cases the function may use a non-standard calling convention. In this case, the user must manually populate the required stack and registers via inline assembly.
///
/// Calling any imported function repersents calling into the C++ compiled codebase and thus is inherently unsafe.
#[proc_macro_attribute]
pub fn import(attr: TokenStream, item: TokenStream) -> TokenStream {
check_finalized!();

let mangled_name = attr.to_string();
let unmangled_name =
rivets_shared::demangle(&mangled_name).unwrap_or_else(|| mangled_name.clone());

let input = parse_macro_input!(item as ItemFn);

let calling_convention = match determine_calling_convention(&input, &unmangled_name) {
Ok(calling_convention) => Some(calling_convention),
Err(e) => return failure(quote! { #input }, &e.to_string()),
};

let arg_types = input.sig.inputs.iter().map(|arg| match arg {
FnArg::Receiver(_) => {
quote! {compile_error!("Summoned functions cannot use the self parameter.")}
}
FnArg::Typed(pat) => {
let ty = &pat.ty;
quote! { #ty }
}
});

let return_type = &input.sig.output;
let vis = &input.vis;
let attr = &input.attrs;
let attr = quote! { #(#attr)* };

let name = &input.sig.ident;
let function_type =
quote! { #attr #vis unsafe #calling_convention fn(#(#arg_types),*) #return_type };

CPP_IMPORTS
.lock()
.expect("Failed to lock cpp imports")
.push((mangled_name.clone(), name.to_string()));

Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit();

quote! {
#[allow(non_upper_case_globals)]
static mut #name: rivets::UnsafeSummonedFunction<#function_type> = rivets::UnsafeSummonedFunction::Uninitialized;
}.into()
}

fn get_hooks() -> Vec<proc_macro2::TokenStream> {
MANGLED_NAMES
.lock()
.expect("Failed to lock mangled names")
.iter()
.map(|(mangled_name, module_name)| {
let module_name = Ident::new(module_name, proc_macro2::Span::call_site());
quote! {
hooks.push(
rivets::RivetsHook {
mangled_name: #mangled_name.into(),
hook: #module_name::hook
}
);
}
})
.collect()
}

fn get_imports() -> Vec<proc_macro2::TokenStream> {
CPP_IMPORTS.lock().expect("Failed to lock cpp imports")
.iter()
.map(|(mangled_name, rust_name)| {
let rust_name = Ident::new(rust_name, proc_macro2::Span::call_site());
quote! {
let Some(address) = symbol_cache.get_function_address(base_address, #mangled_name)
else {
panic!(
"Failed to find address for the following mangled function inside the PDB: {}", #mangled_name
);
};
let function = unsafe {
std::mem::transmute(address) // todo: rust documentation recommends casting this to a raw function pointer. address as *const _
};
unsafe { #rust_name = rivets::UnsafeSummonedFunction::Function(function); }
}
})
.collect()
}

/// A procedural macro for finalizing the rivets library.
/// This macro should be called once at the end of the `main.rs` file.
/// It will finalize the rivets library and inject all of the detours.
#[proc_macro]
pub fn finalize(_: TokenStream) -> TokenStream {
let injects = unsafe { MANGLED_NAMES.clone() };
let injects = injects.iter().map(|(mangled_name, name)| {
let name = Ident::new(name, proc_macro2::Span::call_site());
quote! {
hooks.push(
rivets::RivetsHook {
mangled_name: #mangled_name.into(),
hook: #name::hook
}
);
}
});
check_finalized!();
IS_FINALIZED.store(true, std::sync::atomic::Ordering::Relaxed);

quote! {
rivets::dll_syringe::payload_procedure! {
fn rivets_finalize(symbol_cache: rivets::SymbolCache) -> Option<String> {
let base_address = match symbol_cache.get_module_base_address() {
Ok(base_address) => base_address,
Err(e) => return Some(format!("{e}")),
};
let hooks = get_hooks();
let imports = get_imports();

let mut hooks: Vec<rivets::RivetsHook> = Vec::new();
#(#injects)*
for hook in &hooks {
let inject_result = unsafe { symbol_cache.inject(base_address, hook) };
if inject_result.is_err() {
return Some(format!("{inject_result:?}"));
}
let finalize = quote! {
fn rivets_finalize(symbol_cache: rivets::SymbolCache) -> Option<String> {
let base_address = match symbol_cache.get_module_base_address() {
Ok(base_address) => base_address,
Err(e) => return Some(format!("{e}")),
};

#(#imports)*

let mut hooks: Vec<rivets::RivetsHook> = Vec::new();
#(#hooks)*
for hook in &hooks {
let inject_result = unsafe { symbol_cache.inject(base_address, hook) };
if inject_result.is_err() {
return Some(format!("{inject_result:?}"));
}
None
}
None
}
}
.into()
};

quote! { rivets::dll_syringe::payload_procedure! { #finalize } }.into()
}

#[derive(FromDeriveInput)]
Expand Down
25 changes: 25 additions & 0 deletions rivets-shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::fs::File;
use std::ops::Deref;
use std::path::Path;
use undname::Flags;
use windows::core::PCSTR;
Expand Down Expand Up @@ -154,3 +155,27 @@ impl SymbolCache {
Ok((hook.hook)(address)?)
}
}

/// Represents a function that has been imported from a C++ compiled DLL.
/// Invariant: If the function is not initialized, it is UB to dereference it.
/// The rivets::finalize!() macro should be used to ensure that the function is initialized.
pub enum UnsafeSummonedFunction<T>
where
T: 'static + Sized,
{
Function(T),
Uninitialized,
}

impl<T> Deref for UnsafeSummonedFunction<T> {
type Target = T;

#[inline]
#[track_caller]
fn deref(&self) -> &Self::Target {
match self {
Self::Function(x) => x,
Self::Uninitialized => unsafe { std::hint::unreachable_unchecked() },
}
}
}

0 comments on commit d5060bb

Please sign in to comment.