diff --git a/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/ProductSIntList.java b/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/ProductSIntList.java index 6860b90e5..f88be1fc0 100644 --- a/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/ProductSIntList.java +++ b/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/ProductSIntList.java @@ -7,46 +7,58 @@ import dk.alexandra.fresco.framework.value.SInt; import java.util.ArrayList; import java.util.List; +import java.util.Objects; /** - * ComputationBuilder for multiplying a list of SInts. + * Builder for a {@link Computation}s that computes the product of a list of {@link SInts} + * secret-shared values. + * + *

Empty lists are allowed, and will always produce a value of {@code 1}. */ public class ProductSIntList implements Computation { private final List> input; /** - * Creates a new ProductSIntList. + * Creates a new {@link ProductSIntList}. * - * @param list the list to sum + * @param list the list to sum. Not nullable. */ public ProductSIntList(List> list) { - input = list; + input = Objects.requireNonNull(list); } @Override public DRes buildComputation(ProtocolBuilderNumeric iterationBuilder) { - return iterationBuilder.seq(seq -> - () -> input - ).whileLoop( - (inputs) -> inputs.size() > 1, - (seq, inputs) -> seq.par(parallel -> { - List> out = new ArrayList<>(); - Numeric numericBuilder = parallel.numeric(); - DRes left = null; - for (DRes input1 : inputs) { - if (left == null) { - left = input1; - } else { - out.add(numericBuilder.mult(left, input1)); - left = null; - } - } - if (left != null) { - out.add(left); - } - return () -> out; - }) - ).seq((builder, currentInput) -> currentInput.get(0)); + // Fast case if there is nothing to compute. + if (input.isEmpty()) { + return iterationBuilder.seq(seq -> seq.numeric().known(1)); + } + + // Slow case when there are elements to compute on. + return iterationBuilder + .seq(seq -> () -> input) + .whileLoop( + (inputs) -> inputs.size() > 1, + (seq, inputs) -> + seq.par( + parallel -> { + List> out = new ArrayList<>(); + Numeric numericBuilder = parallel.numeric(); + DRes left = null; + for (DRes input1 : inputs) { + if (left == null) { + left = input1; + } else { + out.add(numericBuilder.mult(left, input1)); + left = null; + } + } + if (left != null) { + out.add(left); + } + return () -> out; + })) + .seq((builder, currentInput) -> currentInput.get(0)); } } diff --git a/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/SumSIntList.java b/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/SumSIntList.java index 9092d413d..b03b8f826 100644 --- a/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/SumSIntList.java +++ b/lib/common/src/main/java/dk/alexandra/fresco/lib/common/math/integer/SumSIntList.java @@ -7,46 +7,57 @@ import dk.alexandra.fresco.framework.value.SInt; import java.util.ArrayList; import java.util.List; +import java.util.Objects; /** - * ComputationBuilder for summing a list of SInts. + * Builder for a {@link Computation}s that sum lists of {@link SInts} secret-shared values. + * + *

Empty lists are allowed, and will always produce a value of {@code 0}. */ public class SumSIntList implements Computation { private final List> input; /** - * Creates a new SumSIntList. + * Creates a new {@link SumSIntList}. * - * @param list the list to sum + * @param list the list to sum. Not nullable. */ public SumSIntList(List> list) { - input = list; + input = Objects.requireNonNull(list); } @Override public DRes buildComputation(ProtocolBuilderNumeric iterationBuilder) { - return iterationBuilder.seq(seq -> - () -> input - ).whileLoop( - (inputs) -> inputs.size() > 1, - (seq, inputs) -> seq.par(parallel -> { - List> out = new ArrayList<>(); - Numeric numericBuilder = parallel.numeric(); - DRes left = null; - for (DRes input1 : inputs) { - if (left == null) { - left = input1; - } else { - out.add(numericBuilder.add(left, input1)); - left = null; - } - } - if (left != null) { - out.add(left); - } - return () -> out; - }) - ).seq((builder, currentInput) -> currentInput.get(0)); + // Fast case if there is nothing to sum. + if (input.isEmpty()) { + return iterationBuilder.seq(seq -> seq.numeric().known(0)); + } + + // Slow case when there are elements to sum. + return iterationBuilder + .seq(seq -> () -> input) + .whileLoop( + (inputs) -> inputs.size() > 1, + (seq, inputs) -> + seq.par( + parallel -> { + List> out = new ArrayList<>(); + Numeric numericBuilder = parallel.numeric(); + DRes left = null; + for (DRes input1 : inputs) { + if (left == null) { + left = input1; + } else { + out.add(numericBuilder.add(left, input1)); + left = null; + } + } + if (left != null) { + out.add(left); + } + return () -> out; + })) + .seq((builder, currentInput) -> currentInput.get(0)); } } diff --git a/lib/common/src/test/java/dk/alexandra/fresco/lib/common/math/integer/TestProductAndSum.java b/lib/common/src/test/java/dk/alexandra/fresco/lib/common/math/integer/TestProductAndSum.java index 65bb34a7a..9b541e62a 100644 --- a/lib/common/src/test/java/dk/alexandra/fresco/lib/common/math/integer/TestProductAndSum.java +++ b/lib/common/src/test/java/dk/alexandra/fresco/lib/common/math/integer/TestProductAndSum.java @@ -15,13 +15,45 @@ import java.util.List; import java.util.stream.Collectors; +/** Test of {@link ProductSIntList} and {@link SumSIntList}. */ public class TestProductAndSum { - public static class TestProduct - extends TestThreadFactory { + private static final class TestCase { + public final BigInteger expectedOutput; + public final List inputs; + + public TestCase(long expectedOutput, long... inputs) { + this.expectedOutput = BigInteger.valueOf(expectedOutput); + this.inputs = + Arrays.stream(inputs) + .mapToObj(BigInteger::valueOf) + .collect(Collectors.toUnmodifiableList()); + } + } - List inputs = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L).stream().map( - BigInteger::valueOf).collect(Collectors.toList()); + private static final List TEST_CASES_SUM = + List.of( + new TestCase(0), + new TestCase(123, 123), + new TestCase(2, 1, 1), + new TestCase(4, 2, 2), + new TestCase(6, 3, 3), + new TestCase(15, 1, 2, 4, 8), + new TestCase(55, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + + private static final List TEST_CASES_PRODUCT = + List.of( + new TestCase(1), + new TestCase(123, 123), + new TestCase(1, 1, 1), + new TestCase(4, 2, 2), + new TestCase(9, 3, 3), + new TestCase(64, 1, 2, 4, 8), + new TestCase(3628800, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + + /** Test of {@link ProductSIntList}. */ + public static final class TestProduct + extends TestThreadFactory { @Override public TestThread next() { @@ -29,45 +61,51 @@ public TestThread next() { @Override public void test() throws Exception { - // define functionality to be tested - Application testApplication = - root -> { - List> closed = inputs.stream().map(root.numeric()::known) - .collect(Collectors.toList()); - DRes result = AdvancedNumeric.using(root).product(closed); - DRes open = root.numeric().open(result); - return () -> open.out(); - }; - BigInteger output = runApplication(testApplication); - assertEquals(output, inputs.stream().reduce(BigInteger.ONE, (a, b) -> a.multiply(b))); + for (final TestCase testCase : TEST_CASES_PRODUCT) { + // define functionality to be tested + Application testApplication = + root -> { + List> closed = + testCase.inputs.stream() + .map(root.numeric()::known) + .collect(Collectors.toUnmodifiableList()); + DRes result = AdvancedNumeric.using(root).product(closed); + DRes open = root.numeric().open(result); + return () -> open.out(); + }; + BigInteger output = runApplication(testApplication); + assertEquals(testCase.expectedOutput, output); + } } }; } } - public static class TestSum + /** Test of {@link SumSIntList}. */ + public static final class TestSum extends TestThreadFactory { - List inputs = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L).stream().map( - BigInteger::valueOf).collect(Collectors.toList()); - @Override public TestThread next() { return new TestThread() { @Override public void test() throws Exception { - // define functionality to be tested - Application testApplication = - root -> { - List> closed = inputs.stream().map(root.numeric()::known) - .collect(Collectors.toList()); - DRes result = AdvancedNumeric.using(root).sum(closed); - DRes open = root.numeric().open(result); - return () -> open.out(); - }; - BigInteger output = runApplication(testApplication); - assertEquals(output, inputs.stream().reduce(BigInteger.ZERO, (a, b) -> a.add(b))); + for (final TestCase testCase : TEST_CASES_SUM) { + // define functionality to be tested + Application testApplication = + root -> { + List> closed = + testCase.inputs.stream() + .map(root.numeric()::known) + .collect(Collectors.toUnmodifiableList()); + DRes result = AdvancedNumeric.using(root).sum(closed); + DRes open = root.numeric().open(result); + return () -> open.out(); + }; + BigInteger output = runApplication(testApplication); + assertEquals(testCase.expectedOutput, output); + } } }; }