forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ParallelNative.cpp
292 lines (256 loc) · 7.48 KB
/
ParallelNative.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
#if AT_PARALLEL_NATIVE
#include <ATen/Parallel.h>
#include <ATen/PTThreadPool.h>
#ifndef C10_MOBILE
#include <c10/core/thread_pool.h>
#else
#include <caffe2/utils/threadpool/ThreadPool.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#endif // C10_MOBILE
#include <atomic>
#ifdef _OPENMP
#include <omp.h>
#endif
#ifdef TH_BLAS_MKL
#include <mkl.h>
#endif
namespace at {
namespace {
// used with _set_in_parallel_region to mark master thread
// as in parallel region while executing parallel primitives
thread_local bool in_parallel_region_ = false;
// thread number (task_id) set by parallel primitive
thread_local size_t thread_num_ = 0;
void _set_in_parallel_region(bool in_region) {
in_parallel_region_ = in_region;
}
void _set_thread_num(size_t thread_num) {
thread_num_ = thread_num;
}
void _unset_thread_num() {
thread_num_ = 0;
}
#ifndef C10_MOBILE
const int NOT_SET = -1;
const int CONSUMED = -2;
// Number of threads set by the user
// NOT_SET -> positive value -> CONSUMED
// or
// NOT_SET -> CONSUMED
// Meaning:
// - NOT_SET - pool not initialized, user value is not set
// - positive value - pool not initialized, user value set
// - CONSUMED - pool is initialized
std::atomic<int> num_intraop_threads{NOT_SET};
int _num_pool_threads(int nthreads) {
if (nthreads == NOT_SET) {
nthreads = intraop_default_num_threads();
} else {
TORCH_INTERNAL_ASSERT(nthreads > 0);
}
// minus one because of the master thread
return nthreads - 1;
}
TaskThreadPoolBase& _get_intraop_pool() {
static std::shared_ptr<TaskThreadPoolBase> pool =
ThreadPoolRegistry()->Create(
"C10",
/* device_id */ 0,
/* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
/* create_new */ true); // create a separate thread pool for intra-op
return *pool;
}
#endif // C10_MOBILE
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
// `fn` will be called with params: (thread_pool_task_id, task_id).
void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
#ifndef C10_MOBILE
for (size_t i = 1; i < range; ++i) {
_get_intraop_pool().run([fn, i]() { fn((int)i, i); });
}
// Run the first task on the current thread directly.
fn(0, 0);
#else
caffe2::ThreadPool* pool = caffe2::mobile_threadpool();
if (pool) {
// caffe2::ThreadPool can utilize the current thread.
pool->run(fn, range);
} else {
for (size_t i = 0; i < range; ++i) {
fn(0, i);
}
}
#endif // C10_MOBILE
}
// RAII guard helps to support in_parallel_region() and get_thread_num() API.
struct ParallelRegionGuard {
ParallelRegionGuard(int64_t task_id) {
_set_thread_num(task_id);
_set_in_parallel_region(true);
}
~ParallelRegionGuard() {
_set_in_parallel_region(false);
_unset_thread_num();
}
};
} // namespace
namespace internal {
void _parallel_run(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const std::function<void(int64_t, int64_t, size_t)>& f) {
at::internal::lazy_init_num_threads();
size_t num_tasks, chunk_size;
std::tie(num_tasks, chunk_size) =
internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
struct {
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
std::exception_ptr eptr;
std::mutex mutex;
volatile size_t remaining;
std::condition_variable cv;
} state;
auto task = [f, &state, begin, end, chunk_size]
(int /* unused */, size_t task_id) {
int64_t local_start = begin + task_id * chunk_size;
if (local_start < end) {
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
try {
ParallelRegionGuard guard(task_id);
f(local_start, local_end, task_id);
} catch (...) {
if (!state.err_flag.test_and_set()) {
state.eptr = std::current_exception();
}
}
}
{
std::unique_lock<std::mutex> lk(state.mutex);
if (--state.remaining == 0) {
state.cv.notify_one();
}
}
};
state.remaining = num_tasks;
_run_with_pool(task, num_tasks);
// Wait for all tasks to finish.
{
std::unique_lock<std::mutex> lk(state.mutex);
if (state.remaining != 0) {
state.cv.wait(lk);
}
}
if (state.eptr) {
std::rethrow_exception(state.eptr);
}
}
} // namespace internal
void init_num_threads() {
#ifdef _OPENMP
omp_set_num_threads(1);
#endif
#ifdef TH_BLAS_MKL
mkl_set_num_threads(1);
#endif
#ifdef C10_MOBILE
caffe2::mobile_threadpool();
#endif
}
void set_num_threads(int nthreads) {
#ifndef C10_MOBILE
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
int no_value = NOT_SET;
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
// num_intraop_threads either stores a positive integer or CONSUMED,
// check that requested size is the same as the current one
int stored_nthreads = num_intraop_threads.load();
if (stored_nthreads <= 0) {
// plus one because of master thread
stored_nthreads = _get_intraop_pool().size() + 1;
}
if (stored_nthreads != nthreads) {
TORCH_WARN(
"Cannot set number of intraop threads "
"after parallel work has started or after set_num_threads call "
"when using native parallel backend");
}
}
#else
TORCH_CHECK(false, "set_num_threads is not supported for mobile.");
#endif // C10_MOBILE
}
int get_num_threads() {
#ifndef C10_MOBILE
// not initializing pool unnecessarily,
// because pool cannot be resized after initialization
int nthreads = num_intraop_threads.load();
if (nthreads > 0) {
return nthreads;
} else if (nthreads == NOT_SET) {
return intraop_default_num_threads();
} else {
TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
return _get_intraop_pool().size() + 1;
}
#else
caffe2::ThreadPool* pool = caffe2::mobile_threadpool();
// caffe2::ThreadPool::getNumThreads() counts the current thread.
return !pool || in_parallel_region() ? 1 /* current thread */ : pool->getNumThreads();
#endif // C10_MOBILE
}
int get_thread_num() {
return thread_num_;
}
bool in_parallel_region() {
#ifndef C10_MOBILE
return in_parallel_region_ || (
num_intraop_threads.load() == CONSUMED &&
// Needed as intraop_launch() doesn't set in_parallel_region().
_get_intraop_pool().inThreadPool()
);
#else
return in_parallel_region_;
#endif // C10_MOBILE
}
void intraop_launch(std::function<void()> func) {
#ifndef C10_MOBILE
if (!in_parallel_region() && get_num_threads() > 1) {
_get_intraop_pool().run(func);
} else {
// execute inline if we're in parallel region
func();
}
#else
// TODO: caffe2::ThreadPool doesn't support submitting tasks separately and
// running in parallel. Should fix it when this API becomes popular.
func();
#endif // C10_MOBILE
}
std::shared_ptr<c10::ivalue::Future> intraop_launch_future(
std::function<void()> func) {
#ifndef C10_MOBILE
auto future = std::make_shared<c10::ivalue::Future>(c10::NoneType::get());
if (!in_parallel_region() && get_num_threads() > 1) {
_get_intraop_pool().run(
[func, future]() {
func();
future->markCompleted();
}
);
} else {
func();
future->markCompleted();
}
return future;
#else
// TODO: caffe2::ThreadPool doesn't support submitting tasks separately and
// running in parallel. Should fix it when this API becomes popular.
auto future = std::make_shared<c10::ivalue::Future>(NoneType::get());
func();
future->markCompleted();
return future;
#endif // C10_MOBILE
}
} // namespace at
#endif