forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BinaryCompareKernel.cu
113 lines (100 loc) · 3.7 KB
/
BinaryCompareKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/zmath.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at { namespace native {
void lt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "lt_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
}
void le_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "le_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a <= b;
});
});
}
void gt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "gt_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a > b;
});
});
}
void ge_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "ge_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a >= b;
});
});
}
void eq_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "eq_cuda", [&]() {
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool {
return a == b;
});
});
}
void ne_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "ne_cuda", [&]() {
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool {
return a != b;
});
});
}
void max_elementwise_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
gpu_kernel(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
return a || b;
});
} else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "max_elementwise_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::max(a, b);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "max_elementwise_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::max(a, b);
});
});
}
}
void min_elementwise_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
gpu_kernel(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
return a && b;
});
} else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "min_elementwise_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::min(a, b);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "min_elementwise_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::min(a, b);
});
});
}
}
REGISTER_DISPATCH(lt_stub, <_kernel_cuda);
REGISTER_DISPATCH(le_stub, &le_kernel_cuda);
REGISTER_DISPATCH(gt_stub, >_kernel_cuda);
REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda);
REGISTER_DISPATCH(eq_stub, &eq_kernel_cuda);
REGISTER_DISPATCH(ne_stub, &ne_kernel_cuda);
REGISTER_DISPATCH(max_elementwise_stub, &max_elementwise_kernel_cuda);
REGISTER_DISPATCH(min_elementwise_stub, &min_elementwise_kernel_cuda);
}} // namespace at::native