Skip to content

Commit

Permalink
fix: MPI problems
Browse files Browse the repository at this point in the history
  • Loading branch information
HPDell committed Aug 2, 2024
1 parent 8117316 commit 2440424
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 15 deletions.
49 changes: 34 additions & 15 deletions src/gwmodelpp/GWRBasic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,7 @@ void gwm::GWRBasic::fTestBase()
for (uword i = 0; i < mX.n_cols; i++)
{
vec diagB = (this->*mCalcDiagBFunction)(i);
double g1 = diagB(0);
double g2 = diagB(1);
double g1 = diagB(0), g2 = diagB(1);
double numdf = g1 * g1 / g2;
FTestResult f3i;
f3i.s = (vk2(i) / g1) / sigma21;
Expand Down Expand Up @@ -564,7 +563,9 @@ arma::vec GWRBasic::calcDiagBBase(arma::uword i)
return { DBL_MAX, DBL_MAX };
}
}
return (this->*mCalcDiagBCoreFunction)(i, c);
vec diagB = (this->*mCalcDiagBCoreFunction)(i, c);
diagB = 1.0 / nDp * diagB;
return { sum(diagB), sum(diagB % diagB) };
}

vec GWRBasic::calcDiagBCoreSerial(uword i, const vec& c)
Expand All @@ -584,11 +585,11 @@ vec GWRBasic::calcDiagBCoreSerial(uword i, const vec& c)
}
catch (const std::exception& e)
{
return { DBL_MAX, DBL_MAX };
diagB.fill(DBL_MAX);
return diagB;
}
}
diagB = 1.0 / nDp * diagB;
return { sum(diagB), sum(diagB % diagB) };
return diagB;
}

