forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
operations.cpp
57 lines (51 loc) · 1.57 KB
/
operations.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <gtest/gtest.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
struct OperationTest : torch::test::SeedingFixture {
protected:
void SetUp() override {}
const int TEST_AMOUNT = 10;
};
TEST_F(OperationTest, Lerp) {
for (auto i = 0; i < TEST_AMOUNT; i++) {
// test lerp_kernel_scalar
auto start = torch::rand({3, 5});
auto end = torch::rand({3, 5});
auto scalar = 0.5;
// expected and actual
auto scalar_expected = start + scalar * (end - start);
auto out = torch::lerp(start, end, scalar);
// compare
ASSERT_EQ(out.dtype(), scalar_expected.dtype());
ASSERT_TRUE(out.allclose(scalar_expected));
// test lerp_kernel_tensor
auto weight = torch::rand({3, 5});
// expected and actual
auto tensor_expected = start + weight * (end - start);
out = torch::lerp(start, end, weight);
// compare
ASSERT_EQ(out.dtype(), tensor_expected.dtype());
ASSERT_TRUE(out.allclose(tensor_expected));
}
}
TEST_F(OperationTest, Cross) {
for (auto i = 0; i < TEST_AMOUNT; i++) {
// input
auto a = torch::rand({10, 3});
auto b = torch::rand({10, 3});
// expected
auto exp = torch::empty({10, 3});
for (int j = 0; j < 10; j++) {
auto u1 = a[j][0], u2 = a[j][1], u3 = a[j][2];
auto v1 = b[j][0], v2 = b[j][1], v3 = b[j][2];
exp[j][0] = u2 * v3 - v2 * u3;
exp[j][1] = v1 * u3 - u1 * v3;
exp[j][2] = u1 * v2 - v1 * u2;
}
// actual
auto out = torch::cross(a, b);
// compare
ASSERT_EQ(out.dtype(), exp.dtype());
ASSERT_TRUE(out.allclose(exp));
}
}