Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kschan: use eigen instead of sparse13 #2489

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ivoc/ocmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void OcFullMatrix::solv(Vect* in, Vect* out, bool use_lu) {
}
auto v1 = Vect2VEC(in);
auto v2 = Vect2VEC(out);
v2 = lu_->solve(v1);
v2 = lu_->solve(v1).eval();
}

double OcFullMatrix::det(int* e) {
Expand Down
137 changes: 53 additions & 84 deletions src/nrniv/kschan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ static KSChanList* channels;

extern char* hoc_symbol_units(Symbol*, const char*);
extern void nrn_mk_table_check();
extern spREAL* spGetElement(char*, int, int);

static Symbol* ksstate_sym;
static Symbol* ksgate_sym;
Expand Down Expand Up @@ -893,9 +892,6 @@ KSChan::KSChan(Object* obj, bool is_p) {
Sprintf(buf, "Chan%d", obj_->index);
name_ = buf;
ion_ = "NonSpecific";
mat_ = NULL;
elms_ = NULL;
diag_ = NULL;
gmax_deflt_ = 0.;
erev_deflt_ = 0.;
soffset_ = 4; // gmax, e, g, i before the first state in p array
Expand Down Expand Up @@ -1285,12 +1281,7 @@ void KSChan::free1() {
delete[] ligands_;
ligands_ = NULL;
}
if (mat_) {
spDestroy(mat_);
delete[] elms_;
delete[] diag_;
mat_ = NULL;
}
mat_.setZero();
ngate_ = 0;
nstate_ = 0;
ntrans_ = 0;
Expand Down Expand Up @@ -2122,118 +2113,96 @@ void KSChan::freesym(Symbol* s, Symbol* top) {
}

void KSChan::setupmat() {
int i, j, err;
// printf("KSChan::setupmat nksstate=%d\n", nksstate_);
if (mat_) {
spDestroy(mat_);
delete[] elms_;
delete[] diag_;
mat_ = NULL;
}
if (!nksstate_) {
return;
}
mat_ = spCreate(nksstate_, 0, &err);
if (err != spOKAY) {
hoc_execerror("Couldn't create sparse matrix", 0);
}
spFactor(mat_); // will fail but creates an internal vector needed by
// mulmat which might be called prior to initialization
// when switching to cvode active.
elms_ = new double*[4 * (ntrans_ - ivkstrans_)];
diag_ = new double*[nksstate_];
for (i = ivkstrans_, j = 0; i < ntrans_; ++i) {
int s, t;
s = trans_[i].src_ - nhhstate_ + 1;
t = trans_[i].target_ - nhhstate_ + 1;
elms_[j++] = spGetElement(mat_, s, s);
elms_[j++] = spGetElement(mat_, s, t);
elms_[j++] = spGetElement(mat_, t, t);
elms_[j++] = spGetElement(mat_, t, s);
}
for (i = 0; i < nksstate_; ++i) {
diag_[i] = spGetElement(mat_, i + 1, i + 1);
}
mat_.resize(nksstate_, nksstate_);
}

void KSChan::fillmat(double v, Datum* pd) {
int i, j;
double a, b;
spClear(mat_);
for (i = ivkstrans_, j = 0; i < iligtrans_; ++i) {
trans_[i].ab(v, a, b);
mat_.setZero();
int j = 0;
for (int i = ivkstrans_; i < iligtrans_; ++i, ++j) {
auto& trans = trans_[i];
double a, b;
trans.ab(v, a, b);
// printf("trans %d v=%g a=%g b=%g\n", i, v, a, b);
*elms_[j++] -= a;
*elms_[j++] += b;
*elms_[j++] -= b;
*elms_[j++] += a;
}
for (i = iligtrans_; i < ntrans_; ++i) {
a = trans_[i].alpha(pd);
b = trans_[i].beta();
*elms_[j++] -= a;
*elms_[j++] += b;
*elms_[j++] -= b;
*elms_[j++] += a;
auto s = trans.src_ - nhhstate_;
auto t = trans.target_ - nhhstate_;
mat_.coeffRef(s, s) -= a;
mat_.coeffRef(s, t) += b;
mat_.coeffRef(t, t) -= b;
mat_.coeffRef(t, s) += a;
}
for (int i = iligtrans_; i < ntrans_; ++i, ++j) {
auto& trans = trans_[i];
auto s = trans.src_ - nhhstate_;
auto t = trans.target_ - nhhstate_;
double a = trans.alpha(pd);
double b = trans.beta();
mat_.coeffRef(s, s) -= a;
mat_.coeffRef(s, t) += b;
mat_.coeffRef(t, t) -= b;
mat_.coeffRef(t, s) += a;
}
// printf("after fill\n");
// spPrint(mat_, 0, 1, 0);
}

void KSChan::mat_dt(double dt, Memb_list* ml, std::size_t instance, std::size_t offset) {
// y' = m*y this part add the dt for the form ynew/dt - yold/dt =m*ynew
// the matrix ends up as (m-1/dt)ynew = -1/dt*yold
int i;
double dt1 = -1. / dt;
for (int i = 0; i < nksstate_; ++i) {
*(diag_[i]) += dt1;
mat_.coeffRef(i, i) += dt1;
ml->data(instance, offset + i) *= dt1;
}
}

void KSChan::solvemat(Memb_list* ml, std::size_t instance, std::size_t offset) {
// spSolve seems to require that the parameters are contiguous, which
// they're not anymore in the real NEURON data structure
std::vector<double> s(nksstate_ + 1); // +1 so the pointer arithmetic to account for 1-based
// indexing is valid
Eigen::VectorXd s(nksstate_);
for (auto j = 0; j < nksstate_; ++j) {
s[j + 1] = ml->data(instance, offset + j);
s[j] = ml->data(instance, offset + j);
}
auto const e = spFactor(mat_);
if (e != spOKAY) {
switch (e) {
case spZERO_DIAG:
hoc_execerror("spFactor error:", "Zero Diagonal");
case spNO_MEMORY:
hoc_execerror("spFactor error:", "No Memory");
case spSINGULAR:
hoc_execerror("spFactor error:", "Singular");
}
mat_.makeCompressed();
lu_.compute(mat_);
switch (lu_.info()) {
case Eigen::Success:
// Everything fine, at least no warning
break;
case Eigen::NumericalIssue:
hoc_execerror(
"NumericalIssue: The matrix is not valid following what expect Eigen SparseLU",
nullptr);
break;
case Eigen::NoConvergence:
hoc_execerror("NoConvergence: The matrix did not converge", nullptr);
break;
case Eigen::InvalidInput:
hoc_execerror("InvalidInput: the inputs are invliad", nullptr);
break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a default clause might be nice in case it's not exhaustive.

}
spSolve(mat_, s.data(), s.data());
s = lu_.solve(s);

// Propgate the solution back to the mechanism data
for (auto j = 0; j < nksstate_; ++j) {
ml->data(instance, offset + j) = s[j + 1];
ml->data(instance, offset + j) = s[j];
}
}

void KSChan::mulmat(Memb_list* ml,
std::size_t instance,
std::size_t offset_s,
std::size_t offset_ds) {
std::vector<double> s, ds;
s.resize(nksstate_ + 1); // +1 so the pointer arithmetic to account for 1-based indexing is
// valid
ds.resize(nksstate_ + 1);
Eigen::VectorXd s(nksstate_);
Eigen::VectorXd ds(nksstate_);
for (auto j = 0; j < nksstate_; ++j) {
s[j + 1] = ml->data(instance, offset_s + j);
ds[j + 1] = ml->data(instance, offset_ds + j);
s[j] = ml->data(instance, offset_s + j);
}
spMultiply(mat_, ds.data(), s.data());
ds = mat_ * s;
// Propagate the results
for (auto j = 0; j < nksstate_; ++j) {
ml->data(instance, offset_s + j) = s[j + 1];
ml->data(instance, offset_ds + j) = ds[j + 1];
ml->data(instance, offset_ds + j) = ds[j];
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/nrniv/kschan.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "ivocvect.h"
#include "nrnunits.h"

#include "spmatrix.h"
#include <Eigen/Sparse>

// extern double dt;
extern double celsius;
Expand Down Expand Up @@ -468,7 +468,8 @@ class KSChan {
int cvode_ieq_;
Symbol* mechsym_; // the top level symbol (insert sym or new sym)
Symbol* rlsym_; // symbol with the range list (= mechsym_ when density)
char* mat_;
Eigen::SparseMatrix<double> mat_{};
Eigen::SparseLU<Eigen::SparseMatrix<double>> lu_{};
double** elms_;
double** diag_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these two still needed or can the be removed?

int dsize_; // size of prop->dparam
Expand Down
2 changes: 1 addition & 1 deletion test/hoctests/tests/test_kschan.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_2():
h.cvode_active(1)
# At least executes KSChan::mulmat
hrun(
"kchan without single cvode=True", t_tol=2e-7, v_tol=1e-11, v_tol_per_time=5e-7
"kchan without single cvode=True", t_tol=4e-7, v_tol=1e-11, v_tol_per_time=5e-7
)
h.cvode_active(0)

Expand Down
Loading