Skip to content

Commit

Permalink
Merge branch 'ChASE-v1.3-newIO' into 'master'
Browse files Browse the repository at this point in the history
Update the estimation of bounds for QR

See merge request SLai/ChASE!28
  • Loading branch information
brunowu committed Apr 4, 2023
2 parents 670dad3 + ee7b8e3 commit 2f0babf
Show file tree
Hide file tree
Showing 8 changed files with 462 additions and 16 deletions.
22 changes: 22 additions & 0 deletions ChASE-MPI/blas_fortran.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,28 @@ extern "C"
const BlasInt* lda, const dcomplex* b,
const BlasInt* ldb);

void FC_GLOBAL(sgesvd, SGESVD)(const char *jobu, const char *jobvt,
const BlasInt* m, const BlasInt* n,
float *A, const BlasInt* lda, float *S,
float *U, const BlasInt *ldu, float *Vt,
const BlasInt *ldvt, float *work,
const BlasInt *lwork, float *rwork, BlasInt *info );
void FC_GLOBAL(dgesvd, DGESVD)(const char *jobu, const char *jobvt,
const BlasInt* m, const BlasInt* n,
double *A, const BlasInt* lda, double *S,
double *U, const BlasInt *ldu, double *Vt,
const BlasInt *ldvt, double *work,
const BlasInt *lwork, double *rwork, BlasInt *info );
void FC_GLOBAL(cgesvd, CGESVD)(const char *jobu, const char *jobvt, const BlasInt* m,
const BlasInt* n, scomplex *A, const BlasInt* lda,
float *S, scomplex *U, const BlasInt *ldu, scomplex *Vt,
const BlasInt *ldvt, scomplex *work, const BlasInt *lwork,
float *rwork, BlasInt *info );
void FC_GLOBAL(zgesvd, ZGESVD)(const char *jobu, const char *jobvt, const BlasInt* m,
const BlasInt* n, dcomplex *A, const BlasInt* lda, double *S,
dcomplex *U, const BlasInt *ldu, dcomplex *Vt,
const BlasInt *ldvt, dcomplex *work, const BlasInt *lwork,
double *rwork, BlasInt *info );
} // extern "C"
} // namespace mpi
} // namespace chase
4 changes: 4 additions & 0 deletions ChASE-MPI/blas_templates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ void t_trsm(const char side, const char uplo, const char trans, const char diag,
const T* a, const std::size_t lda, const T* b,
const std::size_t ldb);

template<typename T>
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n,
T *A, const std::size_t lda, Base<T> *S, T *U, const std::size_t ldu, T *Vt,
const std::size_t ldvt);
// scalapack
// BLACS
void t_descinit(std::size_t* desc, std::size_t* m, std::size_t* n,
Expand Down
117 changes: 117 additions & 0 deletions ChASE-MPI/blas_templates.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,123 @@ void t_trsm(const char side, const char uplo, const char trans, const char diag,
(&side, &uplo, &trans, &diag, &m_, &n_, alpha, a, &lda_, b, &ldb_);
}

template<>
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, float *A,
const std::size_t lda, float *S, float *U, const std::size_t ldu, float *Vt, const std::size_t ldvt){
using T = std::remove_reference<decltype((A[0]))>::type;
BlasInt m_ = m;
BlasInt n_ = n;
BlasInt lda_ = lda;
BlasInt ldu_ = ldu;
BlasInt ldvt_ = ldvt;

T* work;
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
T numwork;
BlasInt lwork, info;

lwork = -1;
FC_GLOBAL(sgesvd, SGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
assert(info == 0);


lwork = static_cast<std::size_t>((numwork));
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
work = ptr.get();

FC_GLOBAL(sgesvd, SGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
assert(info == 0);
}

template<>
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, double *A,
const std::size_t lda, double *S, double *U, const std::size_t ldu, double *Vt, const std::size_t ldvt){
using T = std::remove_reference<decltype((A[0]))>::type;
BlasInt m_ = m;
BlasInt n_ = n;
BlasInt lda_ = lda;
BlasInt ldu_ = ldu;
BlasInt ldvt_ = ldvt;

T* work;
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
T numwork;
BlasInt lwork, info;

lwork = -1;
FC_GLOBAL(dgesvd, DGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
assert(info == 0);


lwork = static_cast<std::size_t>((numwork));
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
work = ptr.get();

FC_GLOBAL(dgesvd, DGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
assert(info == 0);

}

template<>
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, std::complex<float> *A,
const std::size_t lda, float *S, std::complex<float> *U, const std::size_t ldu, std::complex<float> *Vt, const std::size_t ldvt){
using T = std::remove_reference<decltype((A[0]))>::type;
BlasInt m_ = m;
BlasInt n_ = n;
BlasInt lda_ = lda;
BlasInt ldu_ = ldu;
BlasInt ldvt_ = ldvt;

T* work;
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
T numwork;
BlasInt lwork, info;

lwork = -1;
FC_GLOBAL(cgesvd, CGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
assert(info == 0);


lwork = static_cast<std::size_t>(real(numwork));
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
work = ptr.get();

FC_GLOBAL(cgesvd, CGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
assert(info == 0);
}

template<>
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, std::complex<double> *A,
const std::size_t lda, double *S, std::complex<double> *U, const std::size_t ldu, std::complex<double> *Vt, const std::size_t ldvt){
using T = std::remove_reference<decltype((A[0]))>::type;
BlasInt m_ = m;
BlasInt n_ = n;
BlasInt lda_ = lda;
BlasInt ldu_ = ldu;
BlasInt ldvt_ = ldvt;

T* work;
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
T numwork;
BlasInt lwork, info;

lwork = -1;
FC_GLOBAL(zgesvd, ZGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
assert(info == 0);


lwork = static_cast<std::size_t>(real(numwork));
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
work = ptr.get();

FC_GLOBAL(zgesvd, ZGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
assert(info == 0);

delete[] rwork;

}


#if defined(HAS_SCALAPACK)
// SCALAPACK
void t_descinit(std::size_t* desc, std::size_t* m, std::size_t* n,
Expand Down
Loading

0 comments on commit 2f0babf

Please sign in to comment.