forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Threshold.cu
117 lines (95 loc) · 3.85 KB
/
Threshold.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "THCApply.cuh"
#include "utils.h"
struct ThresholdUpdateOutput {
const float threshold_;
const float val_;
ThresholdUpdateOutput(float threshold, float val): threshold_(threshold),
val_(val) {}
__device__ __forceinline__ void operator()(float* out, float* in) {
float x = *in;
*out = (x > threshold_) ? x : val_;
}
};
// in-place variant
struct ThresholdUpdateOutputIP {
const float threshold_;
const float val_;
ThresholdUpdateOutputIP(float threshold, float val): threshold_(threshold),
val_(val) {}
__device__ __forceinline__ void operator()(float* x) {
*x = (*x > threshold_) ? *x : val_;
}
};
static int cunn_Threshold_updateOutput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
double val = luaT_getfieldchecknumber(L, 1, "val");
double threshold = luaT_getfieldchecknumber(L, 1, "threshold");
bool inPlace = luaT_getfieldcheckboolean(L, 1, "inplace");
THAssert(THCudaTensor_checkGPU(state, 2, input, output));
if (inPlace) {
THCudaTensor_pointwiseApply1(state, input,
ThresholdUpdateOutputIP(threshold, val));
THCudaTensor_set(state, output, input);
} else {
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input,
ThresholdUpdateOutput(threshold, val));
}
THCudaCheck(cudaGetLastError());
return 1;
}
struct ThresholdUpdateGradInput
{
const float threshold_;
ThresholdUpdateGradInput(float threshold) : threshold_(threshold) {}
__device__ __forceinline__ void operator()(float* gradInput, float* input,
float* gradOutput) const {
*gradInput = (*input > threshold_) ? *gradOutput : 0;
}
};
struct ThresholdUpdateGradInputIP
{
const float threshold_;
ThresholdUpdateGradInputIP(float threshold) : threshold_(threshold) {}
__device__ __forceinline__ void operator()(float* gradOutput,
float* input) const {
*gradOutput = (*input > threshold_) ? *gradOutput : 0;
}
};
static int cunn_Threshold_updateGradInput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
double val = luaT_getfieldchecknumber(L, 1, "val");
double threshold = luaT_getfieldchecknumber(L, 1, "threshold");
bool inPlace = luaT_getfieldcheckboolean(L, 1, "inplace");
THAssert(THCudaTensor_checkGPU(state, 4, input, output, gradInput, gradOutput));
if (inPlace) {
THCudaTensor_pointwiseApply2(state, gradOutput, input,
ThresholdUpdateGradInputIP(threshold));
THCudaTensor_set(state, gradInput, gradOutput);
} else {
THCudaTensor_resizeAs(state, gradInput, output);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput,
ThresholdUpdateGradInput(threshold));
}
THCudaCheck(cudaGetLastError());
return 1;
}
static const struct luaL_Reg cunn_Threshold__ [] = {
{"Threshold_updateOutput", cunn_Threshold_updateOutput},
{"Threshold_updateGradInput", cunn_Threshold_updateGradInput},
{NULL, NULL}
};
void cunn_Threshold_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_Threshold__, "nn");
lua_pop(L,1);
}