forked from openmlsys/openmlsys-cuda
-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.cuh
62 lines (50 loc) · 1.85 KB
/
util.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef GEMM_UTIL_CUH
#define GEMM_UTIL_CUH
namespace openmlsys {
template <int _m, int _n, int _k = 1>
struct Layout {
static constexpr int m = _m;
static constexpr int n = _n;
static constexpr int k = _k;
};
struct __device_builtin__ __builtin_align__(16) float4 {
float data[4];
__host__ __device__ float operator[](unsigned idx) const { return data[idx]; }
__host__ __device__ float &operator[](unsigned idx) { return data[idx]; }
__host__ __device__ float4 operator*(float other) const {
return float4{data[0] * other, data[1] * other, data[2] * other,
data[3] * other};
}
__host__ __device__ float4 operator+(const float4 &other) const {
return float4{data[0] + other.data[0], data[1] + other.data[1],
data[2] + other.data[2], data[3] + other.data[3]};
}
};
template <typename T>
struct __device_builtin__ Tensor2D {
T *const __restrict__ ptr;
const unsigned rows, cols;
int _rowOffset{0}, _colOffset{0};
template <typename t>
__host__ __device__ Tensor2D(t &&ptr, unsigned rows, unsigned cols)
: ptr{reinterpret_cast<T *>(ptr)}, rows{rows}, cols{cols} {};
template <typename t = T>
__host__ __device__ void addOffset(int rowOffset, int colOffset) {
_rowOffset += rowOffset;
_colOffset += colOffset * sizeof(t) / sizeof(T);
}
__host__ __device__ bool validRowOffset(int rowOffset) const {
return (_rowOffset + rowOffset) < rows;
}
__host__ __device__ bool validColOffset(int colOffset) const {
return (_colOffset + colOffset) < cols;
}
__host__ __device__ bool validOffset(int rowOffset, int colOffset) const {
return validRowOffset(rowOffset) && validColOffset(colOffset);
}
__host__ __device__ T &operator()(int row, int col) const {
return ptr[_colOffset + col + (row + _rowOffset) * cols];
}
};
} // namespace openmlsys
#endif // GEMM_UTIL_CUH