Skip to content

Commit

Permalink
Merge pull request #424 from no-longer-human/par-14407-sum-and-produc…
Browse files Browse the repository at this point in the history
…t-identity-element

 Product and Sum applications should return identity when given empty list as input
  • Loading branch information
quackzar authored Jun 21, 2024
2 parents 3119c1f + daaa6f4 commit e801cf7
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,58 @@
import dk.alexandra.fresco.framework.value.SInt;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
* ComputationBuilder for multiplying a list of SInts.
* Builder for a {@link Computation}s that computes the product of a list of {@link SInts}
* secret-shared values.
*
* <p>Empty lists are allowed, and will always produce a value of {@code 1}.
*/
public class ProductSIntList implements Computation<SInt, ProtocolBuilderNumeric> {

private final List<DRes<SInt>> input;

/**
* Creates a new ProductSIntList.
* Creates a new {@link ProductSIntList}.
*
* @param list the list to sum
* @param list the list to sum. Not nullable.
*/
public ProductSIntList(List<DRes<SInt>> list) {
input = list;
input = Objects.requireNonNull(list);
}

@Override
public DRes<SInt> buildComputation(ProtocolBuilderNumeric iterationBuilder) {
return iterationBuilder.seq(seq ->
() -> input
).whileLoop(
(inputs) -> inputs.size() > 1,
(seq, inputs) -> seq.par(parallel -> {
List<DRes<SInt>> out = new ArrayList<>();
Numeric numericBuilder = parallel.numeric();
DRes<SInt> left = null;
for (DRes<SInt> input1 : inputs) {
if (left == null) {
left = input1;
} else {
out.add(numericBuilder.mult(left, input1));
left = null;
}
}
if (left != null) {
out.add(left);
}
return () -> out;
})
).seq((builder, currentInput) -> currentInput.get(0));
// Fast case if there is nothing to compute.
if (input.isEmpty()) {
return iterationBuilder.seq(seq -> seq.numeric().known(1));
}

// Slow case when there are elements to compute on.
return iterationBuilder
.seq(seq -> () -> input)
.whileLoop(
(inputs) -> inputs.size() > 1,
(seq, inputs) ->
seq.par(
parallel -> {
List<DRes<SInt>> out = new ArrayList<>();
Numeric numericBuilder = parallel.numeric();
DRes<SInt> left = null;
for (DRes<SInt> input1 : inputs) {
if (left == null) {
left = input1;
} else {
out.add(numericBuilder.mult(left, input1));
left = null;
}
}
if (left != null) {
out.add(left);
}
return () -> out;
}))
.seq((builder, currentInput) -> currentInput.get(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,57 @@
import dk.alexandra.fresco.framework.value.SInt;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
* ComputationBuilder for summing a list of SInts.
* Builder for a {@link Computation}s that sum lists of {@link SInts} secret-shared values.
*
* <p>Empty lists are allowed, and will always produce a value of {@code 0}.
*/
public class SumSIntList implements Computation<SInt, ProtocolBuilderNumeric> {

private final List<DRes<SInt>> input;

/**
* Creates a new SumSIntList.
* Creates a new {@link SumSIntList}.
*
* @param list the list to sum
* @param list the list to sum. Not nullable.
*/
public SumSIntList(List<DRes<SInt>> list) {
input = list;
input = Objects.requireNonNull(list);
}

@Override
public DRes<SInt> buildComputation(ProtocolBuilderNumeric iterationBuilder) {
return iterationBuilder.seq(seq ->
() -> input
).whileLoop(
(inputs) -> inputs.size() > 1,
(seq, inputs) -> seq.par(parallel -> {
List<DRes<SInt>> out = new ArrayList<>();
Numeric numericBuilder = parallel.numeric();
DRes<SInt> left = null;
for (DRes<SInt> input1 : inputs) {
if (left == null) {
left = input1;
} else {
out.add(numericBuilder.add(left, input1));
left = null;
}
}
if (left != null) {
out.add(left);
}
return () -> out;
})
).seq((builder, currentInput) -> currentInput.get(0));
// Fast case if there is nothing to sum.
if (input.isEmpty()) {
return iterationBuilder.seq(seq -> seq.numeric().known(0));
}

// Slow case when there are elements to sum.
return iterationBuilder
.seq(seq -> () -> input)
.whileLoop(
(inputs) -> inputs.size() > 1,
(seq, inputs) ->
seq.par(
parallel -> {
List<DRes<SInt>> out = new ArrayList<>();
Numeric numericBuilder = parallel.numeric();
DRes<SInt> left = null;
for (DRes<SInt> input1 : inputs) {
if (left == null) {
left = input1;
} else {
out.add(numericBuilder.add(left, input1));
left = null;
}
}
if (left != null) {
out.add(left);
}
return () -> out;
}))
.seq((builder, currentInput) -> currentInput.get(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,97 @@
import java.util.List;
import java.util.stream.Collectors;

/** Test of {@link ProductSIntList} and {@link SumSIntList}. */
public class TestProductAndSum {

public static class TestProduct<ResourcePoolT extends ResourcePool>
extends TestThreadFactory<ResourcePoolT, ProtocolBuilderNumeric> {
private static final class TestCase {
public final BigInteger expectedOutput;
public final List<BigInteger> inputs;

public TestCase(long expectedOutput, long... inputs) {
this.expectedOutput = BigInteger.valueOf(expectedOutput);
this.inputs =
Arrays.stream(inputs)
.mapToObj(BigInteger::valueOf)
.collect(Collectors.toUnmodifiableList());
}
}

List<BigInteger> inputs = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L).stream().map(
BigInteger::valueOf).collect(Collectors.toList());
private static final List<TestCase> TEST_CASES_SUM =
List.of(
new TestCase(0),
new TestCase(123, 123),
new TestCase(2, 1, 1),
new TestCase(4, 2, 2),
new TestCase(6, 3, 3),
new TestCase(15, 1, 2, 4, 8),
new TestCase(55, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10));

private static final List<TestCase> TEST_CASES_PRODUCT =
List.of(
new TestCase(1),
new TestCase(123, 123),
new TestCase(1, 1, 1),
new TestCase(4, 2, 2),
new TestCase(9, 3, 3),
new TestCase(64, 1, 2, 4, 8),
new TestCase(3628800, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10));

/** Test of {@link ProductSIntList}. */
public static final class TestProduct<ResourcePoolT extends ResourcePool>
extends TestThreadFactory<ResourcePoolT, ProtocolBuilderNumeric> {

@Override
public TestThread<ResourcePoolT, ProtocolBuilderNumeric> next() {
return new TestThread<ResourcePoolT, ProtocolBuilderNumeric>() {

@Override
public void test() throws Exception {
// define functionality to be tested
Application<BigInteger, ProtocolBuilderNumeric> testApplication =
root -> {
List<DRes<SInt>> closed = inputs.stream().map(root.numeric()::known)
.collect(Collectors.toList());
DRes<SInt> result = AdvancedNumeric.using(root).product(closed);
DRes<BigInteger> open = root.numeric().open(result);
return () -> open.out();
};
BigInteger output = runApplication(testApplication);
assertEquals(output, inputs.stream().reduce(BigInteger.ONE, (a, b) -> a.multiply(b)));
for (final TestCase testCase : TEST_CASES_PRODUCT) {
// define functionality to be tested
Application<BigInteger, ProtocolBuilderNumeric> testApplication =
root -> {
List<DRes<SInt>> closed =
testCase.inputs.stream()
.map(root.numeric()::known)
.collect(Collectors.toUnmodifiableList());
DRes<SInt> result = AdvancedNumeric.using(root).product(closed);
DRes<BigInteger> open = root.numeric().open(result);
return () -> open.out();
};
BigInteger output = runApplication(testApplication);
assertEquals(testCase.expectedOutput, output);
}
}
};
}
}

public static class TestSum<ResourcePoolT extends ResourcePool>
/** Test of {@link SumSIntList}. */
public static final class TestSum<ResourcePoolT extends ResourcePool>
extends TestThreadFactory<ResourcePoolT, ProtocolBuilderNumeric> {

List<BigInteger> inputs = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L).stream().map(
BigInteger::valueOf).collect(Collectors.toList());

@Override
public TestThread<ResourcePoolT, ProtocolBuilderNumeric> next() {
return new TestThread<ResourcePoolT, ProtocolBuilderNumeric>() {

@Override
public void test() throws Exception {
// define functionality to be tested
Application<BigInteger, ProtocolBuilderNumeric> testApplication =
root -> {
List<DRes<SInt>> closed = inputs.stream().map(root.numeric()::known)
.collect(Collectors.toList());
DRes<SInt> result = AdvancedNumeric.using(root).sum(closed);
DRes<BigInteger> open = root.numeric().open(result);
return () -> open.out();
};
BigInteger output = runApplication(testApplication);
assertEquals(output, inputs.stream().reduce(BigInteger.ZERO, (a, b) -> a.add(b)));
for (final TestCase testCase : TEST_CASES_SUM) {
// define functionality to be tested
Application<BigInteger, ProtocolBuilderNumeric> testApplication =
root -> {
List<DRes<SInt>> closed =
testCase.inputs.stream()
.map(root.numeric()::known)
.collect(Collectors.toUnmodifiableList());
DRes<SInt> result = AdvancedNumeric.using(root).sum(closed);
DRes<BigInteger> open = root.numeric().open(result);
return () -> open.out();
};
BigInteger output = runApplication(testApplication);
assertEquals(testCase.expectedOutput, output);
}
}
};
}
Expand Down

0 comments on commit e801cf7

Please sign in to comment.