Skip to content

Commit

Permalink
BLAS: symv (#2088)
Browse files Browse the repository at this point in the history
* BLAS: symv

* fix

* tests

* Reverse test fix

* Forward tests

* fix
  • Loading branch information
wsmoses authored Sep 27, 2024
1 parent 662740f commit 5215862
Show file tree
Hide file tree
Showing 6 changed files with 606 additions and 4 deletions.
61 changes: 61 additions & 0 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,67 @@ def symm: CallBlasPattern<(Op $layout, $side, $uplo, $m, $n, $alpha, $A, $lda, $
)
>;



def syr2: CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $x, $incx, $y, $incy, $A, $lda),
["A"],
[cblas_layout, uplo, len, fp, vinc<["n"]>, vinc<["n"]>, mld<["uplo", "n", "n"]>],
[
/*alpha*/ (AssertingInactiveArg),
/*x*/ (AssertingInactiveArg),
/*y*/ (AssertingInactiveArg),
/*A*/ (AssertingInactiveArg)
]
>;


def symv: CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $A, $lda, $x, $incx, $beta, $y, $incy),
["y"],
[cblas_layout, uplo, len, fp, mld<["uplo", "n", "n"]>, vinc<["n"]>, fp, vinc<["n"]>],
[
/*alpha*/ (Seq<["Ax", "vector", "n"], [], 1>
(BlasCall<"symv"> $layout, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>),
(BlasCall<"dot"> $n, (Shadow $y), use<"Ax">, ConstantInt<1>)
),
/*A*/ (Seq<["tmp", "vector", "n"], [], 1>
// Save the diagonal as we shouldn't add syr2 into it
(BlasCall<"copy"> $n, (First (Shadow $A)), (Add $lda, ConstantInt<1>), use<"tmp">, ConstantInt<1>),

(BlasCall<"syr2">
$layout,
$uplo,
$n,
$alpha,
$x,
(Shadow $y),
(Shadow $A)
),
(BlasCall<"copy"> $n, use<"tmp">, ConstantInt<1>, (First (Shadow $A)), (Add $lda, ConstantInt<1>))
),
/*x*/ (BlasCall<"symv"> $layout, $uplo, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $y), Constant<"1.0">, (Shadow $x)),
/*beta*/ (BlasCall<"dot"> $n, (Shadow $y), input<"y">),
/*y*/ (BlasCall<"scal"> $n, $beta, (Shadow $y))
],
// FWD: dy = dalpha A x + alpha dA x + alpha A dx + dbeta y + beta dy

(Seq<[], ["beta1"], 1>
// dbeta y
(BlasCall<"axpy"> $n, (Shadow $beta), $y, (Shadow $y)),

// alpha A dx (optional + beta dy)
(BlasCall<"symv"> $layout, $uplo, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)),

// alpha dA x (optional + beta dy)
(BlasCall<"symv"> $layout, $uplo, $n, $alpha, (Shadow $A), $x, (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)),

// dalpha A x (optional + beta dy)
(BlasCall<"symv"> $layout, $uplo, $n, (Shadow $alpha), $A, (ld $A, Char<"N">, $lda, $n, $n), $x, (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)),

// (beta dy)
(FirstUse<"beta1"> (BlasCall<"scal"> $n, $beta, (Shadow $y)))
)
>;

