forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
functional_test.py
122 lines (98 loc) · 4.11 KB
/
functional_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import unittest
from caffe2.python import core
from hypothesis import given
import hypothesis.strategies as st
import caffe2.python.hypothesis_test_util as hu
from caffe2.python import workspace
from caffe2.python.functional import Functional
import numpy as np
@st.composite
def _tensor_splits(draw, add_axis=False):
"""Generates (axis, split_info, tensor_splits) tuples."""
tensor = draw(hu.tensor(min_value=4)) # Each dim has at least 4 elements.
axis = draw(st.integers(0, len(tensor.shape) - 1))
if add_axis:
# Simple case: get individual slices along one axis, where each of them
# is (N-1)-dimensional. The axis will be added back upon concatenation.
return (
axis, np.ones(tensor.shape[axis], dtype=np.int32), [
np.array(tensor.take(i, axis=axis))
for i in range(tensor.shape[axis])
]
)
else:
# General case: pick some (possibly consecutive, even non-unique)
# indices at which we will split the tensor, along the given axis.
splits = sorted(
draw(
st.
lists(elements=st.integers(0, tensor.shape[axis]), max_size=4)
) + [0, tensor.shape[axis]]
)
return (
axis, np.array(np.diff(splits), dtype=np.int32), [
tensor.take(range(splits[i], splits[i + 1]), axis=axis)
for i in range(len(splits) - 1)
],
)
class TestFunctional(hu.HypothesisTestCase):
@given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
def test_relu(self, X, engine, gc, dc):
X += 0.02 * np.sign(X)
X[X == 0.0] += 0.02
output = Functional.Relu(X, device_option=gc)
Y_l = output[0]
Y_d = output["output_0"]
with workspace.WorkspaceGuard("tmp_workspace"):
op = core.CreateOperator("Relu", ["X"], ["Y"], engine=engine)
workspace.FeedBlob("X", X)
workspace.RunOperatorOnce(op)
Y_ref = workspace.FetchBlob("Y")
np.testing.assert_array_equal(
Y_l, Y_ref, err_msg='Functional Relu result mismatch'
)
np.testing.assert_array_equal(
Y_d, Y_ref, err_msg='Functional Relu result mismatch'
)
@given(tensor_splits=_tensor_splits(), **hu.gcs)
def test_concat(self, tensor_splits, gc, dc):
# Input Size: 1 -> inf
axis, _, splits = tensor_splits
concat_result, split_info = Functional.Concat(*splits, axis=axis, device_option=gc)
concat_result_ref = np.concatenate(splits, axis=axis)
split_info_ref = np.array([a.shape[axis] for a in splits])
np.testing.assert_array_equal(
concat_result,
concat_result_ref,
err_msg='Functional Concat result mismatch'
)
np.testing.assert_array_equal(
split_info,
split_info_ref,
err_msg='Functional Concat split info mismatch'
)
@given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans(), **hu.gcs)
def test_split(self, tensor_splits, split_as_arg, gc, dc):
# Output Size: 1 - inf
axis, split_info, splits = tensor_splits
split_as_arg = True
if split_as_arg:
input_tensors = [np.concatenate(splits, axis=axis)]
kwargs = dict(axis=axis, split=split_info, num_output=len(splits))
else:
input_tensors = [np.concatenate(splits, axis=axis), split_info]
kwargs = dict(axis=axis, num_output=len(splits))
result = Functional.Split(*input_tensors, device_option=gc, **kwargs)
def split_ref(input, split=split_info):
s = np.cumsum([0] + list(split))
return [
np.array(input.take(np.arange(s[i], s[i + 1]), axis=axis))
for i in range(len(split))
]
result_ref = split_ref(*input_tensors)
for i, ref in enumerate(result_ref):
np.testing.assert_array_equal(
result[i], ref, err_msg='Functional Relu result mismatch'
)
if __name__ == '__main__':
unittest.main()