Skip to content

Commit

Permalink
Enable Wasm Tests (openmls#1483)
Browse files Browse the repository at this point in the history
Co-authored-by: Jan Winkelmann (keks) <[email protected]>
  • Loading branch information
keks and keks authored Jan 25, 2024
1 parent 303967c commit 5052cbe
Show file tree
Hide file tree
Showing 24 changed files with 102 additions and 30 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,9 @@ jobs:
else
echo "TEST_MODE=--release" >> $GITHUB_ENV
fi
- name: Build (wasm)
if: matrix.arch == 'wasm32-unknown-unknown'
run: cargo build $TEST_MODE --verbose --target ${{ matrix.arch }} -p openmls -F js
- name: Build
if: ${{ matrix.arch != 'wasm32-unknown-unknown' }}
run: cargo build $TEST_MODE --verbose --target ${{ matrix.arch }} -p openmls
10 changes: 9 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
ref: ${{ github.event.pull_request.head.sha }}
- uses: dtolnay/rust-toolchain@stable
with:
targets: i686-pc-windows-msvc, i686-unknown-linux-gnu
targets: i686-pc-windows-msvc, i686-unknown-linux-gnu, wasm32-unknown-unknown
- uses: Swatinem/rust-cache@v2

- name: Toggle rustc mode
Expand All @@ -44,6 +44,14 @@ jobs:
else
echo "TEST_MODE=--release" >> $GITHUB_ENV
fi
- name: Tests Wasm32 on linux
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt update && sudo apt install nodejs
cargo install wasm-bindgen-cli
export CARGO_TARGET_WASM32_UNKNOWN_UNKNOWN_RUNNER=$HOME/.cargo/bin/wasm-bindgen-test-runner
cargo test $TEST_MODE -p openmls -vv --target wasm32-unknown-unknown -F js
- name: Tests
if: matrix.os != 'windows-latest'
run: cargo test $TEST_MODE -p openmls --verbose
Expand Down
1 change: 1 addition & 0 deletions book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- [Processing incoming messages](user_manual/processing.md)
- [Persistence of group state](user_manual/persistence.md)
- [Credential validation](user_manual/credential_validation.md)
- [WebAssembly](user_manual/wasm.md)
- [Traits & External Types](./traits/README.md)
- [Traits](./traits/traits.md)
- [Types](./traits/types.md)
Expand Down
4 changes: 4 additions & 0 deletions book/src/user_manual/wasm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# WebAssembly

OpenMLS can be built for WebAssembly. However, it does require two features that WebAssembly itself does not provide: access to secure randomness and the current time. Currently, this means that it can only run in a runtime that provides common JavaScript APIs (e.g. in the browser or node.js), accessed through the `web_sys` crate.
You can enable the `js` feature on the `openmls` crate to signal that the APIs are available.
9 changes: 8 additions & 1 deletion openmls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ openmls_basic_credential = { version = "0.2.0", path = "../basic_credential", op
] }
rstest = { version = "^0.16", optional = true }
rstest_reuse = { version = "0.4", optional = true }
wasm-bindgen-test = {version = "0.3.40", optional = true}
getrandom = {version = "0.2.12", optional = true, features = [ "js" ]}
fluvio-wasm-timer = {version = "0.2.5", optional = true}

[features]
default = ["backtrace"]
Expand All @@ -41,14 +44,16 @@ test-utils = [
"dep:rand",
"dep:rstest",
"dep:rstest_reuse",
"dep:wasm-bindgen-test",
"dep:openmls_basic_credential",
]
crypto-debug = [] # ☣️ Enable logging of sensitive cryptographic information
content-debug = [] # ☣️ Enable logging of sensitive message content
js = ["dep:getrandom", "dep:fluvio-wasm-timer"] # enable js randomness source for provider

[dev-dependencies]
backtrace = "0.3"
criterion = "^0.5"
criterion = {version = "^0.5", default-features = false} # need to disable default features for wasm
hex = { version = "0.4", features = ["serde"] }
itertools = "0.10"
lazy_static = "1.4"
Expand All @@ -60,6 +65,8 @@ pretty_env_logger = "0.5"
rstest = "^0.16"
rstest_reuse = "0.4"
tempfile = "3"
wasm-bindgen = "0.2.90"
wasm-bindgen-test = "0.3.40"

