Skip to content

Commit

Permalink
Improve test clarity and avoid code duplication
Browse files Browse the repository at this point in the history
Co-authored-by: Francisco Ferrari <[email protected]>
Co-authored-by: Martin Balao <[email protected]>
  • Loading branch information
franferrax and martinuy committed Jun 7, 2024
1 parent 2c6a3c0 commit 33a64b2
Showing 1 changed file with 68 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,25 @@ private static byte[] computeExpected(byte[] jointPlaintext)
return ciphertext;
}

private static byte[] join(byte[][] inputChunks, int totalLength) {
ByteBuffer outputBuf = ByteBuffer.allocate(totalLength);
for (byte[] inputChunk : inputChunks) {
outputBuf.put(inputChunk);
}
return outputBuf.array();
}

private static byte[][] split(byte[] input, int[] chunkSizes) {
ByteBuffer inputBuf = ByteBuffer.wrap(input);
byte[][] outputChunks = new byte[chunkSizes.length][];
for (int chunkIdx = 0; chunkIdx < chunkSizes.length; chunkIdx++) {
byte[] chunk = new byte[chunkSizes[chunkIdx]];
inputBuf.get(chunk);
outputChunks[chunkIdx] = chunk;
}
return outputChunks;
}

private enum CheckType {CIPHERTEXT, PLAINTEXT}

private enum OutputType {BYTE_ARRAY, DIRECT_BYTE_BUFFER}
Expand All @@ -110,92 +129,60 @@ private static void check(CheckType checkType, OutputType outputType,
}
}

private static ByteBuffer encryptOrDecryptMultipart(int operation,
OutputType outputType, byte[][] inputChunks, int totalLength)
throws Exception {
Cipher cipher = Cipher.getInstance(ALGORITHM, sunPKCS11);
cipher.init(operation, KEY, IV);
ByteBuffer output = null;
int outOfs = 1;
switch (outputType) {
case BYTE_ARRAY -> {
output = ByteBuffer.allocate(totalLength);
for (byte[] inputChunk : inputChunks) {
output.put(cipher.update(inputChunk));
}
// Check that the output array offset does not affect the
// penultimate block length calculation.
byte[] tmpOut = new byte[cipher.getOutputSize(0) + outOfs];
cipher.doFinal(tmpOut, outOfs);
output.put(tmpOut, outOfs, tmpOut.length - outOfs);
}
case DIRECT_BYTE_BUFFER -> {
output = ByteBuffer.allocateDirect(totalLength);
for (byte[] inputChunk : inputChunks) {
cipher.update(ByteBuffer.wrap(inputChunk), output);
}
// Check that the output array offset does not affect the
// penultimate block length calculation.
ByteBuffer tmpOut = ByteBuffer.allocateDirect(
cipher.getOutputSize(0) + outOfs);
tmpOut.position(outOfs);
cipher.doFinal(ByteBuffer.allocate(0), tmpOut);
tmpOut.position(outOfs);
output.put(tmpOut);
}
}
return output;
}

