Skip to content

Commit

Permalink
saving progress
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed May 16, 2024
1 parent 13060e5 commit 3326cf3
Showing 1 changed file with 74 additions and 7 deletions.
81 changes: 74 additions & 7 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,91 @@ void runViaHerbie(const std::string &cmd) {
output.close();
}

std::string getHerbieOperator(const Instruction &I) {
switch (I.getOpcode()) {
case Instruction::FAdd:
return "+";
case Instruction::FSub:
return "-";
case Instruction::FMul:
return "*";
case Instruction::FDiv:
return "/";
default:
return "UnknownOp";
}
}

// Run (our choice of) floating point optimizations on function `F`.
// Return whether or not we change the function.
bool fpOptimize(llvm::Function &F) {
bool fpOptimize(Function &F) {
bool changed = false;
// 1) Identify subgraphs of the computation which can be entirely represented
// in herbie-style arithmetic
std::string herbieInput;
std::map<Value *, std::string> valueToSymbolMap;
std::map<std::string, Value *> symbolToValueMap;
std::set<std::string> arguments;
int symbolCounter = 0;

llvm::errs() << "Optimizing function " << F.getName().str() << "\n";
auto getNextSymbol = [&symbolCounter]() -> std::string {
return "v" + std::to_string(symbolCounter++);
};

// 1) Identify subgraphs of the computation which can be entirely represented
// in herbie-style arithmetic
// 2) Make the herbie FP-style expression by
// converting llvm instructions into herbie string (FPNode ....)
for (auto &BB : F) {
for (auto &I : BB) {
if (auto *op = dyn_cast<BinaryOperator>(&I)) {
if (op->getType()->isFloatingPointTy()) {
std::string lhs =
valueToSymbolMap.count(op->getOperand(0))
? valueToSymbolMap[op->getOperand(0)]
: (valueToSymbolMap[op->getOperand(0)] = getNextSymbol());
std::string rhs =
valueToSymbolMap.count(op->getOperand(1))
? valueToSymbolMap[op->getOperand(1)]
: (valueToSymbolMap[op->getOperand(1)] = getNextSymbol());

arguments.insert(lhs);
arguments.insert(rhs);

std::string symbol = getNextSymbol();
valueToSymbolMap[&I] = symbol;
symbolToValueMap[symbol] = &I;

std::string herbieNode = "(";
herbieNode += getHerbieOperator(I);
herbieNode += " ";
herbieNode += lhs;
herbieNode += " ";
herbieNode += rhs;
herbieNode += ")";
herbieInput += herbieNode;
}
}
}
}

// 3) run fancy opts
if (herbieInput.empty()) {
return changed;
}

// runViaHerbie()
std::string argumentsStr = "(";
for (const auto &arg : arguments) {
argumentsStr += arg + " ";
}
argumentsStr.pop_back();
argumentsStr += ")";

// 4) parse the output string solution from herbieland
herbieInput = "(FPCore " + argumentsStr + " " + herbieInput + ")";

llvm::errs() << "Herbie input:\n" << herbieInput << "\n";

// 3) run fancy opts
runViaHerbie(herbieInput);

// 4) parse the output string solution from herbieland
// 5) convert into a solution in llvm vals/instructions
return changed;
}
Expand Down

0 comments on commit 3326cf3

Please sign in to comment.