[[bench]]
name = "benchmark"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ pub fn run_test_vector(test_vector: TreeMathTestVector) -> Result<(), TmTestVect

#[test]
fn read_test_vectors_tm() {
let tests: Vec<TreeMathTestVector> = read("test_vectors/tree-math.json");
let tests: Vec<TreeMathTestVector> = read_json!("../../../test_vectors/tree-math.json");
for test_vector in tests {
match run_test_vector(test_vector) {
Ok(_) => {}
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/ciphersuite/tests/kat_crypto_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ fn read_test_vectors() {

let provider = OpenMlsRustCrypto::default();

let tests: Vec<CryptoBasicsTestCase> = read("test_vectors/crypto-basics.json");
let tests: Vec<CryptoBasicsTestCase> = read_json!("../../../test_vectors/crypto-basics.json");
for test in tests {
match run_test_vector(test, &provider) {
Ok(_) => {}
Expand Down
9 changes: 5 additions & 4 deletions openmls/src/group/core_group/test_core_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ pub(crate) fn setup_alice_group(
fn test_core_group_persistence(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) {
let (alice_group, _, _, _) = setup_alice_group(ciphersuite, provider);

let mut file_out = tempfile::NamedTempFile::new().expect("Could not create file");
// we need something that implements io::Write
let mut file_out: Vec<u8> = vec![];
alice_group
.save(&mut file_out)
.expect("Could not write group state to file");

let file_in = file_out
.reopen()
.expect("Error re-opening serialized group state file");
// make it into a type that implements io::Read
let file_in: &[u8] = &file_out;

let alice_group_deserialized =
CoreGroup::load(file_in).expect("Could not deserialize mls group");

Expand Down
2 changes: 1 addition & 1 deletion openmls/src/group/tests/kat_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ pub fn run_test_vector(tv: MessagesTestVector) -> Result<(), EncodingMismatch> {

#[test]
fn read_test_vectors_messages() {
let tests: Vec<MessagesTestVector> = read("test_vectors/messages.json");
let tests: Vec<MessagesTestVector> = read_json!("../../../test_vectors/messages.json");

for test_vector in tests {
match run_test_vector(test_vector) {
Expand Down
4 changes: 1 addition & 3 deletions openmls/src/kat_vl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
use serde::Deserialize;
use tls_codec::{Deserialize as TlsDeserialize, VLBytes};

use crate::test_utils::read;

#[derive(Deserialize)]
struct TestElement {
#[serde(with = "hex")]
Expand Down Expand Up @@ -57,7 +55,7 @@ fn read_test_vectors_deserialize() {
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");

let tests: Vec<TestElement> = read("test_vectors/deserialization.json");
let tests: Vec<TestElement> = read_json!("../test_vectors/deserialization.json");

for test_vector in tests {
match run_test_vector(test_vector) {
Expand Down
3 changes: 3 additions & 0 deletions openmls/src/key_packages/lifetime.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(target_arch = "wasm32")]
use fluvio_wasm_timer::{SystemTime, UNIX_EPOCH};
#[cfg(not(target_arch = "wasm32"))]
use std::time::{SystemTime, UNIX_EPOCH};

use serde::{Deserialize, Serialize};
Expand Down
9 changes: 9 additions & 0 deletions openmls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@
target_pointer_width = "128"
))]

#[cfg(all(target_arch = "wasm32", not(feature = "js")))]
compile_error!("In order for OpenMLS to build for WebAssembly, JavaScript APIs must be available (for access to secure randomness and the current time). This can be signalled by setting the `js` feature on OpenMLS.");

// === Testing ===

/// Single place, re-exporting all structs and functions needed for integration tests
Expand Down Expand Up @@ -187,3 +190,9 @@ mod tree;

/// Single place, re-exporting the most used public functions.
pub mod prelude;

// this is a workaround, see https://github.com/la10736/rstest/issues/211#issuecomment-1701238125
#[cfg(any(test, feature = "test-utils"))]
pub mod wasm {
pub use wasm_bindgen_test::wasm_bindgen_test as test;
}
4 changes: 2 additions & 2 deletions openmls/src/schedule/kat_key_schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tls_codec::Serialize as TlsSerializeTrait;

use super::{errors::KsTestVectorError, CommitSecret};
#[cfg(test)]
use crate::test_utils::{read, write};
use crate::test_utils::write;
use crate::{ciphersuite::*, extensions::Extensions, group::*, schedule::*, test_utils::*};

#[derive(Serialize, Deserialize, Debug, Clone, Default)]
Expand Down Expand Up @@ -258,7 +258,7 @@ fn write_test_vectors() {
fn read_test_vectors_key_schedule(provider: &impl OpenMlsProvider) {
let _ = pretty_env_logger::try_init();

let tests: Vec<KeyScheduleTestVector> = read("test_vectors/key-schedule.json");
let tests: Vec<KeyScheduleTestVector> = read_json!("../../test_vectors/key-schedule.json");

for test_vector in tests {
match run_test_vector(test_vector, provider) {
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/schedule/kat_psk_secret.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ fn read_test_vectors_ps(provider: &impl OpenMlsProvider) {
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");

let tests: Vec<TestElement> = read("test_vectors/psk_secret.json");
let tests: Vec<TestElement> = read_json!("../../test_vectors/psk_secret.json");

for test_vector in tests {
match run_test_vector(test_vector, provider) {
Expand Down
12 changes: 12 additions & 0 deletions openmls/src/test_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ pub(crate) fn write(file_name: &str, obj: impl Serialize) {
.expect("Error writing test vector file");
}

// the macro is used in other files, suppress false positive
#[allow(unused_macros)]
macro_rules! read_json {
($file_name:expr) => {{
let data = include_str!($file_name);
serde_json::from_str(data).expect(&format!("Error reading file {}", $file_name))
}};
}

pub(crate) fn read<T: DeserializeOwned>(file_name: &str) -> T {
let file = match File::open(file_name) {
Ok(f) => f,
Expand Down Expand Up @@ -212,6 +221,7 @@ pub use openmls_rust_crypto::OpenMlsRustCrypto;
)
]
#[allow(non_snake_case)]
#[cfg_attr(target_arch = "wasm32", openmls::wasm::test)]
pub fn providers(provider: &impl OpenMlsProvider) {}

// === Ciphersuites ===
Expand All @@ -233,6 +243,7 @@ pub fn providers(provider: &impl OpenMlsProvider) {}
)
)]
#[allow(non_snake_case)]
#[cfg_attr(target_arch = "wasm32", openmls::wasm::test)]
pub fn ciphersuites(ciphersuite: Ciphersuite) {}

// === Ciphersuites & providers ===
Expand All @@ -246,4 +257,5 @@ pub fn ciphersuites(ciphersuite: Ciphersuite) {}
)
]
#[allow(non_snake_case)]
#[cfg_attr(target_arch = "wasm32", openmls::wasm::test)]
pub fn ciphersuites_and_providers(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) {}
15 changes: 11 additions & 4 deletions openmls/src/test_utils/test_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ use openmls_traits::{
types::{Ciphersuite, HpkeKeyPair, SignatureScheme},
OpenMlsProvider,
};
use rayon::prelude::*;

use std::{collections::HashMap, sync::RwLock};
use tls_codec::*;

#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;

pub mod client;
pub mod errors;

Expand Down Expand Up @@ -363,9 +366,13 @@ impl MlsGroupTestSetup {
authentication_service: AS,
) {
let clients = self.clients.read().expect("An unexpected error occurred.");
let messages = group
.members
.par_iter()

#[cfg(not(target_arch = "wasm32"))]
let group_members = group.members.par_iter();
#[cfg(target_arch = "wasm32")]
let group_members = group.members.iter();

let messages = group_members
.filter_map(|(_, m_id)| {
let m = clients
.get(m_id)
Expand Down
3 changes: 2 additions & 1 deletion openmls/src/tree/tests_and_kats/kats/kat_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,8 @@ fn read_test_vectors_encryption(provider: &impl OpenMlsProvider) {
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");

let tests: Vec<EncryptionTestVector> = read("test_vectors/kat_encryption_openmls.json");
let tests: Vec<EncryptionTestVector> =
read_json!("../../../../test_vectors/kat_encryption_openmls.json");

for test_vector in tests {
match run_test_vector(test_vector, provider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,8 @@ fn read_test_vectors_mp(provider: &impl OpenMlsProvider) {
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");

let tests: Vec<MessageProtectionTest> = read("test_vectors/message-protection.json");
let tests: Vec<MessageProtectionTest> =
read_json!("../../../../test_vectors/message-protection.json");

for test_vector in tests {
match run_test_vector(test_vector, provider) {
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/tree/tests_and_kats/kats/secret_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fn read_test_vectors_st(provider: &impl OpenMlsProvider) {
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");

let tests: Vec<SecretTree> = read("test_vectors/secret-tree.json");
let tests: Vec<SecretTree> = read_json!("../../../../test_vectors/secret-tree.json");

for test_vector in tests {
match run_test_vector(test_vector, provider) {
Expand Down
14 changes: 12 additions & 2 deletions openmls/src/treesync/node/parent_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! [`UpdatePathNode`] instances.
use openmls_traits::crypto::OpenMlsCrypto;
use openmls_traits::types::{Ciphersuite, HpkeCiphertext};
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use thiserror::*;
Expand Down Expand Up @@ -66,8 +67,12 @@ impl PlainUpdatePathNode {
public_keys: &[EncryptionKey],
group_context: &[u8],
) -> Result<UpdatePathNode, LibraryError> {
#[cfg(target_arch = "wasm32")]
let public_keys = public_keys.iter();
#[cfg(not(target_arch = "wasm32"))]
let public_keys = public_keys.par_iter();

public_keys
.par_iter()
.map(|pk| {
self.path_secret
.encrypt(crypto, ciphersuite, pk, group_context)
Expand Down Expand Up @@ -131,8 +136,13 @@ impl ParentNode {
);

// Iterate over the path secrets and derive a key pair

#[cfg(not(target_arch = "wasm32"))]
let path_secrets = path_secrets.into_par_iter();
#[cfg(target_arch = "wasm32")]
let path_secrets = path_secrets.into_iter();

let (path_with_keypairs, update_path_nodes): PathDerivationResults = path_secrets
.into_par_iter()
.zip(path_indices)
.map(|(path_secret, index)| {
// Derive a key pair from the path secret. This includes the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ fn read_test_vectors_tree_operations(provider: &impl OpenMlsProvider) {
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");

let tests: Vec<TestElement> = read("test_vectors/tree-operations.json");
let tests: Vec<TestElement> = read_json!("../../../../test_vectors/tree-operations.json");

for test_vector in tests {
match run_test_vector(test_vector, provider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ fn read_test_vectors_tree_validation(provider: &impl OpenMlsProvider) {
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");

let tests: Vec<TestElement> = read("test_vectors/tree-validation.json");
let tests: Vec<TestElement> = read_json!("../../../../test_vectors/tree-validation.json");

for test_vector in tests {
match run_test_vector(test_vector, provider) {
Expand Down
4 changes: 2 additions & 2 deletions openmls/src/treesync/tests_and_kats/kats/kat_treekem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
messages::PathSecret,
prelude_test::Secret,
schedule::CommitSecret,
test_utils::{hex_to_bytes, read},
test_utils::hex_to_bytes,
treesync::{
node::encryption_keys::EncryptionKeyPair,
treekem::{DecryptPathParams, UpdatePath, UpdatePathIn},
Expand Down Expand Up @@ -385,7 +385,7 @@ fn apply_update_path(
#[test]
fn read_test_vectors_treekem() {
let _ = pretty_env_logger::try_init();
let tests: Vec<TreeKemTest> = read("test_vectors/treekem.json");
let tests: Vec<TreeKemTest> = read_json!("../../../../test_vectors/treekem.json");

let provider = OpenMlsRustCrypto::default();

Expand Down
10 changes: 8 additions & 2 deletions openmls/src/treesync/treekem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use openmls_traits::{
crypto::OpenMlsCrypto,
types::{Ciphersuite, HpkeCiphertext},
};
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
Expand Down Expand Up @@ -70,8 +71,13 @@ impl<'a> TreeSyncDiff<'a> {
debug_assert_eq!(copath_resolutions.len(), path.len());

// Encrypt the secrets
path.par_iter()
.zip(copath_resolutions.par_iter())

#[cfg(not(target_arch = "wasm32"))]
let resolved_path = path.par_iter().zip(copath_resolutions.par_iter());
#[cfg(target_arch = "wasm32")]
let resolved_path = path.iter().zip(copath_resolutions.iter());

resolved_path
.map(|(node, resolution)| node.encrypt(crypto, ciphersuite, resolution, group_context))
.collect::<Result<Vec<UpdatePathNode>, LibraryError>>()
}
Expand Down

0 comments on commit 5052cbe

Please sign in to comment.