forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Lerp.cpp
56 lines (46 loc) · 1.9 KB
/
Lerp.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <ATen/native/Lerp.h>
#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
namespace at {
namespace native {
Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self,
const Tensor& end, const Tensor& weight) {
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
lerp_kernel_tensor_weight(kCPU, result, self, end, weight);
return result;
}
Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self,
const Tensor& end, Scalar weight) {
lerp_kernel_scalar_weight(kCPU, result, self, end, weight);
return result;
}
Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) {
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
lerp_kernel_tensor_weight(kCPU, self, self, end, weight);
return self;
}
Tensor& lerp_cpu_scalar_(Tensor& self, const Tensor& end, Scalar weight) {
lerp_kernel_scalar_weight(kCPU, self, self, end, weight);
return self;
}
Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weight) {
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
Tensor result = at::empty({0}, self.options());
lerp_kernel_tensor_weight(kCPU, result, self, end, weight);
return result;
}
Tensor lerp_cpu_scalar(const Tensor& self, const Tensor& end, Scalar weight) {
Tensor result = at::empty({0}, self.options());
lerp_kernel_scalar_weight(kCPU, result, self, end, weight);
return result;
}
DEFINE_DISPATCH(lerp_kernel_scalar_weight);
DEFINE_DISPATCH(lerp_kernel_tensor_weight);
} // namespace native
} // namespace at