Skip to content

Commit

Permalink
adding a tail_call_info event that tells us whether an apply_rule
Browse files Browse the repository at this point in the history
function exited with a tail call
  • Loading branch information
theo25 committed Dec 10, 2024
1 parent 1d701f5 commit 35e6e56
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 0 deletions.
10 changes: 10 additions & 0 deletions bindings/python/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,16 @@ void bind_proof_trace(py::module_ &m) {
"function_name",
&llvm_pattern_matching_failure_event::get_function_name);

py::class_<
llvm_tail_call_info_event,
std::shared_ptr<llvm_tail_call_info_event>>(
proof_trace, "llvm_tail_call_info_event", step_event)
.def_property_readonly(
"callern_name",
&llvm_tail_call_info_event::get_caller_name)
.def_property_readonly(
"is_tail", &llvm_tail_call_info_event::is_tail);

py::class_<llvm_function_event, std::shared_ptr<llvm_function_event>>(
proof_trace, "llvm_function_event", step_event)
.def_property_readonly("name", &llvm_function_event::get_name)
Expand Down
3 changes: 3 additions & 0 deletions docs/proof-trace.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ event ::= hook
| side_cond_exit
| config
| pattern_matching_failure
| tail_call_info
arg ::= kore_term
Expand All @@ -60,6 +61,8 @@ rule ::= WORD(0x22) ordinal arity variable*
side_cond_entry ::= WORD(0xEE) ordinal arity variable*
side_cond_exit ::= WORD(0x33) ordinal boolean_result
tail_call_info ::= WORD(0x55) function_name boolean_result
config ::= WORD(0xFF) kore_term
string ::= <c-style null terminated string>
Expand Down
50 changes: 50 additions & 0 deletions include/kllvm/binary/ProofTraceParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ constexpr uint64_t rule_event_sentinel = detail::word(0x22);
constexpr uint64_t side_condition_event_sentinel = detail::word(0xEE);
constexpr uint64_t side_condition_end_sentinel = detail::word(0x33);
constexpr uint64_t pattern_matching_failure_sentinel = detail::word(0x44);
constexpr uint64_t tail_call_info_sentinel = detail::word(0x55);

