Skip to content

Commit

Permalink
Multithreaded rulebook
Browse files Browse the repository at this point in the history
  • Loading branch information
mpmisko committed Sep 5, 2019
1 parent 1171aae commit 70486f2
Showing 1 changed file with 50 additions and 11 deletions.
61 changes: 50 additions & 11 deletions sparseconvnet/SCN/Metadata/SubmanifoldConvolutionRules.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#ifndef SUBMANIFOLDCONVOLUTIONRULES_H
#define SUBMANIFOLDCONVOLUTIONRULES_H

#include <algorithm>

// Full input region for an output point
template <Int dimension>
RectangularRegion<dimension>
Expand All @@ -27,20 +29,57 @@ template <Int dimension>
double SubmanifoldConvolution_SgToRules(SparseGrid<dimension> &grid,
RuleBook &rules, long *size) {
double countActiveInputs = 0;
for (auto const &outputIter : grid.mp) {
auto inRegion =
InputRegionCalculator_Submanifold<dimension>(outputIter.first, size);
Int rulesOffset = 0;
for (auto inputPoint : inRegion) {
auto inputIter = grid.mp.find(inputPoint);
if (inputIter != grid.mp.end()) {
rules[rulesOffset].push_back(inputIter->second + grid.ctr);
rules[rulesOffset].push_back(outputIter.second + grid.ctr);
countActiveInputs++;
const Int threadCount = 4;
std::vector<std::thread> threads;
std::array<int, threadCount> activeInputs = {};
std::vector<RuleBook> rulebooks;
for (Int t = 0; t < threadCount; ++t) {
rulebooks.push_back(RuleBook(rules.size()));
}

auto func = [&](const int order) {
auto outputIter = grid.mp.begin();
auto &rb = rulebooks[order];
int rem = grid.mp.size();
int aciveInputCount = 0;

if (rem > order) {
std::advance(outputIter, order);
rem -= order;

for (; outputIter != grid.mp.end();
std::advance(outputIter, std::min(threadCount, rem)),
rem -= threadCount) {
auto inRegion = InputRegionCalculator_Submanifold<dimension>(
outputIter->first, size);
Int rulesOffset = 0;
for (auto inputPoint : inRegion) {
auto inputIter = grid.mp.find(inputPoint);
if (inputIter != grid.mp.end()) {
aciveInputCount++;
rb[rulesOffset].push_back(inputIter->second + grid.ctr);
rb[rulesOffset].push_back(outputIter->second + grid.ctr);
}
rulesOffset++;
}
}
rulesOffset++;
activeInputs[order] = aciveInputCount;
}
};

for (Int t = 0; t < threadCount; ++t) {
threads.push_back(std::thread(func, t));
}

for (Int t = 0; t < threadCount; ++t) {
threads[t].join();
countActiveInputs += activeInputs[t];
for (std::size_t i = 0; i < rulebooks[t].size(); ++i) {
rules[i].insert(rules[i].end(), rulebooks[t][i].begin(),
rulebooks[t][i].end());
}
}

return countActiveInputs;
}

Expand Down

0 comments on commit 70486f2

Please sign in to comment.