-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
the first *simple* end to end test with logger
- Loading branch information
Showing
6 changed files
with
441 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
ENZYME_PATH ?= /home/brant/Enzyme/build/Enzyme/ClangEnzyme-15.so | ||
LLVM_PATH ?= /home/brant/llvms/llvm15/build/bin | ||
CXX = $(LLVM_PATH)/clang++ | ||
|
||
CXXFLAGS = -I/home/brant/include \ | ||
-L/home/brant/lib \ | ||
-I /usr/include/c++/11 \ | ||
-I /usr/include/x86_64-linux-gnu/c++/11 \ | ||
-L /usr/lib/gcc/x86_64-linux-gnu/11 \ | ||
-fno-exceptions \ | ||
-fpass-plugin=$(ENZYME_PATH) \ | ||
-Xclang -load -Xclang $(ENZYME_PATH) \ | ||
-lmpfr -O3 -ffast-math -fuse-ld=lld | ||
|
||
FPOPTFLAGS += -mllvm --enzyme-enable-fpopt \ | ||
-mllvm --enzyme-print-herbie \ | ||
-mllvm --enzyme-print-fpopt \ | ||
-mllvm --fpopt-log-path=enzyme.txt \ | ||
-mllvm --fpopt-target-func-regex=Pendulum | ||
|
||
SRC ?= example.c | ||
LOGGER ?= fp-logger.cpp | ||
EXE ?= example-logged.exe | ||
|
||
.PHONY: all clean | ||
|
||
all: $(EXE) | ||
|
||
example.cpp: $(SRC) | ||
python3 fpopt-original-driver-generator.py $(SRC) "Pendulum" | ||
|
||
example-logged.cpp: $(SRC) | ||
python3 fpopt-logged-driver-generator.py $(SRC) "Pendulum" | ||
|
||
example.exe: example.cpp | ||
$(CXX) -Wall -O3 example.cpp $(CXXFLAGS) -o $@ | ||
|
||
example-logged.exe: example-logged.cpp $(LOGGER) | ||
$(CXX) -Wall -O3 $(LOGGER) example-logged.cpp $(CXXFLAGS) -o $@ | ||
|
||
example-fpopt.exe: example.cpp | ||
$(CXX) example.cpp $(CXXFLAGS) $(FPOPTFLAGS) -o $@ | ||
|
||
clean: | ||
rm -f $(EXE) example-logged.cpp example.cpp | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#include <math.h> | ||
#include <stdint.h> | ||
#define TRUE 1 | ||
#define FALSE 0 | ||
|
||
// ## PRE t0: -2, 2 | ||
// ## PRE w0: -5, 5 | ||
// ## PRE N: 1000, 1000 | ||
double Pendulum(double t0, double w0, double N) { | ||
double h = 0.01; | ||
double L = 2.0; | ||
double m = 1.5; | ||
double g = 9.80665; | ||
double t = t0; | ||
double w = w0; | ||
double n = 0.0; | ||
int tmp = n < N; | ||
while (tmp) { | ||
double k1w = (-g / L) * sin(t); | ||
double k2t = w + ((h / 2.0) * k1w); | ||
double t_1 = t + (h * k2t); | ||
double k2w = (-g / L) * sin((t + ((h / 2.0) * w))); | ||
double w_2 = w + (h * k2w); | ||
double n_3 = n + 1.0; | ||
t = t_1; | ||
w = w_2; | ||
n = n_3; | ||
tmp = n < N; | ||
} | ||
return t; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#include <cassert> | ||
#include <fstream> | ||
#include <iomanip> | ||
#include <iostream> | ||
#include <limits> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include "fp-logger.hpp" | ||
|
||
class ValueInfo { | ||
public: | ||
double minRes = std::numeric_limits<double>::max(); | ||
double maxRes = std::numeric_limits<double>::lowest(); | ||
std::vector<double> minOperands; | ||
std::vector<double> maxOperands; | ||
unsigned executions = 0; | ||
|
||
void update(double res, const double *operands, unsigned numOperands) { | ||
minRes = std::min(minRes, res); | ||
maxRes = std::max(maxRes, res); | ||
if (minOperands.empty()) { | ||
minOperands.resize(numOperands, std::numeric_limits<double>::max()); | ||
maxOperands.resize(numOperands, std::numeric_limits<double>::lowest()); | ||
} | ||
for (unsigned i = 0; i < numOperands; ++i) { | ||
minOperands[i] = std::min(minOperands[i], operands[i]); | ||
maxOperands[i] = std::max(maxOperands[i], operands[i]); | ||
} | ||
++executions; | ||
} | ||
}; | ||
|
||
class ErrorInfo { | ||
public: | ||
double minErr = std::numeric_limits<double>::max(); | ||
double maxErr = std::numeric_limits<double>::lowest(); | ||
|
||
void update(double err) { | ||
minErr = std::min(minErr, err); | ||
maxErr = std::max(maxErr, err); | ||
} | ||
}; | ||
|
||
class GradInfo { | ||
public: | ||
double grad = 0.0; | ||
|
||
void update(double grad) { this->grad = grad; } | ||
}; | ||
|
||
class Logger { | ||
private: | ||
std::unordered_map<std::string, ValueInfo> valueInfo; | ||
std::unordered_map<std::string, ErrorInfo> errorInfo; | ||
std::unordered_map<std::string, GradInfo> gradInfo; | ||
|
||
public: | ||
void updateValue(const std::string &id, double res, unsigned numOperands, | ||
const double *operands) { | ||
auto &info = valueInfo.emplace(id, ValueInfo()).first->second; | ||
info.update(res, operands, numOperands); | ||
} | ||
|
||
void updateError(const std::string &id, double err) { | ||
auto &info = errorInfo.emplace(id, ErrorInfo()).first->second; | ||
info.update(err); | ||
} | ||
|
||
void updateGrad(const std::string &id, double grad) { | ||
auto &info = gradInfo.emplace(id, GradInfo()).first->second; | ||
info.update(grad); | ||
} | ||
|
||
void print() const { | ||
std::cout << std::scientific | ||
<< std::setprecision(std::numeric_limits<double>::max_digits10); | ||
|
||
for (const auto &pair : valueInfo) { | ||
const auto &id = pair.first; | ||
const auto &info = pair.second; | ||
std::cout << "Value:" << id << "\n"; | ||
std::cout << "\tMinRes = " << info.minRes << "\n"; | ||
std::cout << "\tMaxRes = " << info.maxRes << "\n"; | ||
std::cout << "\tExecutions = " << info.executions << "\n"; | ||
for (unsigned i = 0; i < info.minOperands.size(); ++i) { | ||
std::cout << "\tOperand[" << i << "] = [" << info.minOperands[i] << ", " | ||
<< info.maxOperands[i] << "]\n"; | ||
} | ||
} | ||
|
||
for (const auto &pair : errorInfo) { | ||
const auto &id = pair.first; | ||
const auto &info = pair.second; | ||
std::cout << "Error:" << id << "\n"; | ||
std::cout << "\tMinErr = " << info.minErr << "\n"; | ||
std::cout << "\tMaxErr = " << info.maxErr << "\n"; | ||
} | ||
|
||
for (const auto &pair : gradInfo) { | ||
const auto &id = pair.first; | ||
const auto &info = pair.second; | ||
std::cout << "Grad:" << id << "\n"; | ||
std::cout << "\tGrad = " << info.grad << "\n"; | ||
} | ||
} | ||
}; | ||
|
||
Logger *logger = nullptr; | ||
|
||
void initializeLogger() { logger = new Logger(); } | ||
|
||
void destroyLogger() { | ||
delete logger; | ||
logger = nullptr; | ||
} | ||
|
||
void printLogger() { logger->print(); } | ||
|
||
void enzymeLogError(const char *id, double err) { | ||
assert(logger && "Logger is not initialized"); | ||
logger->updateError(id, err); | ||
} | ||
|
||
void enzymeLogGrad(const char *id, double grad) { | ||
assert(logger && "Logger is not initialized"); | ||
logger->updateGrad(id, grad); | ||
} | ||
|
||
void enzymeLogValue(const char *id, double res, unsigned numOperands, | ||
double *operands) { | ||
assert(logger && "Logger is not initialized"); | ||
logger->updateValue(id, res, numOperands, operands); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
void initializeLogger(); | ||
void destroyLogger(); | ||
void printLogger(); | ||
|
||
void enzymeLogError(const char *id, double err); | ||
void enzymeLogGrad(const char *id, double grad); | ||
void enzymeLogValue(const char *id, double res, unsigned numOperands, | ||
double *operands); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import os | ||
import sys | ||
import re | ||
import random | ||
import numpy as np | ||
|
||
num_samples_per_func = 100 | ||
default_regex = "ex\\d+" | ||
|
||
|
||
def parse_bound(bound): | ||
if "/" in bound: | ||
numerator, denominator = map(float, bound.split("/")) | ||
return numerator / denominator | ||
return float(bound) | ||
|
||
|
||
def parse_c_file(filepath, func_regex): | ||
with open(filepath, "r") as file: | ||
content = file.read() | ||
|
||
pattern = re.compile(rf"(?s)(// ## PRE(?:.*?\n)+?)\s*([\w\s\*]+?)\s+({func_regex})\s*\(([^)]*)\)") | ||
|
||
matches = pattern.findall(content) | ||
|
||
if not matches: | ||
exit(f"No functions found with the regex: {func_regex}") | ||
|
||
functions = [] | ||
|
||
for comments, return_type, func_name, params in matches: | ||
param_comments = re.findall(r"// ## PRE (\w+):\s*([-+.\d/]+),\s*([-+.\d/]+)", comments) | ||
bounds = { | ||
name: { | ||
"min": parse_bound(min_val), | ||
"max": parse_bound(max_val), | ||
} | ||
for name, min_val, max_val in param_comments | ||
} | ||
params = [param.strip() for param in params.split(",") if param.strip()] | ||
functions.append((func_name, bounds, params, return_type.strip())) | ||
|
||
return functions | ||
|
||
|
||
def create_driver_function(functions): | ||
driver_code = ["int main() {"] | ||
driver_code.append(" initializeLogger();") | ||
driver_code.append(" volatile double res;") | ||
|
||
for func_name, bounds, params, return_type in functions: | ||
print(f"Generating driver code for {func_name}") | ||
for _ in range(num_samples_per_func): | ||
call_params = [] | ||
for param in params: | ||
param_tokens = param.strip().split() | ||
if len(param_tokens) >= 2: | ||
param_name = param_tokens[-1] | ||
else: | ||
exit(f"Cannot parse parameter: {param}") | ||
try: | ||
min_val = bounds[param_name]["min"] | ||
max_val = bounds[param_name]["max"] | ||
except KeyError: | ||
exit( | ||
f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds." | ||
) | ||
random_value = np.random.uniform(min_val, max_val) | ||
call_params.append(str(random_value)) | ||
driver_code.append(f" res = __enzyme_autodiff<{return_type}>((void *) {func_name}, {', '.join(call_params)});") | ||
|
||
driver_code.append(" printLogger();") | ||
driver_code.append(" destroyLogger();") | ||
driver_code.append(" return 0;") | ||
driver_code.append("}") | ||
return "\n".join(driver_code) | ||
|
||
|
||
def main(): | ||
if len(sys.argv) < 2: | ||
exit("Usage: script.py <filepath> [function_regex]") | ||
|
||
filepath = sys.argv[1] | ||
func_regex = sys.argv[2] if len(sys.argv) > 2 else default_regex | ||
|
||
if len(sys.argv) <= 2: | ||
print(f"WARNING: No regex provided for target function names. Using default regex: {default_regex}") | ||
|
||
functions = parse_c_file(filepath, func_regex) | ||
driver_code = create_driver_function(functions) | ||
new_filepath = os.path.splitext(filepath)[0] + "-logged.cpp" | ||
|
||
with open(filepath, "r") as original_file: | ||
original_content = original_file.read() | ||
|
||
code_to_insert = """#include "fp-logger.hpp" | ||
void thisIsNeverCalledAndJustForTheLinker() { | ||
enzymeLogError("", 0.0); | ||
enzymeLogGrad("", 0.0); | ||
enzymeLogValue("", 0.0, 2, nullptr); | ||
} | ||
int enzyme_dup; | ||
int enzyme_dupnoneed; | ||
int enzyme_out; | ||
int enzyme_const; | ||
template <typename return_type, typename... T> | ||
return_type __enzyme_autodiff(void *, T...);""" | ||
|
||
with open(new_filepath, "w") as new_file: | ||
new_file.write(original_content) | ||
new_file.write("\n\n" + code_to_insert + "\n\n" + driver_code) | ||
|
||
print(f"Driver function appended to the new file: {new_filepath}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.