class llvm_step_event : public std::enable_shared_from_this<llvm_step_event> {
public:
Expand Down Expand Up @@ -172,6 +173,31 @@ class llvm_pattern_matching_failure_event : public llvm_step_event {
const override;
};

class llvm_tail_call_info_event : public llvm_step_event {
private:
std::string caller_name_;
bool is_tail_;

llvm_tail_call_info_event(std::string caller_name, bool is_tail)
: caller_name_(std::move(caller_name))
, is_tail_(is_tail) { }

public:
static sptr<llvm_tail_call_info_event>
create(std::string caller_name, bool is_tail) {
return sptr<llvm_tail_call_info_event>(
new llvm_tail_call_info_event(std::move(caller_name), is_tail));
}

[[nodiscard]] std::string const &get_caller_name() const {
return caller_name_;
}
[[nodiscard]] bool is_tail() const { return is_tail_; }

void print(std::ostream &out, bool expand_terms, unsigned indent = 0U)
const override;
};

class llvm_event;

class llvm_function_event : public llvm_step_event {
Expand Down Expand Up @@ -599,6 +625,27 @@ class proof_trace_parser {
return event;
}

sptr<llvm_tail_call_info_event> static parse_tail_call_info(
proof_trace_buffer &buffer) {
if (!buffer.check_word(tail_call_info_sentinel)) {
return nullptr;
}

std::string caller_name;
if (!buffer.read_string(caller_name)) {
return nullptr;
}

bool is_tail = false;
if (!buffer.read_bool(is_tail)) {
return nullptr;
}

auto event = llvm_tail_call_info_event::create(caller_name, is_tail);

return event;
}

bool parse_argument(proof_trace_buffer &buffer, llvm_event &event) {
if (buffer.eof() || buffer.peek() != '\x7F') {
return false;
Expand Down Expand Up @@ -634,6 +681,9 @@ class proof_trace_parser {
case pattern_matching_failure_sentinel:
return parse_pattern_matching_failure(buffer);

case tail_call_info_sentinel:
return parse_tail_call_info(buffer);

default: return nullptr;
}
}
Expand Down
15 changes: 15 additions & 0 deletions include/kllvm/codegen/ProofEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class proof_event {
*/
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
proof_branch(std::string const &label, llvm::BasicBlock *insert_at_end);
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
proof_branch(std::string const &label, llvm::Instruction *insert_before);

/*
* Set up a standard event prelude by creating a pair of basic blocks for the
Expand All @@ -42,6 +44,8 @@ class proof_event {
*/
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
event_prelude(std::string const &label, llvm::BasicBlock *insert_at_end);
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
event_prelude(std::string const &label, llvm::Instruction *insert_before);

/*
* Set up a check of whether a new proof hint chunk should be started. The
Expand Down Expand Up @@ -172,6 +176,13 @@ class proof_event {
llvm::Value *proof_writer, std::string const &function_name,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call to the `tail_call_info` API of the specified `proof_writer`.
*/
llvm::CallInst *emit_write_tail_call_info(
llvm::Value *proof_writer, std::string const &caller_name,
bool is_tail, llvm::BasicBlock *insert_at_end);

/*
* Emit a call to the `start_new_chunk` API of the specified `proof_writer`.
*/
Expand Down Expand Up @@ -228,6 +239,10 @@ class proof_event {
[[nodiscard]] llvm::BasicBlock *pattern_matching_failure(
kore_composite_pattern const &pattern, llvm::BasicBlock *current_block);

[[nodiscard]] llvm::BasicBlock *tail_call_info(
std::string const &caller_name, bool is_tail,
llvm::Instruction *insert_before, llvm::BasicBlock *current_block);

proof_event(kore_definition *definition, llvm::Module *module)
: definition_(definition)
, module_(module)
Expand Down
2 changes: 2 additions & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ void write_side_condition_event_post_to_proof_trace(
void *proof_writer, uint64_t ordinal, bool side_cond_result);
void write_pattern_matching_failure_to_proof_trace(
void *proof_writer, char const *function_name);
void write_tail_call_info_to_proof_trace(
void *proof_writer, char const *caller_name, bool is_tail);
void write_configuration_to_proof_trace(
void *proof_writer, block *config, bool is_initial);
void start_new_chunk_in_proof_trace(void *proof_writer);
Expand Down
7 changes: 7 additions & 0 deletions include/runtime/proof_trace_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class proof_trace_writer {
side_condition_event_post(uint64_t ordinal, bool side_cond_result)
= 0;
virtual void pattern_matching_failure(char const *function_name) = 0;
virtual void tail_call_info(char const *caller_name, bool is_tail) = 0;
virtual void configuration(block *config, bool is_initial) = 0;
virtual void start_new_chunk() = 0;
virtual void end_of_trace() = 0;
Expand Down Expand Up @@ -163,6 +164,12 @@ class proof_trace_file_writer : public proof_trace_writer {
write_null_terminated_string(function_name);
}

void tail_call_info(char const *caller_name, bool is_tail) override {
write_uint64(kllvm::tail_call_info_sentinel);
write_null_terminated_string(caller_name);
write_bool(is_tail);
}

void configuration(block *config, bool is_initial) override {
write_uint64(kllvm::config_sentinel);
serialize_configuration_to_proof_trace(file_, config, 0);
Expand Down
8 changes: 8 additions & 0 deletions lib/binary/ProofTraceParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ void llvm_pattern_matching_failure_event::print(
"{}pattern matching failure: {}\n", indent, function_name_);
}

void llvm_tail_call_info_event::print(
std::ostream &out, bool expand_terms, unsigned ind) const {
std::string indent(ind * indent_size, ' ');
out << fmt::format(
"{}tail_call_info: {} {}\n", indent, caller_name_,
(is_tail_ ? "tail" : "notail"));
}

void llvm_function_event::print(
std::ostream &out, bool expand_terms, unsigned ind) const {
std::string indent(ind * indent_size, ' ');
Expand Down
11 changes: 11 additions & 0 deletions lib/codegen/CreateTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,18 @@ bool make_function(
if (call->getCallingConv() == llvm::CallingConv::Tail
&& can_tail_call(call->getType())) {
call->setTailCallKind(llvm::CallInst::TCK_MustTail);
current_block =
proof_event(definition, module)
.tail_call_info(name, true, call, current_block);
} else {
current_block =
proof_event(definition, module)
.tail_call_info(name, false, nullptr, current_block);
}
} else {
current_block =
proof_event(definition, module)
.tail_call_info(name, false, nullptr, current_block);
}
}
auto *ret
Expand Down
1 change: 1 addition & 0 deletions lib/codegen/Decision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ void leaf_node::codegen(decision *d) {
d->current_block_
= proof_event(d->definition_, d->module_)
.rewrite_event_pre(axiom, arity, vars, subst, d->current_block_);
// maybe report here as part of the rule event whether a tail call happened

if (d->profile_matching_) {
llvm::CallInst::Create(
Expand Down
74 changes: 74 additions & 0 deletions lib/codegen/ProofEvent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,27 @@ llvm::CallInst *proof_event::emit_write_pattern_matching_failure(
return b.CreateCall(func, {proof_writer, var_function_name});
}

llvm::CallInst *proof_event::emit_write_tail_call_info(
llvm::Value *proof_writer, std::string const &caller_name,
bool is_tail, llvm::BasicBlock *insert_at_end) {
auto b = llvm::IRBuilder(insert_at_end);

auto *void_ty = llvm::Type::getVoidTy(ctx_);
auto *i8_ptr_ty = llvm::PointerType::getUnqual(ctx_);
auto *i8_ty = llvm::Type::getInt64Ty(ctx_);

auto *func_ty
= llvm::FunctionType::get(void_ty, {i8_ptr_ty, i8_ptr_ty, i8_ty}, false);

auto *func = get_or_insert_function(
module_, "write_tail_call_info_to_proof_trace", func_ty);

auto *var_caller_name
= b.CreateGlobalStringPtr(caller_name, "", 0, module_);
auto *var_is_tail = llvm::ConstantInt::get(i8_ty, is_tail);
return b.CreateCall(func, {proof_writer, var_caller_name, var_is_tail});
}

llvm::CallInst *proof_event::emit_start_new_chunk(
llvm::Value *proof_writer, llvm::BasicBlock *insert_at_end) {
auto b = llvm::IRBuilder(insert_at_end);
Expand Down Expand Up @@ -372,13 +393,41 @@ std::pair<llvm::BasicBlock *, llvm::BasicBlock *> proof_event::proof_branch(
return {true_block, merge_block};
}

std::pair<llvm::BasicBlock *, llvm::BasicBlock *> proof_event::proof_branch(
std::string const &label, llvm::Instruction *insert_before) {
auto *i1_ty = llvm::Type::getInt1Ty(ctx_);

auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty);
auto *proof_output = new llvm::LoadInst(
i1_ty, proof_output_flag, "proof_output", insert_before);

auto *f = insert_before->getParent()->getParent();
auto *true_block
= llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f);
auto *merge_block
= llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f);

emit_no_op(merge_block);

llvm::BranchInst::Create(
true_block, merge_block, proof_output, insert_before);
return {true_block, merge_block};
}

std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
proof_event::event_prelude(
std::string const &label, llvm::BasicBlock *insert_at_end) {
auto [true_block, merge_block] = proof_branch(label, insert_at_end);
return {true_block, merge_block, emit_get_proof_trace_writer(true_block)};
}

std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
proof_event::event_prelude(
std::string const &label, llvm::Instruction *insert_before) {
auto [true_block, merge_block] = proof_branch(label, insert_before);
return {true_block, merge_block, emit_get_proof_trace_writer(true_block)};
}

llvm::BasicBlock *proof_event::check_for_emit_new_chunk(
llvm::BasicBlock *insert_at_end, llvm::BasicBlock *merge_block) {
auto *f = insert_at_end->getParent();
Expand Down Expand Up @@ -695,4 +744,29 @@ llvm::BasicBlock *proof_event::pattern_matching_failure(
return merge_block;
}

llvm::BasicBlock *proof_event::tail_call_info(
std::string const &caller_name, bool is_tail,
llvm::Instruction *insert_before, llvm::BasicBlock *current_block) {

if (!proof_hint_instrumentation) {
return current_block;
}

std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *> prelude;
if (is_tail) {
assert(insert_before);
prelude = event_prelude("tail_call_info", insert_before);
} else {
prelude = event_prelude("tail_call_info", current_block);
}

auto [true_block, merge_block, proof_writer] = prelude;

emit_write_tail_call_info(proof_writer, caller_name, is_tail, true_block);

llvm::BranchInst::Create(merge_block, true_block);

return merge_block;
}

} // namespace kllvm
6 changes: 6 additions & 0 deletions runtime/util/ConfigurationSerializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,12 @@ void write_pattern_matching_failure_to_proof_trace(
->pattern_matching_failure(function_name);
}

void write_tail_call_info_to_proof_trace(
void *proof_writer, char const *caller_name, bool is_tail) {
static_cast<proof_trace_writer *>(proof_writer)
->tail_call_info(caller_name, is_tail);
}

void write_configuration_to_proof_trace(
void *proof_writer, block *config, bool is_initial) {
static_cast<proof_trace_writer *>(proof_writer)
Expand Down

0 comments on commit 35e6e56

Please sign in to comment.