Skip to content

Commit

Permalink
Create proof trace event for tail call information (#1179)
Browse files Browse the repository at this point in the history
This PR adds a new event that provides information on the way a function
exits control, specifically whether it exits via a tail call or a
conventional return statement. This event is added to assist in
computing the call stack of the various simplifications from the proof
trace hint.
  • Loading branch information
theo25 authored Dec 12, 2024
1 parent e7ddbef commit 6df5ac5
Show file tree
Hide file tree
Showing 120 changed files with 12,053 additions and 33 deletions.
7 changes: 7 additions & 0 deletions bindings/python/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ void bind_proof_trace(py::module_ &m) {
"function_name",
&llvm_pattern_matching_failure_event::get_function_name);

py::class_<
llvm_function_exit_event, std::shared_ptr<llvm_function_exit_event>>(
proof_trace, "llvm_function_exit_event", step_event)
.def_property_readonly(
"rule_ordinal", &llvm_function_exit_event::get_rule_ordinal)
.def_property_readonly("is_tail", &llvm_function_exit_event::is_tail);

py::class_<llvm_function_event, std::shared_ptr<llvm_function_event>>(
proof_trace, "llvm_function_event", step_event)
.def_property_readonly("name", &llvm_function_event::get_name)
Expand Down
3 changes: 3 additions & 0 deletions docs/proof-trace.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ event ::= hook
| side_cond_exit
| config
| pattern_matching_failure
| function_exit
arg ::= kore_term
Expand All @@ -60,6 +61,8 @@ rule ::= WORD(0x22) ordinal arity variable*
side_cond_entry ::= WORD(0xEE) ordinal arity variable*
side_cond_exit ::= WORD(0x33) ordinal boolean_result
function_exit ::= WORD(0x55) ordinal boolean_result
config ::= WORD(0xFF) kore_term
string ::= <c-style null terminated string>
Expand Down
47 changes: 47 additions & 0 deletions include/kllvm/binary/ProofTraceParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ constexpr uint64_t rule_event_sentinel = detail::word(0x22);
constexpr uint64_t side_condition_event_sentinel = detail::word(0xEE);
constexpr uint64_t side_condition_end_sentinel = detail::word(0x33);
constexpr uint64_t pattern_matching_failure_sentinel = detail::word(0x44);
constexpr uint64_t function_exit_sentinel = detail::word(0x55);

class llvm_step_event : public std::enable_shared_from_this<llvm_step_event> {
public:
Expand Down Expand Up @@ -172,6 +173,29 @@ class llvm_pattern_matching_failure_event : public llvm_step_event {
const override;
};

class llvm_function_exit_event : public llvm_step_event {
private:
uint64_t rule_ordinal_;
bool is_tail_;

llvm_function_exit_event(uint64_t rule_ordinal, bool is_tail)
: rule_ordinal_(rule_ordinal)
, is_tail_(is_tail) { }

public:
static sptr<llvm_function_exit_event>
create(uint64_t rule_ordinal, bool is_tail) {
return sptr<llvm_function_exit_event>(
new llvm_function_exit_event(rule_ordinal, is_tail));
}

[[nodiscard]] uint64_t get_rule_ordinal() const { return rule_ordinal_; }
[[nodiscard]] bool is_tail() const { return is_tail_; }

void print(std::ostream &out, bool expand_terms, unsigned indent = 0U)
const override;
};

class llvm_event;

class llvm_function_event : public llvm_step_event {
Expand Down Expand Up @@ -599,6 +623,27 @@ class proof_trace_parser {
return event;
}

sptr<llvm_function_exit_event> static parse_function_exit(
proof_trace_buffer &buffer) {
if (!buffer.check_word(function_exit_sentinel)) {
return nullptr;
}

uint64_t ordinal = 0;
if (!buffer.read_uint64(ordinal)) {
return nullptr;
}

bool is_tail = false;
if (!buffer.read_bool(is_tail)) {
return nullptr;
}

auto event = llvm_function_exit_event::create(ordinal, is_tail);

return event;
}

bool parse_argument(proof_trace_buffer &buffer, llvm_event &event) {
if (buffer.eof() || buffer.peek() != '\x7F') {
return false;
Expand Down Expand Up @@ -634,6 +679,8 @@ class proof_trace_parser {
case pattern_matching_failure_sentinel:
return parse_pattern_matching_failure(buffer);

case function_exit_sentinel: return parse_function_exit(buffer);

default: return nullptr;
}
}
Expand Down
108 changes: 103 additions & 5 deletions include/kllvm/codegen/ProofEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
#include "kllvm/ast/AST.h"
#include "kllvm/codegen/Decision.h"
#include "kllvm/codegen/DecisionParser.h"
#include "kllvm/codegen/Options.h"
#include "kllvm/codegen/Util.h"

#include "llvm/IR/Instructions.h"

#include <fmt/format.h>

#include <map>
#include <tuple>

Expand All @@ -21,27 +24,58 @@ class proof_event {

/*
* Load the boolean flag that controls whether proof hint output is enabled or
* not, then create a branch at the end of this basic block depending on the
* result.
* not, then create a branch at the specified location depending on the
* result. The location can be before a given instruction or at the end of a
* given basic block.
*
* Returns a pair of blocks [proof enabled, merge]; the first of these is
* intended for self-contained behaviour only relevant in proof output mode,
* while the second is for the continuation of the interpreter's previous
* behaviour.
*/
template <typename Location>
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
proof_branch(std::string const &label, llvm::BasicBlock *insert_at_end);
proof_branch(std::string const &label, Location *insert_loc);

/*
* Return the parent function of the given location.
* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
llvm::Function *get_parent_function(Location *loc);

/*
* Return the parent basic block of the given location.
* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
llvm::BasicBlock *get_parent_block(Location *loc);

/*
* If the given location is an Instruction, this method moves the instruction
* to the merge block.
* If the given location is a BasicBlock, this method simply emits a no-op
* instruction to the merge block.
* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
void fix_insert_loc(Location *loc, llvm::BasicBlock *merge_block);

/*
* Set up a standard event prelude by creating a pair of basic blocks for the
* proof output and continuation, then loading the output filename from its
* global.
* global. The location for the prelude can be before a given instruction or
* at the end of a given basic block.
*
* Returns a triple [proof enabled, merge, proof_writer]; see `proofBranch`
* and `emitGetOutputFileName`.
*/
template <typename Location>
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
event_prelude(std::string const &label, llvm::BasicBlock *insert_at_end);
event_prelude(std::string const &label, Location *insert_loc);

/*
* Set up a check of whether a new proof hint chunk should be started. The
Expand Down Expand Up @@ -172,6 +206,13 @@ class proof_event {
llvm::Value *proof_writer, std::string const &function_name,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call to the `function_exit` API of the specified `proof_writer`.
*/
llvm::CallInst *emit_write_function_exit(
llvm::Value *proof_writer, uint64_t ordinal, bool is_tail,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call to the `start_new_chunk` API of the specified `proof_writer`.
*/
Expand Down Expand Up @@ -228,6 +269,10 @@ class proof_event {
[[nodiscard]] llvm::BasicBlock *pattern_matching_failure(
kore_composite_pattern const &pattern, llvm::BasicBlock *current_block);

template <typename Location>
[[nodiscard]] llvm::BasicBlock *
function_exit(uint64_t ordinal, bool is_tail, Location *insert_loc);

proof_event(kore_definition *definition, llvm::Module *module)
: definition_(definition)
, module_(module)
Expand All @@ -236,4 +281,57 @@ class proof_event {

} // namespace kllvm

//===----------------------------------------------------------------------===//
// Implementation for method templates
//===----------------------------------------------------------------------===//

template <typename Location>
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
kllvm::proof_event::proof_branch(
std::string const &label, Location *insert_loc) {
auto *i1_ty = llvm::Type::getInt1Ty(ctx_);

auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty);
auto *proof_output = new llvm::LoadInst(
i1_ty, proof_output_flag, "proof_output", insert_loc);

auto *f = get_parent_function(insert_loc);
auto *true_block
= llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f);
auto *merge_block
= llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f);

llvm::BranchInst::Create(true_block, merge_block, proof_output, insert_loc);

fix_insert_loc(insert_loc, merge_block);

return {true_block, merge_block};
}

template <typename Location>
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
kllvm::proof_event::event_prelude(
std::string const &label, Location *insert_loc) {
auto [true_block, merge_block] = proof_branch(label, insert_loc);
return {true_block, merge_block, emit_get_proof_trace_writer(true_block)};
}

template <typename Location>
llvm::BasicBlock *kllvm::proof_event::function_exit(
uint64_t ordinal, bool is_tail, Location *insert_loc) {

if (!proof_hint_instrumentation) {
return get_parent_block(insert_loc);
}

auto [true_block, merge_block, proof_writer]
= event_prelude("function_exit", insert_loc);

emit_write_function_exit(proof_writer, ordinal, is_tail, true_block);

llvm::BranchInst::Create(merge_block, true_block);

return merge_block;
}

#endif // PROOF_EVENT_H
2 changes: 2 additions & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ void write_side_condition_event_post_to_proof_trace(
void *proof_writer, uint64_t ordinal, bool side_cond_result);
void write_pattern_matching_failure_to_proof_trace(
void *proof_writer, char const *function_name);
void write_function_exit_to_proof_trace(
void *proof_writer, uint64_t ordinal, bool is_tail);
void write_configuration_to_proof_trace(
void *proof_writer, block *config, bool is_initial);
void start_new_chunk_in_proof_trace(void *proof_writer);
Expand Down
23 changes: 23 additions & 0 deletions include/runtime/proof_trace_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class proof_trace_writer {
side_condition_event_post(uint64_t ordinal, bool side_cond_result)
= 0;
virtual void pattern_matching_failure(char const *function_name) = 0;
virtual void function_exit(uint64_t ordinal, bool is_tail) = 0;
virtual void configuration(block *config, bool is_initial) = 0;
virtual void start_new_chunk() = 0;
virtual void end_of_trace() = 0;
Expand Down Expand Up @@ -163,6 +164,12 @@ class proof_trace_file_writer : public proof_trace_writer {
write_null_terminated_string(function_name);
}

void function_exit(uint64_t ordinal, bool is_tail) override {
write_uint64(kllvm::function_exit_sentinel);
write_uint64(ordinal);
write_bool(is_tail);
}

void configuration(block *config, bool is_initial) override {
write_uint64(kllvm::config_sentinel);
serialize_configuration_to_proof_trace(file_, config, 0);
Expand Down Expand Up @@ -227,6 +234,15 @@ class proof_trace_callback_writer : public proof_trace_writer {
, result(result) { }
};

struct function_exit_construction {
uint64_t ordinal;
bool is_tail;

function_exit_construction(uint64_t ordinal, bool is_tail)
: ordinal(ordinal)
, is_tail(is_tail) { }
};

struct call_event_construction {
char const *hook_name;
char const *symbol_name;
Expand Down Expand Up @@ -281,6 +297,8 @@ class proof_trace_callback_writer : public proof_trace_writer {
side_condition_result_construction const &event) { }
virtual void pattern_matching_failure_callback(
pattern_matching_failure_construction const &event) { }
virtual void function_exit_callback(function_exit_construction const &event) {
}
virtual void configuration_event_callback(
kore_configuration_construction const &config, bool is_initial) { }

Expand Down Expand Up @@ -366,6 +384,11 @@ class proof_trace_callback_writer : public proof_trace_writer {
pattern_matching_failure_callback(pm_failure);
}

void function_exit(uint64_t ordinal, bool is_tail) override {
function_exit_construction function_exit(ordinal, is_tail);
function_exit_callback(function_exit);
}

void configuration(block *config, bool is_initial) override {
kore_configuration_construction configuration(config);
configuration_event_callback(configuration, is_initial);
Expand Down
8 changes: 8 additions & 0 deletions lib/binary/ProofTraceParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ void llvm_pattern_matching_failure_event::print(
"{}pattern matching failure: {}\n", indent, function_name_);
}

void llvm_function_exit_event::print(
std::ostream &out, bool expand_terms, unsigned ind) const {
std::string indent(ind * indent_size, ' ');
out << fmt::format(
"{}function exit: {} {}\n", indent, rule_ordinal_,
(is_tail_ ? "tail" : "notail"));
}

void llvm_function_event::print(
std::ostream &out, bool expand_terms, unsigned ind) const {
std::string indent(ind * indent_size, ' ');
Expand Down
21 changes: 21 additions & 0 deletions lib/codegen/CreateTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,7 @@ bool can_tail_call(llvm::Type *type) {
return int_type->getBitWidth() <= 192;
}

// NOLINTNEXTLINE(*-cognitive-complexity)
bool make_function(
std::string const &name, kore_pattern *pattern, kore_definition *definition,
llvm::Module *module, bool tailcc, bool big_step, bool apply,
Expand Down Expand Up @@ -1276,6 +1277,10 @@ bool make_function(
call->setTailCallKind(llvm::CallInst::TCK_MustTail);
retval = call;
} else {
size_t ordinal = 0;
if (apply) {
ordinal = std::stoll(name.substr(11));
}
if (auto *call = llvm::dyn_cast<llvm::CallInst>(retval)) {
// check that musttail requirements are met:
// 1. Call is in tail position (guaranteed)
Expand All @@ -1286,6 +1291,22 @@ bool make_function(
if (call->getCallingConv() == llvm::CallingConv::Tail
&& can_tail_call(call->getType())) {
call->setTailCallKind(llvm::CallInst::TCK_MustTail);
if (apply) {
current_block
= proof_event(definition, module)
.function_exit(
ordinal, true, llvm::dyn_cast<llvm::Instruction>(call));
}
} else {
if (apply) {
current_block = proof_event(definition, module)
.function_exit(ordinal, false, current_block);
}
}
} else {
if (apply) {
current_block = proof_event(definition, module)
.function_exit(ordinal, false, current_block);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lib/codegen/Decision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ void leaf_node::codegen(decision *d) {
d->current_block_
= proof_event(d->definition_, d->module_)
.rewrite_event_pre(axiom, arity, vars, subst, d->current_block_);
// maybe report here as part of the rule event whether a tail call happened

if (d->profile_matching_) {
llvm::CallInst::Create(
Expand Down
Loading

0 comments on commit 6df5ac5

Please sign in to comment.