Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/introducing the # summon()] macro #8

Merged
merged 5 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 168 additions & 35 deletions rivets-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ use lazy_regex::regex;
use proc_macro::{self, Diagnostic, Level, Span, TokenStream};
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use std::sync::{atomic::AtomicBool, LazyLock, Mutex};
use syn::{parse_macro_input, Abi, DeriveInput, Error, Expr, FnArg, Ident, ItemFn, Variant};

static IS_FINALIZED: AtomicBool = AtomicBool::new(false);
static MANGLED_NAMES: LazyLock<Mutex<Vec<(String, String)>>> = LazyLock::new(|| Mutex::new(vec![]));
static CPP_IMPORTS: LazyLock<Mutex<Vec<(String, String)>>> = LazyLock::new(|| Mutex::new(vec![]));

macro_rules! derive_error {
($string: tt) => {
Error::new(proc_macro2::Span::call_site(), $string)
Expand All @@ -19,6 +24,14 @@ macro_rules! derive_error {
};
}

macro_rules! check_finalized {
() => {
if IS_FINALIZED.load(std::sync::atomic::Ordering::Relaxed) {
panic!("The rivets library has already been finalized!");
}
};
}

fn failure(callback: proc_macro2::TokenStream, error_message: &str) -> TokenStream {
Diagnostic::spanned(Span::call_site(), Level::Error, error_message).emit();
callback.into()
Expand All @@ -42,8 +55,6 @@ fn determine_calling_convention(input: &ItemFn, unmangled_name: &str) -> Result<
}
}

static mut MANGLED_NAMES: Vec<(String, String)> = vec![];

/// A procedural macro for detouring a C++ compiled function.
///
/// The argument to the macro is the mangled name of the C++ function to detour.
Expand All @@ -58,6 +69,8 @@ static mut MANGLED_NAMES: Vec<(String, String)> = vec![];
///
/// This macro cannot hook into the middle of a C++ function. It can only hook into the beginning or end of a function.
///
/// This macro cannot hook into a function that has been inlined by the compiler. Prominent examples of this include `lua_gettop`.
///
/// Exposes an `unsafe` `back` function that can be called in order to resume control flow to the original C++ function.
///
/// Internally uses the `retour` crate to create a static detour for the function and thus inherits the safety guarantees of that crate.
Expand Down Expand Up @@ -88,6 +101,8 @@ static mut MANGLED_NAMES: Vec<(String, String)> = vec![];
/// See the `pdb2hpp` module for a tool that can generate the correct FFI types for C++ functions.
#[proc_macro_attribute]
pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
check_finalized!();

let mangled_name = attr.to_string();
let unmangled_name =
rivets_shared::demangle(&mangled_name).unwrap_or_else(|| mangled_name.clone());
Expand Down Expand Up @@ -144,61 +159,179 @@ pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
#callback

pub unsafe fn hook(address: u64) -> Result<(), rivets::retour::Error> {
let compiled_function: #cpp_function_header = std::mem::transmute(address);
let compiled_function: #cpp_function_header = std::mem::transmute(address); // todo: rust documentation recommends casting this to a raw function pointer. address as *const _
Detour.initialize(compiled_function, #name)?.enable()?;
Ok(())
}
}
};

unsafe {
MANGLED_NAMES.push((mangled_name.clone(), format!("{name}")));
}
MANGLED_NAMES
.lock()
.expect("Failed to lock mangled names")
.push((mangled_name.clone(), name.to_string()));

Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit();

result.into()
}