#ifdef ENABLE_OPENMP
Expand Down Expand Up @@ -808,20 +809,21 @@ double GWRBasic::calcTrQtQCoreOmp()
vec GWRBasic::calcDiagBCoreOmp(uword i, const vec& c)
{
arma::uword nDp = mX.n_rows;
vec diagB(nDp, fill::zeros);
mat diagB_all(nDp, mOmpThreadNum, fill::zeros);
std::pair<uword, uword> workRange = mWorkRange.value_or(make_pair(0, nDp));
int flag = true;
#pragma omp parallel for num_threads(mOmpThreadNum)
for (arma::uword k = workRange.first; k < workRange.second; k++)
{
if (flag) {
int thread = omp_get_thread_num();
vec w = mSpatialWeight.weightVector(k);
mat xtw = trans(mX.each_col() % w);
try
{
mat C = trans(xtw) * inv_sympd(xtw * mX);
vec b = C.col(i);
diagB += (b % b - (1.0 / nDp) * (b % c));
diagB_all.col(thread) += (b % b - (1.0 / nDp) * (b % c));
}
catch (const std::exception& e)
{
Expand All @@ -831,10 +833,10 @@ vec GWRBasic::calcDiagBCoreOmp(uword i, const vec& c)
}
if (!flag)
{
return { DBL_MAX, DBL_MAX };
diagB_all.fill(DBL_MAX);
}
diagB = 1.0 / nDp * diagB;
return { sum(diagB), sum(diagB % diagB) };
vec diagB = sum(diagB_all, 1);
return diagB;
}
#endif

Expand Down Expand Up @@ -1297,10 +1299,27 @@ double GWRBasic::calcTrQtQMpi()

vec GWRBasic::calcDiagBMpi(uword i)
{
vec diagBi = calcDiagBBase(i);
vec diagB;
MPI_Allreduce(diagBi.memptr(), diagB.memptr(), 2, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
return diagB;
arma::uword nDp = mX.n_rows;
vec c(nDp, fill::zeros);
for (arma::uword j = 0; j < nDp; j++)
{
vec w = mSpatialWeight.weightVector(j);
mat xtw = trans(mX.each_col() % w);
try
{
mat C = trans(xtw) * inv_sympd(xtw * mX);
c += C.col(i);
}
catch (const std::exception& e)
{
return { DBL_MAX, DBL_MAX };
}
}
vec diagBi = (this->*mCalcDiagBCoreFunction)(i, c);
vec diagB(nDp, arma::fill::zeros);
MPI_Allreduce(diagBi.memptr(), diagB.memptr(), int(nDp), MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
diagB = 1.0 / double(nDp) * diagB;
return { sum(diagB), sum(diagB % diagB) };
}
#endif // ENABLE_MPI

Expand Down
78 changes: 78 additions & 0 deletions test/mpi/testGWRBasicMpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ using namespace std;
using namespace arma;
using namespace gwm;

array<double, 4> convFTestArray(GWRBasic::FTestResult f)
{
return { f.s, f.df1, f.df2, f.p };
}

TEST_CASE("BasicGWR: LondonHP")
{
int iProcess, nProcess;
Expand Down Expand Up @@ -303,6 +308,79 @@ TEST_CASE("BasicGWR: LondonHP")
}
}

SECTION("F test | adaptive bandwidth | no bandwidth optimization | no variable optimization") {
auto parallel = GENERATE_REF(values(parallel_list));
INFO("Parallel:" << ParallelTypeDict.at(parallel));

CRSDistance distance(false);
BandwidthWeight bandwidth(36, true, BandwidthWeight::Gaussian);
SpatialWeight spatial(&bandwidth, &distance);

GWRBasic algorithm;
algorithm.setCoords(londonhp100_coord);
algorithm.setDependentVariable(y);
algorithm.setIndependentVariables(x);
algorithm.setSpatialWeight(spatial);
algorithm.setParallelType(parallel);
algorithm.setWorkerId(iProcess);
algorithm.setWorkerNum(nProcess);
#ifdef ENABLE_OPENMP
if (parallel == ParallelType::OpenMP)
{
algorithm.setOmpThreadNum(omp_get_num_threads());
}
#endif // ENABLE_OPENMP
#ifdef ENABLE_CUDA
if (parallel == ParallelType::CUDA)
{
algorithm.setGPUId(0);
algorithm.setGroupSize(64);
}
#endif // ENABLE_CUDA
algorithm.setIsDoFtest(true);
REQUIRE_NOTHROW(algorithm.fit());
if (iProcess == 0)
{
auto results = algorithm.fTestResults();
auto f1 = convFTestArray(get<0>(results));
auto f2 = convFTestArray(get<1>(results));
vector<array<double, 4>> f3;
for (auto &&i : get<2>(results))
{
f3.push_back(convFTestArray(i));
}
auto f4 = convFTestArray(get<3>(results));
REQUIRE_THAT(f1[0], Catch::Matchers::WithinAbs(0.9342, 1e-3));
REQUIRE_THAT(f1[1], Catch::Matchers::WithinAbs(93.3300, 1e-3));
REQUIRE_THAT(f1[2], Catch::Matchers::WithinAbs(96.0000, 1e-3));
REQUIRE_THAT(f1[3], Catch::Matchers::WithinAbs(0.3710, 1e-3));
REQUIRE_THAT(f2[0], Catch::Matchers::WithinAbs(1.9762, 1e-3));
REQUIRE_THAT(f2[1], Catch::Matchers::WithinAbs(13.1571, 1e-3));
REQUIRE_THAT(f2[2], Catch::Matchers::WithinAbs(96.0000, 1e-3));
REQUIRE_THAT(f2[3], Catch::Matchers::WithinAbs(0.0303, 1e-3));
REQUIRE_THAT(f4[0], Catch::Matchers::WithinAbs(0.8752, 1e-3));
REQUIRE_THAT(f4[1], Catch::Matchers::WithinAbs(89.9377, 1e-3));
REQUIRE_THAT(f4[2], Catch::Matchers::WithinAbs(96.0000, 1e-3));
REQUIRE_THAT(f4[3], Catch::Matchers::WithinAbs(0.2619, 1e-3));
REQUIRE_THAT(f3[0][0], Catch::Matchers::WithinAbs(0.4655, 1e-3));
REQUIRE_THAT(f3[0][1], Catch::Matchers::WithinAbs(27.1298, 1e-3));
REQUIRE_THAT(f3[0][2], Catch::Matchers::WithinAbs(93.3300, 1e-3));
REQUIRE_THAT(f3[0][3], Catch::Matchers::WithinAbs(0.9872, 1e-3));
REQUIRE_THAT(f3[1][0], Catch::Matchers::WithinAbs(0.5022, 1e-3));
REQUIRE_THAT(f3[1][1], Catch::Matchers::WithinAbs(18.3315, 1e-3));
REQUIRE_THAT(f3[1][2], Catch::Matchers::WithinAbs(93.3300, 1e-3));
REQUIRE_THAT(f3[1][3], Catch::Matchers::WithinAbs(0.9526, 1e-3));
REQUIRE_THAT(f3[2][0], Catch::Matchers::WithinAbs(0.6415, 1e-3));
REQUIRE_THAT(f3[2][1], Catch::Matchers::WithinAbs(29.6827, 1e-3));
REQUIRE_THAT(f3[2][2], Catch::Matchers::WithinAbs(93.3300, 1e-3));
REQUIRE_THAT(f3[2][3], Catch::Matchers::WithinAbs(0.9151, 1e-3));
REQUIRE_THAT(f3[3][0], Catch::Matchers::WithinAbs(0.3019, 1e-3));
REQUIRE_THAT(f3[3][1], Catch::Matchers::WithinAbs(24.7164, 1e-3));
REQUIRE_THAT(f3[3][2], Catch::Matchers::WithinAbs(93.3300, 1e-3));
REQUIRE_THAT(f3[3][3], Catch::Matchers::WithinAbs(0.9994, 1e-3));
}
}

}

int main(int argc, char *argv[])
Expand Down

0 comments on commit 2440424

Please sign in to comment.