forked from plaidml/plaidml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
platform_test.cc
95 lines (78 loc) · 2.49 KB
/
platform_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Copyright 2018, Intel Corporation.
#include <gtest/gtest.h>
#include "tile/base/platform_test.h"
#include "tile/platform/local_machine/platform.h"
namespace vertexai {
namespace tile {
namespace testing {
namespace {
Param supported_params[] = {
{DataType::INT8, 1}, //
{DataType::INT8, 2}, //
{DataType::INT8, 4}, //
{DataType::INT8, 8}, //
{DataType::INT8, 16}, //
{DataType::INT16, 1}, //
{DataType::INT16, 2}, //
{DataType::INT16, 4}, //
{DataType::INT16, 8}, //
{DataType::INT16, 16}, //
{DataType::INT32, 1}, //
{DataType::INT32, 2}, //
{DataType::INT32, 4}, //
{DataType::INT32, 8}, //
{DataType::INT32, 16}, //
{DataType::INT64, 1}, //
{DataType::INT64, 2}, //
{DataType::UINT8, 1}, //
{DataType::UINT8, 2}, //
{DataType::UINT8, 4}, //
{DataType::UINT8, 8}, //
{DataType::UINT8, 16}, //
{DataType::UINT16, 1}, //
{DataType::UINT16, 2}, //
{DataType::UINT16, 4}, //
{DataType::UINT16, 8}, //
{DataType::UINT16, 16}, //
{DataType::UINT32, 1}, //
{DataType::UINT32, 2}, //
{DataType::UINT32, 4}, //
{DataType::UINT32, 8}, //
{DataType::UINT32, 16}, //
{DataType::UINT64, 1}, //
{DataType::UINT64, 2}, //
{DataType::FLOAT16, 1}, //
{DataType::FLOAT16, 2}, //
{DataType::FLOAT16, 4}, //
{DataType::FLOAT16, 8}, //
{DataType::FLOAT16, 16}, //
{DataType::FLOAT32, 1}, //
{DataType::FLOAT32, 2}, //
{DataType::FLOAT32, 4}, //
{DataType::FLOAT32, 8}, //
{DataType::FLOAT32, 16}, //
// TODO: enable these tests by querying the target device for
// support of 64-bit floating point types
// {DataType::FLOAT64, 1}, //
// {DataType::FLOAT64, 2}, //
};
std::vector<FactoryParam> SupportedParams() {
std::vector<FactoryParam> params;
for (const Param& param : supported_params) {
auto factory = [param] {
context::Context ctx;
local_machine::proto::Platform config;
auto hw_config = config.add_hardware_configs();
hw_config->mutable_sel()->set_value(true);
hw_config->mutable_settings()->set_vec_size(param.vec_size);
return std::make_unique<local_machine::Platform>(ctx, config);
};
params.push_back({factory, param});
}
return params;
}
INSTANTIATE_TEST_CASE_P(OpenCl, PlatformTest, ::testing::ValuesIn(SupportedParams()));
} // namespace
} // namespace testing
} // namespace tile
} // namespace vertexai