diff --git a/tools/rewriter/GenCallAsm.cpp b/tools/rewriter/GenCallAsm.cpp index ee8901f6c..1b9d1b43f 100644 --- a/tools/rewriter/GenCallAsm.cpp +++ b/tools/rewriter/GenCallAsm.cpp @@ -979,7 +979,7 @@ std::string emit_asm_wrapper(AbiSignature &sig, add_comment_line(aw, "Wrapper for "s + sig_string(sig, target_name) + ":"); add_asm_line(aw, ".text"); - if (!as_macro) { + if (kind != WrapperKind::PointerToStatic) { add_asm_line(aw, ".global "s + wrapper_name); } else { add_asm_line(aw, ".local "s + wrapper_name); diff --git a/tools/rewriter/GenCallAsm.h b/tools/rewriter/GenCallAsm.h index 68fc8b9c0..88c004096 100644 --- a/tools/rewriter/GenCallAsm.h +++ b/tools/rewriter/GenCallAsm.h @@ -10,6 +10,8 @@ enum class WrapperKind { Direct, // Indirect call through a pointer sent to another compartment Pointer, + // Indirect call through a pointer sent to another compartment + PointerToStatic, // Indirect call through a pointer received from another compartment IndirectCallsite, }; diff --git a/tools/rewriter/SourceRewriter.cpp b/tools/rewriter/SourceRewriter.cpp index 70509ef85..465d835ce 100644 --- a/tools/rewriter/SourceRewriter.cpp +++ b/tools/rewriter/SourceRewriter.cpp @@ -911,6 +911,8 @@ class FnDecl : public RefactoringCallback { if (definition) { defined_fns[pkey].insert(fn_name); fn_pkeys[fn_name] = pkey; + Filename filename = get_expansion_filename(fn_node->getLocation(), sm); + fn_definitions[fn_name] = filename; } else { declared_fns[pkey].insert(fn_name); } @@ -920,6 +922,7 @@ class FnDecl : public RefactoringCallback { std::set declared_fns[MAX_PKEYS]; std::map abi_signatures; std::map fn_pkeys; + std::map fn_definitions; }; static void create_file(llvm::raw_fd_ostream *file[MAX_PKEYS], int i, const char *extension) { @@ -1418,8 +1421,20 @@ int main(int argc, const char **argv) { } std::cout << "Generating function pointer wrappers\n"; + std::string macros_defining_wrappers; + /* + * This loops over all non-static address-taken functions but we only want to + * define one call gate for each so we need to track them in case a function + * has its address taken in multiple places + */ + std::set generated_wrappers = {}; // Define wrappers for function pointers (i.e. those referenced by IA2_FN) for (const auto &[fn_name, opaque] : ptr_expr_pass.addr_taken_fns) { + if (generated_wrappers.contains(fn_name)) { + continue; + } + llvm::errs() << " inserting " << fn_name << " into set\n"; + generated_wrappers.insert(fn_name); /* * Declare these wrapper in the output header so that IA2_FN can reference * them. e.g. extern struct IA2_fnptr_ZTSFiiE __ia2_foo; @@ -1435,10 +1450,18 @@ int main(int argc, const char **argv) { AbiSignature c_abi_sig = fn_decl_pass.abi_signatures[fn_name]; std::string asm_wrapper = emit_asm_wrapper(c_abi_sig, wrapper_name, fn_name, - WrapperKind::Pointer, 0, target_pkey, Target); - wrapper_out << "asm(\n"; - wrapper_out << asm_wrapper; - wrapper_out << ");\n"; + WrapperKind::Pointer, 0, target_pkey, Target, + true /* as_macro */); + macros_defining_wrappers += "#define IA2_DEFINE_WRAPPER_"s + fn_name + " \\\n"; + macros_defining_wrappers += "asm(\\\n"; + macros_defining_wrappers += asm_wrapper; + macros_defining_wrappers += ");\n"; + + /* Invoke the macro we just defined in the source file defining the target function */ + auto filename = fn_decl_pass.fn_definitions[fn_name]; + std::ofstream source_file(filename, std::ios::app); + source_file << "IA2_DEFINE_WRAPPER(" << fn_name << ")\n"; + } else { header_out << "asm(\n"; header_out << " \".set " << wrapper_name << ", __real_" << fn_name << "\\n\"\n"; @@ -1449,7 +1472,6 @@ int main(int argc, const char **argv) { std::cout << "Generating function pointer wrappers for static functions\n"; // Define wrappers for pointers to static functions (also those referenced by // IA2_FN) - std::string static_wrappers; for (const auto &[filename, addr_taken_fns] : ptr_expr_pass.internal_addr_taken_fns) { @@ -1467,12 +1489,12 @@ int main(int argc, const char **argv) { AbiSignature c_abi_sig = fn_decl_pass.abi_signatures[fn_name]; std::string asm_wrapper = emit_asm_wrapper( - c_abi_sig, wrapper_name, fn_name, WrapperKind::Pointer, 0, + c_abi_sig, wrapper_name, fn_name, WrapperKind::PointerToStatic, 0, target_pkey, Target, true /* as_macro */); - static_wrappers += "#define IA2_DEFINE_WRAPPER_"s + fn_name + " \\\n"; - static_wrappers += "asm(\\\n"; - static_wrappers += asm_wrapper; - static_wrappers += ");\n"; + macros_defining_wrappers += "#define IA2_DEFINE_WRAPPER_"s + fn_name + " \\\n"; + macros_defining_wrappers += "asm(\\\n"; + macros_defining_wrappers += asm_wrapper; + macros_defining_wrappers += ");\n"; header_out << "extern " << opaque << " " << wrapper_name << ";\n"; @@ -1490,7 +1512,7 @@ int main(int argc, const char **argv) { } header_out << "asm(\"__libia2_abort:\\n\"\n" << " \"" << undef_insn << "\");\n"; - header_out << static_wrappers.c_str(); + header_out << macros_defining_wrappers.c_str(); for (int i = 0; i < num_pkeys; i++) { if (ld_args_out[i] != nullptr) {