private static void doMultipart(int... chunkSizes) throws Exception {
int totalLength = IntStream.of(chunkSizes).sum();
byte[][] plaintextChunks = generateChunks(totalLength, chunkSizes);

ByteBuffer jointPlaintextBuf = ByteBuffer.allocate(totalLength);
for (byte[] plaintextChunk : plaintextChunks) {
jointPlaintextBuf.put(plaintextChunk);
}
byte[] jointPlaintext = jointPlaintextBuf.array();
byte[] jointPlaintext = join(plaintextChunks, totalLength);
byte[] expectedCiphertext = computeExpected(jointPlaintext);

// Check that the output array offset does not affect the penultimate
// block length calculation.
int outOfs = 1;

Cipher cipher = Cipher.getInstance(ALGORITHM, sunPKCS11);

// Encryption test, with byte[]
cipher.init(Cipher.ENCRYPT_MODE, KEY, IV);
ByteBuffer actualCiphertextBuf = ByteBuffer.allocate(totalLength);
for (byte[] plaintextChunk : plaintextChunks) {
actualCiphertextBuf.put(cipher.update(plaintextChunk));
byte[][] ciphertextChunks = split(expectedCiphertext, chunkSizes);

for (OutputType outputType : OutputType.values()) {
// Encryption test
check(CheckType.CIPHERTEXT, outputType, expectedCiphertext,
encryptOrDecryptMultipart(Cipher.ENCRYPT_MODE, outputType,
plaintextChunks, totalLength));
// Decryption test
check(CheckType.PLAINTEXT, outputType, jointPlaintext,
encryptOrDecryptMultipart(Cipher.DECRYPT_MODE, outputType,
ciphertextChunks, totalLength));
}

byte [] outArray = new byte[cipher.getOutputSize(0) + outOfs];
cipher.doFinal(outArray, outOfs);
actualCiphertextBuf.put(outArray, outOfs, outArray.length - outOfs);

check(CheckType.CIPHERTEXT, OutputType.BYTE_ARRAY,
expectedCiphertext, actualCiphertextBuf);

// Encryption test, with direct output buffer
cipher.init(Cipher.ENCRYPT_MODE, KEY, IV);
ByteBuffer actualCiphertextDir = ByteBuffer.allocateDirect(totalLength);
for (byte[] plaintextChunk : plaintextChunks) {
cipher.update(ByteBuffer.wrap(plaintextChunk), actualCiphertextDir);
}

ByteBuffer outBuffer = ByteBuffer.allocateDirect(
cipher.getOutputSize(0) + outOfs);
outBuffer.position(outOfs);
cipher.doFinal(ByteBuffer.allocate(0), outBuffer);
outBuffer.position(outOfs);
actualCiphertextDir.put(outBuffer);

check(CheckType.CIPHERTEXT, OutputType.DIRECT_BYTE_BUFFER,
expectedCiphertext, actualCiphertextDir);

// Decryption test, with byte[]
cipher.init(Cipher.DECRYPT_MODE, KEY, IV);
ByteBuffer actualPlaintextBuf = ByteBuffer.allocate(totalLength);
actualCiphertextBuf.position(0);
for (byte[] plaintextChunk : plaintextChunks) {
// Use the same chunk sizes as the plaintext
byte[] ciphertextChunk = new byte[plaintextChunk.length];
actualCiphertextBuf.get(ciphertextChunk);
actualPlaintextBuf.put(cipher.update(ciphertextChunk));
}

outArray = new byte[cipher.getOutputSize(0) + outOfs];
cipher.doFinal(outArray, outOfs);
actualPlaintextBuf.put(outArray, outOfs, outArray.length - outOfs);

check(CheckType.PLAINTEXT, OutputType.BYTE_ARRAY,
jointPlaintext, actualPlaintextBuf);

// Decryption test, with direct output buffer
cipher.init(Cipher.DECRYPT_MODE, KEY, IV);
ByteBuffer actualPlaintextDir = ByteBuffer.allocateDirect(totalLength);
actualCiphertextBuf.position(0);
for (byte[] plaintextChunk : plaintextChunks) {
// Use the same chunk sizes as the plaintext
byte[] ciphertextChunk = new byte[plaintextChunk.length];
actualCiphertextBuf.get(ciphertextChunk);
cipher.update(ByteBuffer.wrap(ciphertextChunk), actualPlaintextDir);
}

outBuffer = ByteBuffer.allocateDirect(
cipher.getOutputSize(0) + outOfs);
outBuffer.position(outOfs);
cipher.doFinal(ByteBuffer.allocate(0), outBuffer);
outBuffer.position(outOfs);
actualPlaintextDir.put(outBuffer);

check(CheckType.PLAINTEXT, OutputType.DIRECT_BYTE_BUFFER,
jointPlaintext, actualPlaintextDir);
}

private static String repr(byte[] data) {
Expand Down

0 comments on commit 33a64b2

Please sign in to comment.