From 24404242c0387e66081bb9788c518d308808b869 Mon Sep 17 00:00:00 2001 From: HPDell Date: Fri, 2 Aug 2024 13:12:54 +0100 Subject: [PATCH] fix: MPI problems --- src/gwmodelpp/GWRBasic.cpp | 49 +++++++++++++++------- test/mpi/testGWRBasicMpi.cpp | 78 ++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 15 deletions(-) diff --git a/src/gwmodelpp/GWRBasic.cpp b/src/gwmodelpp/GWRBasic.cpp index 3906913..4c56b1d 100644 --- a/src/gwmodelpp/GWRBasic.cpp +++ b/src/gwmodelpp/GWRBasic.cpp @@ -465,8 +465,7 @@ void gwm::GWRBasic::fTestBase() for (uword i = 0; i < mX.n_cols; i++) { vec diagB = (this->*mCalcDiagBFunction)(i); - double g1 = diagB(0); - double g2 = diagB(1); + double g1 = diagB(0), g2 = diagB(1); double numdf = g1 * g1 / g2; FTestResult f3i; f3i.s = (vk2(i) / g1) / sigma21; @@ -564,7 +563,9 @@ arma::vec GWRBasic::calcDiagBBase(arma::uword i) return { DBL_MAX, DBL_MAX }; } } - return (this->*mCalcDiagBCoreFunction)(i, c); + vec diagB = (this->*mCalcDiagBCoreFunction)(i, c); + diagB = 1.0 / nDp * diagB; + return { sum(diagB), sum(diagB % diagB) }; } vec GWRBasic::calcDiagBCoreSerial(uword i, const vec& c) @@ -584,11 +585,11 @@ vec GWRBasic::calcDiagBCoreSerial(uword i, const vec& c) } catch (const std::exception& e) { - return { DBL_MAX, DBL_MAX }; + diagB.fill(DBL_MAX); + return diagB; } } - diagB = 1.0 / nDp * diagB; - return { sum(diagB), sum(diagB % diagB) }; + return diagB; } #ifdef ENABLE_OPENMP @@ -808,20 +809,21 @@ double GWRBasic::calcTrQtQCoreOmp() vec GWRBasic::calcDiagBCoreOmp(uword i, const vec& c) { arma::uword nDp = mX.n_rows; - vec diagB(nDp, fill::zeros); + mat diagB_all(nDp, mOmpThreadNum, fill::zeros); std::pair workRange = mWorkRange.value_or(make_pair(0, nDp)); int flag = true; #pragma omp parallel for num_threads(mOmpThreadNum) for (arma::uword k = workRange.first; k < workRange.second; k++) { if (flag) { + int thread = omp_get_thread_num(); vec w = mSpatialWeight.weightVector(k); mat xtw = trans(mX.each_col() % w); try { mat C = trans(xtw) * inv_sympd(xtw * mX); vec b = C.col(i); - diagB += (b % b - (1.0 / nDp) * (b % c)); + diagB_all.col(thread) += (b % b - (1.0 / nDp) * (b % c)); } catch (const std::exception& e) { @@ -831,10 +833,10 @@ vec GWRBasic::calcDiagBCoreOmp(uword i, const vec& c) } if (!flag) { - return { DBL_MAX, DBL_MAX }; + diagB_all.fill(DBL_MAX); } - diagB = 1.0 / nDp * diagB; - return { sum(diagB), sum(diagB % diagB) }; + vec diagB = sum(diagB_all, 1); + return diagB; } #endif @@ -1297,10 +1299,27 @@ double GWRBasic::calcTrQtQMpi() vec GWRBasic::calcDiagBMpi(uword i) { - vec diagBi = calcDiagBBase(i); - vec diagB; - MPI_Allreduce(diagBi.memptr(), diagB.memptr(), 2, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); - return diagB; + arma::uword nDp = mX.n_rows; + vec c(nDp, fill::zeros); + for (arma::uword j = 0; j < nDp; j++) + { + vec w = mSpatialWeight.weightVector(j); + mat xtw = trans(mX.each_col() % w); + try + { + mat C = trans(xtw) * inv_sympd(xtw * mX); + c += C.col(i); + } + catch (const std::exception& e) + { + return { DBL_MAX, DBL_MAX }; + } + } + vec diagBi = (this->*mCalcDiagBCoreFunction)(i, c); + vec diagB(nDp, arma::fill::zeros); + MPI_Allreduce(diagBi.memptr(), diagB.memptr(), int(nDp), MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + diagB = 1.0 / double(nDp) * diagB; + return { sum(diagB), sum(diagB % diagB) }; } #endif // ENABLE_MPI diff --git a/test/mpi/testGWRBasicMpi.cpp b/test/mpi/testGWRBasicMpi.cpp index 2c44061..d435540 100644 --- a/test/mpi/testGWRBasicMpi.cpp +++ b/test/mpi/testGWRBasicMpi.cpp @@ -21,6 +21,11 @@ using namespace std; using namespace arma; using namespace gwm; +array convFTestArray(GWRBasic::FTestResult f) +{ + return { f.s, f.df1, f.df2, f.p }; +} + TEST_CASE("BasicGWR: LondonHP") { int iProcess, nProcess; @@ -303,6 +308,79 @@ TEST_CASE("BasicGWR: LondonHP") } } + SECTION("F test | adaptive bandwidth | no bandwidth optimization | no variable optimization") { + auto parallel = GENERATE_REF(values(parallel_list)); + INFO("Parallel:" << ParallelTypeDict.at(parallel)); + + CRSDistance distance(false); + BandwidthWeight bandwidth(36, true, BandwidthWeight::Gaussian); + SpatialWeight spatial(&bandwidth, &distance); + + GWRBasic algorithm; + algorithm.setCoords(londonhp100_coord); + algorithm.setDependentVariable(y); + algorithm.setIndependentVariables(x); + algorithm.setSpatialWeight(spatial); + algorithm.setParallelType(parallel); + algorithm.setWorkerId(iProcess); + algorithm.setWorkerNum(nProcess); +#ifdef ENABLE_OPENMP + if (parallel == ParallelType::OpenMP) + { + algorithm.setOmpThreadNum(omp_get_num_threads()); + } +#endif // ENABLE_OPENMP +#ifdef ENABLE_CUDA + if (parallel == ParallelType::CUDA) + { + algorithm.setGPUId(0); + algorithm.setGroupSize(64); + } +#endif // ENABLE_CUDA + algorithm.setIsDoFtest(true); + REQUIRE_NOTHROW(algorithm.fit()); + if (iProcess == 0) + { + auto results = algorithm.fTestResults(); + auto f1 = convFTestArray(get<0>(results)); + auto f2 = convFTestArray(get<1>(results)); + vector> f3; + for (auto &&i : get<2>(results)) + { + f3.push_back(convFTestArray(i)); + } + auto f4 = convFTestArray(get<3>(results)); + REQUIRE_THAT(f1[0], Catch::Matchers::WithinAbs(0.9342, 1e-3)); + REQUIRE_THAT(f1[1], Catch::Matchers::WithinAbs(93.3300, 1e-3)); + REQUIRE_THAT(f1[2], Catch::Matchers::WithinAbs(96.0000, 1e-3)); + REQUIRE_THAT(f1[3], Catch::Matchers::WithinAbs(0.3710, 1e-3)); + REQUIRE_THAT(f2[0], Catch::Matchers::WithinAbs(1.9762, 1e-3)); + REQUIRE_THAT(f2[1], Catch::Matchers::WithinAbs(13.1571, 1e-3)); + REQUIRE_THAT(f2[2], Catch::Matchers::WithinAbs(96.0000, 1e-3)); + REQUIRE_THAT(f2[3], Catch::Matchers::WithinAbs(0.0303, 1e-3)); + REQUIRE_THAT(f4[0], Catch::Matchers::WithinAbs(0.8752, 1e-3)); + REQUIRE_THAT(f4[1], Catch::Matchers::WithinAbs(89.9377, 1e-3)); + REQUIRE_THAT(f4[2], Catch::Matchers::WithinAbs(96.0000, 1e-3)); + REQUIRE_THAT(f4[3], Catch::Matchers::WithinAbs(0.2619, 1e-3)); + REQUIRE_THAT(f3[0][0], Catch::Matchers::WithinAbs(0.4655, 1e-3)); + REQUIRE_THAT(f3[0][1], Catch::Matchers::WithinAbs(27.1298, 1e-3)); + REQUIRE_THAT(f3[0][2], Catch::Matchers::WithinAbs(93.3300, 1e-3)); + REQUIRE_THAT(f3[0][3], Catch::Matchers::WithinAbs(0.9872, 1e-3)); + REQUIRE_THAT(f3[1][0], Catch::Matchers::WithinAbs(0.5022, 1e-3)); + REQUIRE_THAT(f3[1][1], Catch::Matchers::WithinAbs(18.3315, 1e-3)); + REQUIRE_THAT(f3[1][2], Catch::Matchers::WithinAbs(93.3300, 1e-3)); + REQUIRE_THAT(f3[1][3], Catch::Matchers::WithinAbs(0.9526, 1e-3)); + REQUIRE_THAT(f3[2][0], Catch::Matchers::WithinAbs(0.6415, 1e-3)); + REQUIRE_THAT(f3[2][1], Catch::Matchers::WithinAbs(29.6827, 1e-3)); + REQUIRE_THAT(f3[2][2], Catch::Matchers::WithinAbs(93.3300, 1e-3)); + REQUIRE_THAT(f3[2][3], Catch::Matchers::WithinAbs(0.9151, 1e-3)); + REQUIRE_THAT(f3[3][0], Catch::Matchers::WithinAbs(0.3019, 1e-3)); + REQUIRE_THAT(f3[3][1], Catch::Matchers::WithinAbs(24.7164, 1e-3)); + REQUIRE_THAT(f3[3][2], Catch::Matchers::WithinAbs(93.3300, 1e-3)); + REQUIRE_THAT(f3[3][3], Catch::Matchers::WithinAbs(0.9994, 1e-3)); + } + } + } int main(int argc, char *argv[])