diff --git a/sparseconvnet/SCN/Metadata/SubmanifoldConvolutionRules.h b/sparseconvnet/SCN/Metadata/SubmanifoldConvolutionRules.h index b63ca26..b8ae665 100644 --- a/sparseconvnet/SCN/Metadata/SubmanifoldConvolutionRules.h +++ b/sparseconvnet/SCN/Metadata/SubmanifoldConvolutionRules.h @@ -7,6 +7,8 @@ #ifndef SUBMANIFOLDCONVOLUTIONRULES_H #define SUBMANIFOLDCONVOLUTIONRULES_H +#include + // Full input region for an output point template RectangularRegion @@ -27,20 +29,58 @@ template double SubmanifoldConvolution_SgToRules(SparseGrid &grid, RuleBook &rules, long *size) { double countActiveInputs = 0; - for (auto const &outputIter : grid.mp) { - auto inRegion = - InputRegionCalculator_Submanifold(outputIter.first, size); - Int rulesOffset = 0; - for (auto inputPoint : inRegion) { - auto inputIter = grid.mp.find(inputPoint); - if (inputIter != grid.mp.end()) { - rules[rulesOffset].push_back(inputIter->second + grid.ctr); - rules[rulesOffset].push_back(outputIter.second + grid.ctr); - countActiveInputs++; + const Int threadCount = 4; + std::vector threads; + std::array activeInputs = {}; + std::vector rulebooks; + for (Int t = 0; t < threadCount; ++t) { + rulebooks.push_back(RuleBook(rules.size())); + } + + auto func = [&](const int order) { + auto outputIter = grid.mp.begin(); + auto &rb = rulebooks[order]; + int rem = grid.mp.size(); + int aciveInputCount = 0; + + if (rem > order) { + std::advance(outputIter, order); + rem -= order; + + for (; outputIter != grid.mp.end(); + std::advance(outputIter, std::min(threadCount, rem)), + rem -= threadCount) { + auto inRegion = InputRegionCalculator_Submanifold( + outputIter->first, size); + Int rulesOffset = 0; + for (auto inputPoint : inRegion) { + auto inputIter = grid.mp.find(inputPoint); + if (inputIter != grid.mp.end()) { + aciveInputCount++; + rb[rulesOffset].push_back(inputIter->second + grid.ctr); + rb[rulesOffset].push_back(outputIter->second + grid.ctr); + } + rulesOffset++; + } } - rulesOffset++; } + + activeInputs[order] = aciveInputCount; + }; + + for (Int t = 0; t < threadCount; ++t) { + threads.push_back(std::thread(func, t)); } + + for (Int t = 0; t < threadCount; ++t) { + threads[t].join(); + countActiveInputs += activeInputs[t]; + for (std::size_t i = 0; i < rulebooks[t].size(); ++i) { + rules[i].insert(rules[i].end(), rulebooks[t][i].begin(), + rulebooks[t][i].end()); + } + } + return countActiveInputs; }