forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MultiMarginLoss.cu
384 lines (343 loc) · 12.8 KB
/
MultiMarginLoss.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/native/Resize.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAException.h>
namespace at {
namespace native {
namespace {
constexpr int MULTIMARGIN_THREADS = 128;
template <int P, typename scalar_t>
__global__ void MultiMarginLoss_forward_kernel(
scalar_t *output, scalar_t *input, int64_t *target, scalar_t *weights,
int nframe, int dim, bool sizeAverage, scalar_t margin) {
using acc_t = at::acc_type<scalar_t, true>;
__shared__ acc_t buffer[MULTIMARGIN_THREADS];
int k = blockIdx.x;
scalar_t *input_k = input + k*dim;
scalar_t *output_k = output + k;
int target_k = static_cast<int>(target[k]);
scalar_t input_target_k = input_k[target_k];
int i_start = threadIdx.x;
int i_end = dim;
int i_step = blockDim.x;
buffer[threadIdx.x] = 0;
for (int i = i_start; i < i_end; i += i_step) {
scalar_t z = margin - input_target_k + input_k[i];
if (i == target_k) {
continue;
}
if (z > 0) {
scalar_t h = (P==1) ? z : z*z;
if (weights) {
h *= weights[target_k];
}
buffer[threadIdx.x] += h;
}
}
__syncthreads();
// reduce
if (threadIdx.x == 0) {
acc_t sum = 0;
for (int i=0; i < blockDim.x; i++)
sum += buffer[i];
const int denom = sizeAverage ? nframe * dim : dim;
*output_k = static_cast<scalar_t>(sum / denom);
}
}
template <int P, typename scalar_t>
__global__ void MultiMarginLoss_backward_kernel(
scalar_t *gradInput, scalar_t *gradOutput, scalar_t *input, int64_t *target,
scalar_t *weights, int nframe, int dim, bool sizeAverage, scalar_t margin,
bool reduce) {
using acc_t = at::acc_type<scalar_t, true>;
__shared__ acc_t buffer[MULTIMARGIN_THREADS];
int k = blockIdx.x;
scalar_t *input_k = input + k*dim;
scalar_t *gradInput_k = gradInput + k*dim;
int target_k = static_cast<int>(target[k]);
scalar_t input_target_k = input_k[target_k];
scalar_t *gradOutput_k = gradOutput;
if (!reduce) {
gradOutput_k += k;
}
const int denom = sizeAverage && reduce ? nframe * dim : dim;
const acc_t g = acc_t(1) / static_cast<acc_t>(denom);
int i_start = threadIdx.x;
int i_end = dim;
int i_step = blockDim.x;
buffer[threadIdx.x] = 0;
for (int i=i_start; i<i_end; i+=i_step) {
scalar_t z = margin - input_target_k + input_k[i];
if (i == target_k) {
continue;
}
if (z > 0) {
acc_t h = (P == 1) ? g : 2*g*z;
if (weights) {
h *= weights[target_k];
}
buffer[threadIdx.x] -= static_cast<scalar_t>(h);
gradInput_k[i] = static_cast<scalar_t>(h);
} else {
gradInput_k[i] = static_cast<scalar_t>(0);
}
}
__syncthreads();
// reduce
if (threadIdx.x == 0) {
acc_t gradInput_target_k = 0;
for (int i=0; i<blockDim.x; i++) {
gradInput_target_k += buffer[i];
}
gradInput_k[target_k] = static_cast<scalar_t>(gradInput_target_k);
}
for (int i=i_start; i<i_end; i+= i_step) {
gradInput_k[i] *= * gradOutput_k;
}
}
void multi_margin_loss_shape_check(
const Tensor &input, const Tensor &target) {
auto in_sizes = input.sizes();
auto dims = in_sizes.size();
TORCH_CHECK(
(dims == 2 && in_sizes[1] != 0) || (dims == 1 && in_sizes[0] != 0) || dims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
in_sizes);
int64_t nframe = dims <= 1 ? 1 : in_sizes[0];
TORCH_CHECK(
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
}
} // namespace (anonymous)
Tensor& multi_margin_loss_cuda_out(
const Tensor &input_, const Tensor &target_, const Scalar &p_, const Scalar &margin_,
const c10::optional<Tensor> &weights_, int64_t reduction, Tensor& out_) {
auto p = p_.toLong();
TORCH_CHECK(p == 1 || p == 2, "multi_margin_loss: Invalid p, expected 1 or 2 but got ", p);
multi_margin_loss_shape_check(input_, target_);
if (reduction == at::Reduction::None) {
resize_output(out_, target_.sizes());
} else if (input_.dim() == 2) {
resize_output(out_, {input_.sizes()[0]});
} else {
resize_output(out_, {});
}
if (input_.numel() == 0) {
return out_;
}
auto input = input_.contiguous();
auto target = target_.contiguous();
Tensor weights;
if (weights_ && weights_->defined()) {
weights = weights_->contiguous();
}
auto out = (out_.is_contiguous() ? out_ :
at::empty(out_.sizes(), input.options()));
const auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "multi_margin_loss_cuda", [&] {
const scalar_t margin = margin_.to<scalar_t>();
if (input.dim() <= 1) {
int nframe = 1;
TORCH_CHECK(target.dim() <= 1 && target.numel() == nframe, "inconsistent target size");
dim3 blocks(1);
dim3 threads(MULTIMARGIN_THREADS);
if (p == 1) {
MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
1,
input.dim() < 1 ? input.numel() : input.sizes()[0],
reduction == at::Reduction::Mean,
margin);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (p == 2) {
MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
1,
input.dim() < 1 ? input.numel() : input.sizes()[0],
reduction == at::Reduction::Mean,
margin);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} else {
auto in_sizes = input.sizes();
TORCH_INTERNAL_ASSERT(in_sizes.size() == 2);
int nframe = in_sizes[0];
// allow zero-dim target for 2D input.
TORCH_CHECK(in_sizes[1] != 0 && target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size");
dim3 blocks(nframe);
dim3 threads(MULTIMARGIN_THREADS);
if (reduction == at::Reduction::None) {
if (p == 1) {
MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
nframe, in_sizes[1],
false,
margin);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (p == 2) {
MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
nframe, in_sizes[1],
false,
margin);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} else {
auto tmp_output = at::empty({nframe}, input.options());
if (p == 1) {
MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
tmp_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
nframe, in_sizes[1],
reduction == Reduction::Mean,
margin);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (p == 2) {
MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
tmp_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
nframe, in_sizes[1],
reduction == Reduction::Mean,
margin);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
at::sum_out(out, tmp_output, /*dims=*/IntArrayRef{});
}
}
});
if (!out.is_alias_of(out_)) {
out_.copy_(out);
}
return out_;
}
Tensor multi_margin_loss_cuda(
const Tensor &input, const Tensor &target, const Scalar &p, const Scalar &margin,
const c10::optional<Tensor> &weights, int64_t reduction) {
auto out = at::empty({}, input.options());
multi_margin_loss_cuda_out(input, target, p, margin, weights, reduction, out);
return out;
}
Tensor& multi_margin_loss_cuda_backward_out(
const Tensor &grad_output_,const Tensor &input_, const Tensor &target_,
const Scalar &p_, const Scalar &margin_, const c10::optional<Tensor> &weights_,
int64_t reduction, Tensor &grad_input_) {
auto p = p_.toLong();
TORCH_CHECK(p == 1 || p == 2,
"multi_margin_loss_backward: Invalid p, expected 1 or 2 but got ", p);
multi_margin_loss_shape_check(input_, target_);
resize_output(grad_input_, input_.sizes());
if (input_.numel() == 0) {
return grad_input_;
}
auto input = input_.contiguous();
auto grad_input = (grad_input_.is_contiguous() ? grad_input_ :
at::empty(grad_input_.sizes(), input.options()));
auto grad_output = grad_output_.contiguous();
auto target = target_.contiguous();
Tensor weights;
if (weights_ && weights_->defined()) {
weights = weights_->contiguous();
}
const auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"multi_margin_loss_backward_cuda", [&] {
const scalar_t margin = margin_.to<scalar_t>();
if (input.dim() <= 1) {
dim3 blocks(1);
dim3 threads(MULTIMARGIN_THREADS);
if (p == 1) {
MultiMarginLoss_backward_kernel<1> <<<blocks, threads, 0, stream>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
1,
input.dim() == 0 ? 1 : input.sizes()[0],
reduction == at::Reduction::Mean,
margin,
reduction != at::Reduction::None);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (p == 2) {
MultiMarginLoss_backward_kernel<2> <<<blocks, threads, 0, stream>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
1,
input.dim() == 0 ? 1 : input.sizes()[0],
reduction == at::Reduction::Mean,
margin,
reduction != at::Reduction::None);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} else {
auto in_sizes = input.sizes();
TORCH_INTERNAL_ASSERT(in_sizes.size() == 2);
int nframe = in_sizes[0];
TORCH_CHECK((in_sizes[1] != 0) && (target.dim() <= 1) && (target.numel() == nframe),
"inconsistent target size");
dim3 blocks(in_sizes[0]);
dim3 threads(MULTIMARGIN_THREADS);
if (p == 1) {
MultiMarginLoss_backward_kernel<1> <<<blocks, threads, 0, stream>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
nframe, in_sizes[1],
reduction == at::Reduction::Mean,
margin,
reduction != at::Reduction::None);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (p == 2) {
MultiMarginLoss_backward_kernel<2> <<<blocks, threads, 0, stream>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weights.defined() ? weights.data_ptr<scalar_t>() : nullptr,
nframe, in_sizes[1],
reduction == at::Reduction::Mean,
margin,
reduction != at::Reduction::None);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
});
if (!grad_input.is_alias_of(grad_input_)) {
grad_input_.copy_(grad_input);
}
return grad_input_;
}
Tensor multi_margin_loss_cuda_backward(
const Tensor &grad_output, const Tensor &input, const Tensor &target,
const Scalar &p, const Scalar &margin, const c10::optional<Tensor> &weights,
int64_t reduction) {
auto grad_input = at::empty({}, input.options());
multi_margin_loss_cuda_backward_out(
grad_output, input, target, p, margin, weights, reduction, grad_input);
return grad_input;
}
}} // namespace at::native