From fa355e946baecf09e54608a68c1cfb20c85cd438 Mon Sep 17 00:00:00 2001 From: Khyber Sen Date: Thu, 21 Nov 2024 18:57:10 -0800 Subject: [PATCH] rewriter: support the 6 x86 param regs for post condition functions This is done by pushing the param regs onto the stack in the prologue (if there's a post condition function), and then popping them before we call the post condition function. Next we'll work on supporting more than the 6 param regs, as well as passing the return value. --- tools/rewriter/GenCallAsm.cpp | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tools/rewriter/GenCallAsm.cpp b/tools/rewriter/GenCallAsm.cpp index 50eadcf72..9e137de2b 100644 --- a/tools/rewriter/GenCallAsm.cpp +++ b/tools/rewriter/GenCallAsm.cpp @@ -445,7 +445,7 @@ static AsmWriter get_asmwriter(bool as_macro) { return {.ss = {}, .terminator = terminator}; } -static void emit_prologue(AsmWriter &aw, uint32_t caller_pkey, uint32_t target_pkey, Arch arch) { +static void emit_prologue(AsmWriter &aw, uint32_t caller_pkey, uint32_t target_pkey, Arch arch, bool save_param_regs) { if (arch == Arch::X86) { // Save the old frame pointer and set the frame pointer for the call gate add_asm_line(aw, "pushq %rbp"); @@ -456,6 +456,15 @@ static void emit_prologue(AsmWriter &aw, uint32_t caller_pkey, uint32_t target_p for (auto &r : x86_preserved_registers) { add_asm_line(aw, "pushq %"s + r); } + if (save_param_regs) { + // Push first 6 registers for args onto the stack, too, + // so that we can store them for the post condition function. + add_comment_line(aw, "Save param regs for post condition call"); + for (auto &r : x86_int_param_reg_order) { + add_asm_line(aw, "pushq %"s + r); + } + } + // TODO this for arm, too } else if (arch == Arch::Aarch64) { // Frame pointer and link register need to be saved first, to make backtraces work add_asm_line(aw, "stp x29, x30, [sp, #-16]!"); @@ -816,6 +825,13 @@ static void emit_set_return_pkru(AsmWriter &aw, uint32_t caller_pkey, Arch arch) static void emit_post_condition_fn_call(AsmWriter &aw, Arch arch, std::string_view target_post_condition_name) { llvm::errs() << "emitting post condition call to " << target_post_condition_name << "\n"; + if (arch == Arch::X86) { + add_comment_line(aw, "Restore param regs for post condition call"); + for (auto it = x86_int_param_reg_order.rbegin(); it != x86_int_param_reg_order.rend(); ++it) { + auto &r = *it; + add_asm_line(aw, "popq %"s + r); + } + } add_comment_line(aw, "Align stack"); add_asm_line(aw, "subq $8, %rsp"); add_comment_line(aw, "Call post condition function"); @@ -990,6 +1006,14 @@ std::string emit_asm_wrapper(AbiSignature &sig, stack_alignment = (compartment_stack_space + 8) % 16; } + // For now, we hardcode the existence and name of the post-condition function. + // The name is `${target_name}_post_condition`, + // and we only do this for `dav1d_get_picture`. + std::optional target_post_condition_name = std::nullopt; + if (target_name && *target_name == "dav1d_get_picture") { + target_post_condition_name = *target_name + "_post_condition"; + } + add_comment_line(aw, "Wrapper for "s + sig_string(sig, target_name) + ":"); add_asm_line(aw, ".text"); if (kind != WrapperKind::PointerToStatic) { @@ -1001,7 +1025,7 @@ std::string emit_asm_wrapper(AbiSignature &sig, add_asm_line(aw, ".type "s + wrapper_name + ", @function"); add_asm_line(aw, wrapper_name + ":"); - emit_prologue(aw, caller_pkey, target_pkey, arch); + emit_prologue(aw, caller_pkey, target_pkey, arch, target_post_condition_name.has_value()); if (arch == Arch::X86) { add_raw_line(aw, llvm::formatv("ASSERT_PKRU({0:x8}) \"\\n\"", ~((0b11 << (2 * caller_pkey)) | 0b11))); @@ -1041,13 +1065,6 @@ std::string emit_asm_wrapper(AbiSignature &sig, // Call the post-condition function, if one was specified. // The call happens in the caller's compartment. - // For now, we hardcode the existence and name of the post-condition function. - // The name is `${target_name}_post_condition`, - // and we only do this for `dav1d_get_picture`. - std::optional target_post_condition_name = std::nullopt; - if (target_name && *target_name == "dav1d_get_picture") { - target_post_condition_name = *target_name + "_post_condition"; - } if (target_post_condition_name) { emit_post_condition_fn_call(aw, arch, *target_post_condition_name); }