Skip to content

Commit

Permalink
improve lstm output, makefiles
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Nov 15, 2024
1 parent 9f67bd1 commit 88feab6
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 74 deletions.
108 changes: 61 additions & 47 deletions enzyme/benchmarks/ReverseMode/adbench/gmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ int main(const int argc, const char* argv[]) {
//}
}

for (size_t i = 0; i < 5; i++)
{

struct GMMInput input;
Expand Down Expand Up @@ -349,6 +350,65 @@ int main(const int argc, const char* argv[]) {
test_suite["tools"].push_back(enzyme);
}
}

{

struct GMMInput input;
read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
input.alphas, input.means, input.icf, input.x,
input.wishart, params.replicate_point);

size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;

struct GMMOutput result = {0, std::vector<double>(Jcols)};
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_jacobian<rust_unsafe_dgmm_objective>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Rust unsafe Enzyme combined";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}

for (size_t i = 0; i < 5; i++)
{

struct GMMInput input;
read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
input.alphas, input.means, input.icf, input.x,
input.wishart, params.replicate_point);

size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;

struct GMMOutput result = {0, std::vector<double>(Jcols)};
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_jacobian<rust_dgmm_objective>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Rust Enzyme combined";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5;
i < result.gradient.size(); i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}

{

Expand Down Expand Up @@ -401,36 +461,6 @@ int main(const int argc, const char* argv[]) {
primal["result"].push_back(res);
test_suite["tools"].push_back(primal);
}
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_jacobian<rust_unsafe_dgmm_objective>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Rust unsafe Enzyme combined";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}

{

struct GMMInput input;
read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
input.alphas, input.means, input.icf, input.x,
input.wishart, params.replicate_point);

size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;

struct GMMOutput result = {0, std::vector<double>(Jcols)};

{
struct timeval start, end;
gettimeofday(&start, NULL);
Expand All @@ -443,24 +473,8 @@ int main(const int argc, const char* argv[]) {
primal["result"].push_back(res);
test_suite["tools"].push_back(primal);
}
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_jacobian<rust_dgmm_objective>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Rust Enzyme combined";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5;
i < result.gradient.size(); i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}

test_suite["llvm-version"] = __clang_version__;
test_suite["mode"] = "ReverseMode";
test_suite["batch-size"] = 1;
Expand Down
36 changes: 12 additions & 24 deletions enzyme/benchmarks/ReverseMode/adbench/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,17 +474,14 @@ int main(const int argc, const char* argv[]) {
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_mayalias_primal(input);
double res = calculate_mayalias_primal(input);
gettimeofday(&end, NULL);
printf("C++ mayalias primal %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "C++ mayalias primal";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("%f ", res);
enzyme["result"].push_back(res);
test_suite["tools"].push_back(enzyme);

printf("\n");
Expand All @@ -507,17 +504,14 @@ int main(const int argc, const char* argv[]) {
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_restrict_primal(input);
double res = calculate_restrict_primal(input);
gettimeofday(&end, NULL);
printf("C++ restrict primal %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "C++ restrict primal";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("%f ", res);
enzyme["result"].push_back(res);
test_suite["tools"].push_back(enzyme);

printf("\n");
Expand All @@ -540,17 +534,14 @@ int main(const int argc, const char* argv[]) {
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_unsafe_primal(input);
double res =calculate_unsafe_primal(input);
gettimeofday(&end, NULL);
printf("Enzyme (unsafe Rust) primal %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Enzyme (unsafe Rust) primal";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("%f ", res);
enzyme["result"].push_back(res);
test_suite["tools"].push_back(enzyme);

printf("\n");
Expand All @@ -573,17 +564,14 @@ int main(const int argc, const char* argv[]) {
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_safe_primal(input);
double res = calculate_safe_primal(input);
gettimeofday(&end, NULL);
printf("Enzyme (safe Rust) primal %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Enzyme (safe Rust) primal";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("%f ", res);
enzyme["result"].push_back(res);
test_suite["tools"].push_back(enzyme);

printf("\n");
Expand Down
2 changes: 1 addition & 1 deletion enzyme/benchmarks/ReverseMode/gmm/Makefile.make
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ gmm.o: gmm.cpp $(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a
clang++ $(LOADCLANG) $(BENCH) -O3 -fno-math-errno $^ $(BENCHLINK) -lm -o $@

results.json: gmm.o
./$^
numactl -C 1 ./$^
4 changes: 2 additions & 2 deletions enzyme/benchmarks/ReverseMode/lstm/Makefile.make
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ clean:
cargo +enzyme clean

$(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a: src/lib.rs Cargo.toml
RUSTFLAGS="-Z mutable-noalias=no" cargo +enzyme rustc --release --lib --crate-type=staticlib
RUSTFLAGS="-Z mutable-noalias=yes" cargo +enzyme rustc --release --lib --crate-type=staticlib

lstm.o: lstm.cpp $(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a
clang++ $(LOADCLANG) $(BENCH) -O3 -fno-math-errno $^ $(BENCHLINK) -lm -o $@

results.json: lstm.o
./$^
numactl -C 1 ./$^

0 comments on commit 88feab6

Please sign in to comment.