Skip to content

Commit

Permalink
fix(frontend-python): test_matmul.test_matmul, randomly failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Aug 2, 2024
1 parent 2968e80 commit 513e7dd
Showing 1 changed file with 137 additions and 131 deletions.
268 changes: 137 additions & 131 deletions frontends/concrete-python/tests/execution/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,127 +168,135 @@ def rhs_function(x):
helpers.check_execution(rhs_function_circuit, rhs_function, rhs_sample)


test_matmul_shape_and_bounds = [
(
(3, 2),
(2, 3),
(0, 3),
),
(
(3, 2),
(2, 3),
(0, 127),
),
(
(1, 2),
(2, 1),
(0, 3),
),
(
(3, 3),
(3, 3),
(0, 3),
),
(
(2, 1),
(1, 2),
(0, 7),
),
(
(2,),
(2,),
(0, 7),
),
(
(5, 5),
(5,),
(0, 3),
),
(
(5,),
(5, 5),
(0, 3),
),
(
(5,),
(5, 5),
(-63, 63),
),
(
(2,),
(2, 7),
(-63, 0),
),
(
(5,),
(5, 3),
(0, 3),
),
(
(5, 3),
(3,),
(0, 3),
),
(
(5,),
(4, 5, 3),
(-5, 5),
),
(
(4, 5, 3),
(3,),
(0, 5),
),
(
(5,),
(2, 4, 5, 3),
(0, 5),
),
(
(2, 4, 5, 3),
(3,),
(-1, 5),
),
(
(5, 4, 3),
(3, 2),
(0, 5),
),
(
(4, 3),
(5, 3, 2),
(0, 5),
),
(
(2, 5, 4, 3),
(3, 2),
(0, 5),
),
(
(5, 4, 3),
(1, 3, 2),
(0, 5),
),
(
(1, 4, 3),
(5, 3, 2),
(0, 5),
),
(
(5, 4, 3),
(2, 1, 3, 2),
(0, 5),
),
(
(2, 1, 4, 3),
(5, 3, 2),
(0, 5),
),
]

@pytest.mark.parametrize(
"lhs_shape,rhs_shape,bounds",
"lhs_shape,rhs_shape,bounds,clear_rhs",
[
pytest.param(
(3, 2),
(2, 3),
(0, 3),
),
pytest.param(
(3, 2),
(2, 3),
(0, 127),
),
pytest.param(
(1, 2),
(2, 1),
(0, 3),
),
pytest.param(
(3, 3),
(3, 3),
(0, 3),
),
pytest.param(
(2, 1),
(1, 2),
(0, 7),
),
pytest.param(
(2,),
(2,),
(0, 7),
),
pytest.param(
(5, 5),
(5,),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(-63, 63),
),
pytest.param(
(2,),
(2, 7),
(-63, 0),
),
pytest.param(
(5,),
(5, 3),
(0, 3),
),
pytest.param(
(5, 3),
(3,),
(0, 3),
),
pytest.param(
(5,),
(4, 5, 3),
(-5, 5),
),
pytest.param(
(4, 5, 3),
(3,),
(0, 5),
),
pytest.param(
(5,),
(2, 4, 5, 3),
(0, 5),
),
pytest.param(
(2, 4, 5, 3),
(3,),
(-1, 5),
),
pytest.param(
(5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(2, 5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(1, 3, 2),
(0, 5),
),
pytest.param(
(1, 4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(2, 1, 3, 2),
(0, 5),
),
pytest.param(
(2, 1, 4, 3),
(5, 3, 2),
(0, 5),
),
],
(
lhs_shape, rhs_shape, bounds, clear
)
for lhs_shape,rhs_shape,bounds in test_matmul_shape_and_bounds
for clear in [False, True]
]
)
def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
def test_matmul(lhs_shape, rhs_shape, bounds, clear_rhs, helpers):
"""
Test matmul.
"""
Expand All @@ -305,22 +313,20 @@ def clear(x, y):
def encrypted(x, y):
return np.matmul(x, y)

for implementation in [clear, encrypted]:
inputset = [
(
np.random.randint(minimum, maximum, size=lhs_shape),
np.random.randint(minimum, maximum, size=rhs_shape),
)
for i in range(100)
]
circuit = implementation.compile(inputset, configuration)

sample = [
implementation = clear if clear_rhs else encrypted

inputset = [
(
np.random.randint(minimum, maximum, size=lhs_shape),
np.random.randint(minimum, maximum, size=rhs_shape),
]
)
for _ in range(100)
]
circuit = implementation.compile(inputset, configuration)

sample = list(inputset[-1])

helpers.check_execution(circuit, implementation, sample, retries=3)
helpers.check_execution(circuit, implementation, sample, retries=3)


@pytest.mark.parametrize("bit_width", [4, 10])
Expand Down

0 comments on commit 513e7dd

Please sign in to comment.