def syr2k : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc),
["C"],
[cblas_layout, uplo, trans, len, len, fp, mld<["trans", "n", "k"]>, mld<["trans", "n", "k"]>, fp, mld<["n", "n"]>],
Expand Down
7 changes: 4 additions & 3 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2953,9 +2953,10 @@ llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
#endif
{
const char *extractable[] = {
"dot", "scal", "axpy", "gemv", "gemm", "spmv", "syrk",
"nrm2", "trmm", "trmv", "symm", "potrf", "potrs", "copy",
"spmv", "syr2k", "potrs", "getrf", "getrs", "trtrs", "getri"};
"dot", "scal", "axpy", "gemv", "gemm", "spmv", "syrk", "nrm2",
"trmm", "trmv", "symm", "potrf", "potrs", "copy", "spmv", "syr2k",
"potrs", "getrf", "getrs", "trtrs", "getri", "symv",
};
const char *floatType[] = {"s", "d", "c", "z"};
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
const char *suffixes[] = {"", "_", "64_", "_64_"};
Expand Down
141 changes: 141 additions & 0 deletions enzyme/test/Integration/ForwardMode/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ void my_dgemv(char layout, char trans, int M, int N, double alpha,
cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
}

void my_dsymv(char layout, char uplo, int N, double alpha,
double *__restrict__ A, int lda, double *__restrict__ X, int incx,
double beta, double *__restrict__ Y, int incy) {
cblas_dsymv(layout, uplo, N, alpha, A, lda, X, incx, beta, Y, incy);
}

double my_ddot(int N, double *__restrict__ X, int incx, double *__restrict__ Y,
int incy) {
double res = cblas_ddot(N, X, incx, Y, incy);
Expand Down Expand Up @@ -413,6 +419,139 @@ static void gemvTests() {
}
}


static void symvTests() {
// N means normal matrix, T means transposed
for (char layout : {CblasRowMajor, CblasColMajor}) {
for (auto uplo : {'U', 'u', 'L', 'l'}) {
{

std::string Test = "SYMV active A, C ";
BlasInfo inputs[6] = {/*A*/ BlasInfo(A, layout, N, N, lda),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, N, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()};
init();
my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);

assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::SYMV);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].layout == layout);
assert(calls[0].uplo == uplo);
assert(calls[0].targ1 == UNUSED_TRANS);
assert(calls[0].targ2 == UNUSED_TRANS);
assert(calls[0].iarg1 == N);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);

// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);

init();
__enzyme_fwddiff<void>(
(void *)my_dsymv, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, N, enzyme_const, alpha, enzyme_dup,
A, dA, enzyme_const, lda, enzyme_const, B, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();

cblas_dsymv(layout, uplo, N, alpha, dA, lda, B, incB, beta,
dC, incC);

my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);

// cblas_dscal(trans ? N : M, beta, dC, incC);

checkTest(Test);

// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);

// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);

Test = "SYMV active A, B, C ";

init();
__enzyme_fwddiff<void>(
(void *)my_dsymv, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, N, enzyme_const, alpha, enzyme_dup,
A, dA, enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();

cblas_dsymv(layout, uplo, N, alpha, A, lda, dB, incB, beta,
dC, incC);

cblas_dsymv(layout, uplo, N, alpha, dA, lda, B, incB, 1.0, dC, incC);

// cblas_dscal(trans ? N : M, beta, dC, incC);

my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);

// NOT ACTIVE: cblas_dgemv(layout, trans, M, N, dalpha, A, lda, B,
// incB, 1.0, C, incC);

checkTest(Test);

// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);

// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);

Test = "SYMV active B, C ";

init();
__enzyme_fwddiff<void>(
(void *)my_dsymv, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, N, enzyme_const, alpha, enzyme_const,
A, enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();

cblas_dsymv(layout, uplo, N, alpha, A, lda, dB, incB, beta,
dC, incC);

// cblas_dscal(trans ? N : M, beta, dC, incC);

my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);

// NOT ACTIVE: cblas_dgemv(layout, trans, M, N, dalpha, A, lda, B,
// incB, 1.0, C, incC);

checkTest(Test);

// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);

// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}

static void gemmTests() {
// N means normal matrix, T means transposed
for (char layout : {CblasRowMajor, CblasColMajor}) {
Expand Down Expand Up @@ -826,4 +965,6 @@ int main() {
syrkTests();

potrfTests();

symvTests();
}
Loading

0 comments on commit 5215862

Please sign in to comment.