diff --git a/runtime/libia2/include/ia2.h b/runtime/libia2/include/ia2.h index 5e32ab416..5f4cd9ac9 100644 --- a/runtime/libia2/include/ia2.h +++ b/runtime/libia2/include/ia2.h @@ -48,7 +48,9 @@ #define IA2_CALL(opaque, id) opaque #define IA2_CAST(func, ty) (ty) (void *) func #else -#define IA2_DEFINE_WRAPPER(func) IA2_DEFINE_WRAPPER_##func +#define IA2_DEFINE_WRAPPER(func) \ + typeof(func) func __attribute__((__used__)); \ + IA2_DEFINE_WRAPPER_##func #define IA2_SIGHANDLER(func) ia2_sighandler_##func /// Create a wrapped signal handler for `sa_sigaction` /// diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 346326141..7cb26b1f0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -56,6 +56,7 @@ add_subdirectory(shared_data) add_subdirectory(global_fn_ptr) add_subdirectory(rewrite_macros) add_subdirectory(sighandler) +add_subdirectory(static_addr_taken) # The following tests are not supported on ARM64 yet if (NOT LIBIA2_AARCH64) diff --git a/tests/static_addr_taken/CMakeLists.txt b/tests/static_addr_taken/CMakeLists.txt new file mode 100644 index 000000000..cc966633a --- /dev/null +++ b/tests/static_addr_taken/CMakeLists.txt @@ -0,0 +1,13 @@ +define_shared_lib( + SRCS lib.c + PKEY 2 +) + +define_test( + SRCS main.c + NEEDS_LD_WRAP + PKEY 1 + CRITERION_TEST +) + +define_ia2_wrapper() diff --git a/tests/static_addr_taken/include/static_fns.h b/tests/static_addr_taken/include/static_fns.h new file mode 100644 index 000000000..8c9263562 --- /dev/null +++ b/tests/static_addr_taken/include/static_fns.h @@ -0,0 +1,10 @@ +#pragma once + +typedef void (*fn_ptr_ty)(void); + +static void inline_noop(void) { + printf("called %s defined in header\n", __func__); +} + +fn_ptr_ty *get_ptrs_in_main(void); +fn_ptr_ty *get_ptrs_in_lib(void); diff --git a/tests/static_addr_taken/lib.c b/tests/static_addr_taken/lib.c new file mode 100644 index 000000000..c10395d66 --- /dev/null +++ b/tests/static_addr_taken/lib.c @@ -0,0 +1,40 @@ +#include +#include +#include + +#include + +#define IA2_COMPARTMENT 2 +#include + +#include "static_fns.h" + +static void duplicate_noop(void) { + printf("called %s in library\n", __func__); +} + +static void identical_name(void) { + static int x = 4; + printf("%s in library read x = %d\n", __func__, x); +} + +static fn_ptr_ty ptrs[3] IA2_SHARED_DATA = { + inline_noop, duplicate_noop, identical_name +}; + +fn_ptr_ty *get_ptrs_in_lib(void) { + return ptrs; +} + +Test(static_addr_taken, call_ptrs_in_lib) { + for (int i = 0; i < 3; i++) { + ptrs[i](); + } +} + +Test(static_addr_taken, call_ptr_from_main) { + fn_ptr_ty *main_ptrs = get_ptrs_in_main(); + for (int i = 0; i < 3; i++) { + main_ptrs[i](); + } +} diff --git a/tests/static_addr_taken/main.c b/tests/static_addr_taken/main.c new file mode 100644 index 000000000..a5f1d5970 --- /dev/null +++ b/tests/static_addr_taken/main.c @@ -0,0 +1,42 @@ +#include +#include +#include + +#define IA2_DEFINE_TEST_HANDLER +#include + +INIT_RUNTIME(2); +#define IA2_COMPARTMENT 1 +#include + +#include "static_fns.h" + +static void duplicate_noop(void) { + printf("called %s in main binary\n", __func__); +} + +static void identical_name(void) { + static int x = 3; + printf("%s in main binary read x = %d\n", __func__, x); +} + +static fn_ptr_ty ptrs[3] IA2_SHARED_DATA = { + inline_noop, duplicate_noop, identical_name +}; + +fn_ptr_ty *get_ptrs_in_main(void) { + return ptrs; +} + +Test(static_addr_taken, call_ptrs_in_main) { + for (int i = 0; i < 3; i++) { + ptrs[i](); + } +} + +Test(static_addr_taken, call_ptr_from_lib) { + fn_ptr_ty *lib_ptrs = get_ptrs_in_lib(); + for (int i = 0; i < 3; i++) { + lib_ptrs[i](); + } +} diff --git a/tools/rewriter/SourceRewriter.cpp b/tools/rewriter/SourceRewriter.cpp index 4820d1004..6048a2f06 100644 --- a/tools/rewriter/SourceRewriter.cpp +++ b/tools/rewriter/SourceRewriter.cpp @@ -666,24 +666,6 @@ class FnPtrExpr : public RefactoringCallback { auto [it, new_fn] = internal_addr_taken_fns[filename].insert( std::make_pair(fn_name, new_type)); - - // TODO: Note that this only checks if a function is added to the - // internal_addr_taken_fns map. To make the rewriter idempotent we should - // check for an existing used attribute. - if (new_fn) { - auto decl_start = fn_decl->getBeginLoc(); - if (!decl_start.isFileID()) { - llvm::errs() << "Error: non-file loc for function " << fn_name << '\n'; - } else { - Replacement old_used_attr(sm, decl_start, 0, - llvm::StringRef("__attribute__((used)) ")); - Replacement used_attr = replace_new_file(filename, old_used_attr); - auto err = file_replacements[filename].add(used_attr); - if (err) { - llvm::errs() << "Error adding replacements: " << err << '\n'; - } - } - } } // This check must come after modifying the maps in this pass but before the @@ -701,6 +683,7 @@ class FnPtrExpr : public RefactoringCallback { return; } + // Add the IA2_FN annotation around the function pointer expression clang::CharSourceRange expansion_range = sm.getExpansionRange(loc); Replacement old_r{sm, expansion_range, new_expr}; Replacement r = replace_new_file(filename, old_r);