/// A procedural macro for importing a C++ compiled function into the rust scope.
/// This macro is useful in the case where you need to directly call any C++ function from rust.
///
/// # Arguments
/// * `mangled_name` - The mangled name of the C++ function to import.
/// * `dll_name` (optional) - Argument for the name of the DLL to import the function from. If not provided, factorio.exe will be used.
///
/// Note that most Factorio libraries (such as allegro and lua) are statically linked. In this case, the `dll_name` argument is not needed.
///
/// # Examples
/// ```
/// // Summons the lua_gettop function from the compiled lua library.
/// // lua_gettop is compiled without name mangling, so calling convention (in this case, extern "C") must be manually provided.
/// #[import(lua_gettop)]
/// extern "C" fn lua_gettop(lua_state: *mut luastate::lua_State) -> i64 {}
///
/// // Calls the lua_gettop function with correct arguments.
/// fn my_func(*mut luastate::lua_State) {
/// let top = unsafe { lua_gettop(lua_state) };
/// println!("Lua stack top: {top}");
/// }
/// ```
///
/// # Safety
/// The arguments and return type of the imported function must be exactly matching FFI types.
/// All structs, classes, enums, and union arguments must have a corresponding `#[repr(C)]` attribute and must also have the correct offsets and sizes.
/// Alternatively, the user can use the `rivets::Opaque` type to represent any arbitrary FFI data if you do not intend to interact with the data.
/// See the `pdb2hpp` module for a tool that can generate the correct FFI types for C++ functions.
///
/// The user must also ensure that the calling convention is correct.
/// Rivets attempts to automatically parse this information from the mangled name however
/// - If the calling convention is not one of cdecl, stdcall, fastcall, thiscall, or vectorcall, the user must specify the calling convention manually.
/// - If the calling convention is not present in the mangled name, the user must specify the calling convention manually.
/// - In rare cases the function may use a non-standard calling convention. In this case, the user must manually populate the required stack and registers via inline assembly.
///
/// Calling any imported function repersents calling into the C++ compiled codebase and thus is inherently unsafe.
#[proc_macro_attribute]
pub fn import(attr: TokenStream, item: TokenStream) -> TokenStream {
check_finalized!();

let mangled_name = attr.to_string();
let unmangled_name =
rivets_shared::demangle(&mangled_name).unwrap_or_else(|| mangled_name.clone());

let input = parse_macro_input!(item as ItemFn);

let calling_convention = match determine_calling_convention(&input, &unmangled_name) {
Ok(calling_convention) => Some(calling_convention),
Err(e) => return failure(quote! { #input }, &e.to_string()),
};

let arg_types = input.sig.inputs.iter().map(|arg| match arg {
FnArg::Receiver(_) => {
quote! {compile_error!("Summoned functions cannot use the self parameter.")}
}
FnArg::Typed(pat) => {
let ty = &pat.ty;
quote! { #ty }
}
});

let return_type = &input.sig.output;
let vis = &input.vis;
let attr = &input.attrs;
let attr = quote! { #(#attr)* };

let name = &input.sig.ident;
let function_type =
quote! { #attr #vis unsafe #calling_convention fn(#(#arg_types),*) #return_type };

CPP_IMPORTS
.lock()
.expect("Failed to lock cpp imports")
.push((mangled_name.clone(), name.to_string()));

Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit();

quote! {
#[allow(non_upper_case_globals)]
static mut #name: rivets::UnsafeSummonedFunction<#function_type> = rivets::UnsafeSummonedFunction::Uninitialized;
}.into()
}

fn get_hooks() -> Vec<proc_macro2::TokenStream> {
MANGLED_NAMES
.lock()
.expect("Failed to lock mangled names")
.iter()
.map(|(mangled_name, module_name)| {
let module_name = Ident::new(module_name, proc_macro2::Span::call_site());
quote! {
hooks.push(
rivets::RivetsHook {
mangled_name: #mangled_name.into(),
hook: #module_name::hook
}
);
}
})
.collect()
}

fn get_imports() -> Vec<proc_macro2::TokenStream> {
CPP_IMPORTS.lock().expect("Failed to lock cpp imports")
.iter()
.map(|(mangled_name, rust_name)| {
let rust_name = Ident::new(rust_name, proc_macro2::Span::call_site());
quote! {
let Some(address) = symbol_cache.get_function_address(base_address, #mangled_name)
else {
panic!(
"Failed to find address for the following mangled function inside the PDB: {}", #mangled_name
);
};
let function = unsafe {
std::mem::transmute(address) // todo: rust documentation recommends casting this to a raw function pointer. address as *const _
};
unsafe { #rust_name = rivets::UnsafeSummonedFunction::Function(function); }
}
})
.collect()
}

/// A procedural macro for finalizing the rivets library.
/// This macro should be called once at the end of the `main.rs` file.
/// It will finalize the rivets library and inject all of the detours.
#[proc_macro]
pub fn finalize(_: TokenStream) -> TokenStream {
let injects = unsafe { MANGLED_NAMES.clone() };
let injects = injects.iter().map(|(mangled_name, name)| {
let name = Ident::new(name, proc_macro2::Span::call_site());
quote! {
hooks.push(
rivets::RivetsHook {
mangled_name: #mangled_name.into(),
hook: #name::hook
}
);
}
});
check_finalized!();
IS_FINALIZED.store(true, std::sync::atomic::Ordering::Relaxed);

quote! {
rivets::dll_syringe::payload_procedure! {
fn rivets_finalize(symbol_cache: rivets::SymbolCache) -> Option<String> {
let base_address = match symbol_cache.get_module_base_address() {
Ok(base_address) => base_address,
Err(e) => return Some(format!("{e}")),
};
let hooks = get_hooks();
let imports = get_imports();

let mut hooks: Vec<rivets::RivetsHook> = Vec::new();
#(#injects)*
for hook in &hooks {
let inject_result = unsafe { symbol_cache.inject(base_address, hook) };
if inject_result.is_err() {
return Some(format!("{inject_result:?}"));
}
let finalize = quote! {
fn rivets_finalize(symbol_cache: rivets::SymbolCache) -> Option<String> {
let base_address = match symbol_cache.get_module_base_address() {
Ok(base_address) => base_address,
Err(e) => return Some(format!("{e}")),
};

#(#imports)*

let mut hooks: Vec<rivets::RivetsHook> = Vec::new();
#(#hooks)*
for hook in &hooks {
let inject_result = unsafe { symbol_cache.inject(base_address, hook) };
if inject_result.is_err() {
return Some(format!("{inject_result:?}"));
}
None
}
None
}
}
.into()
};

quote! { rivets::dll_syringe::payload_procedure! { #finalize } }.into()
}

#[derive(FromDeriveInput)]
Expand Down
25 changes: 25 additions & 0 deletions rivets-shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::fs::File;
use std::ops::Deref;
use std::path::Path;
use undname::Flags;
use windows::core::PCSTR;
Expand Down Expand Up @@ -154,3 +155,27 @@ impl SymbolCache {
Ok((hook.hook)(address)?)
}
}

/// Represents a function that has been imported from a C++ compiled DLL.
/// Invariant: If the function is not initialized, it is UB to dereference it.
/// The rivets::finalize!() macro should be used to ensure that the function is initialized.
pub enum UnsafeSummonedFunction<T>
where
T: 'static + Sized,
{
Function(T),
Uninitialized,
}

impl<T> Deref for UnsafeSummonedFunction<T> {
type Target = T;

#[inline]
#[track_caller]
fn deref(&self) -> &Self::Target {
match self {
Self::Function(x) => x,
Self::Uninitialized => unsafe { std::hint::unreachable_unchecked() },
}
}
}