From 513e7dd3bece62a3bcf6f2a07cbbf7a42101b138 Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 2 Aug 2024 10:20:10 +0200 Subject: [PATCH] fix(frontend-python): test_matmul.test_matmul, randomly failing test --- .../tests/execution/test_matmul.py | 268 +++++++++--------- 1 file changed, 137 insertions(+), 131 deletions(-) diff --git a/frontends/concrete-python/tests/execution/test_matmul.py b/frontends/concrete-python/tests/execution/test_matmul.py index a5836b2bd6..a7def594f9 100644 --- a/frontends/concrete-python/tests/execution/test_matmul.py +++ b/frontends/concrete-python/tests/execution/test_matmul.py @@ -168,127 +168,135 @@ def rhs_function(x): helpers.check_execution(rhs_function_circuit, rhs_function, rhs_sample) +test_matmul_shape_and_bounds = [ + ( + (3, 2), + (2, 3), + (0, 3), + ), + ( + (3, 2), + (2, 3), + (0, 127), + ), + ( + (1, 2), + (2, 1), + (0, 3), + ), + ( + (3, 3), + (3, 3), + (0, 3), + ), + ( + (2, 1), + (1, 2), + (0, 7), + ), + ( + (2,), + (2,), + (0, 7), + ), + ( + (5, 5), + (5,), + (0, 3), + ), + ( + (5,), + (5, 5), + (0, 3), + ), + ( + (5,), + (5, 5), + (-63, 63), + ), + ( + (2,), + (2, 7), + (-63, 0), + ), + ( + (5,), + (5, 3), + (0, 3), + ), + ( + (5, 3), + (3,), + (0, 3), + ), + ( + (5,), + (4, 5, 3), + (-5, 5), + ), + ( + (4, 5, 3), + (3,), + (0, 5), + ), + ( + (5,), + (2, 4, 5, 3), + (0, 5), + ), + ( + (2, 4, 5, 3), + (3,), + (-1, 5), + ), + ( + (5, 4, 3), + (3, 2), + (0, 5), + ), + ( + (4, 3), + (5, 3, 2), + (0, 5), + ), + ( + (2, 5, 4, 3), + (3, 2), + (0, 5), + ), + ( + (5, 4, 3), + (1, 3, 2), + (0, 5), + ), + ( + (1, 4, 3), + (5, 3, 2), + (0, 5), + ), + ( + (5, 4, 3), + (2, 1, 3, 2), + (0, 5), + ), + ( + (2, 1, 4, 3), + (5, 3, 2), + (0, 5), + ), +] + @pytest.mark.parametrize( - "lhs_shape,rhs_shape,bounds", + "lhs_shape,rhs_shape,bounds,clear_rhs", [ - pytest.param( - (3, 2), - (2, 3), - (0, 3), - ), - pytest.param( - (3, 2), - (2, 3), - (0, 127), - ), - pytest.param( - (1, 2), - (2, 1), - (0, 3), - ), - pytest.param( - (3, 3), - (3, 3), - (0, 3), - ), - pytest.param( - (2, 1), - (1, 2), - (0, 7), - ), - pytest.param( - (2,), - (2,), - (0, 7), - ), - pytest.param( - (5, 5), - (5,), - (0, 3), - ), - pytest.param( - (5,), - (5, 5), - (0, 3), - ), - pytest.param( - (5,), - (5, 5), - (-63, 63), - ), - pytest.param( - (2,), - (2, 7), - (-63, 0), - ), - pytest.param( - (5,), - (5, 3), - (0, 3), - ), - pytest.param( - (5, 3), - (3,), - (0, 3), - ), - pytest.param( - (5,), - (4, 5, 3), - (-5, 5), - ), - pytest.param( - (4, 5, 3), - (3,), - (0, 5), - ), - pytest.param( - (5,), - (2, 4, 5, 3), - (0, 5), - ), - pytest.param( - (2, 4, 5, 3), - (3,), - (-1, 5), - ), - pytest.param( - (5, 4, 3), - (3, 2), - (0, 5), - ), - pytest.param( - (4, 3), - (5, 3, 2), - (0, 5), - ), - pytest.param( - (2, 5, 4, 3), - (3, 2), - (0, 5), - ), - pytest.param( - (5, 4, 3), - (1, 3, 2), - (0, 5), - ), - pytest.param( - (1, 4, 3), - (5, 3, 2), - (0, 5), - ), - pytest.param( - (5, 4, 3), - (2, 1, 3, 2), - (0, 5), - ), - pytest.param( - (2, 1, 4, 3), - (5, 3, 2), - (0, 5), - ), - ], + ( + lhs_shape, rhs_shape, bounds, clear + ) + for lhs_shape,rhs_shape,bounds in test_matmul_shape_and_bounds + for clear in [False, True] + ] ) -def test_matmul(lhs_shape, rhs_shape, bounds, helpers): +def test_matmul(lhs_shape, rhs_shape, bounds, clear_rhs, helpers): """ Test matmul. """ @@ -305,22 +313,20 @@ def clear(x, y): def encrypted(x, y): return np.matmul(x, y) - for implementation in [clear, encrypted]: - inputset = [ - ( - np.random.randint(minimum, maximum, size=lhs_shape), - np.random.randint(minimum, maximum, size=rhs_shape), - ) - for i in range(100) - ] - circuit = implementation.compile(inputset, configuration) - - sample = [ + implementation = clear if clear_rhs else encrypted + + inputset = [ + ( np.random.randint(minimum, maximum, size=lhs_shape), np.random.randint(minimum, maximum, size=rhs_shape), - ] + ) + for _ in range(100) + ] + circuit = implementation.compile(inputset, configuration) + + sample = list(inputset[-1]) - helpers.check_execution(circuit, implementation, sample, retries=3) + helpers.check_execution(circuit, implementation, sample, retries=3) @pytest.mark.parametrize("bit_width", [4, 10])