Skip to content

Commit

Permalink
Added test and new update_group_context_extensions function
Browse files Browse the repository at this point in the history
  • Loading branch information
cameronvoell committed Apr 9, 2024
1 parent fa33f5e commit c6bc4cd
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 4 deletions.
2 changes: 2 additions & 0 deletions openmls/src/group/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ pub enum CreateGroupContextExtProposalError {
/// See [`LeafNodeValidationError`] for more details.
#[error(transparent)]
LeafNodeValidation(#[from] LeafNodeValidationError),
#[error(transparent)]
GroupStateError(#[from] MlsGroupStateError),
}

/// Error merging a commit.
Expand Down
52 changes: 49 additions & 3 deletions openmls/src/group/mls_group/proposal.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use core_group::create_commit_params::CreateCommitParams;
use openmls_traits::{
key_store::OpenMlsKeyStore, signatures::Signer, types::Ciphersuite, OpenMlsProvider,
};

use super::{
errors::{ProposalError, ProposeAddMemberError, ProposeRemoveMemberError},
MlsGroup,
core_group, errors::{ProposalError, ProposeAddMemberError, ProposeRemoveMemberError}, CreateGroupContextExtProposalError, GroupContextExtensionProposal, GroupContextExtensionsProposalValidationError, MlsGroup, MlsGroupState, PendingCommitState, Proposal
};
use crate::{
binary_tree::LeafNodeIndex,
Expand All @@ -14,7 +14,7 @@ use crate::{
framing::MlsMessageOut,
group::{errors::CreateAddProposalError, GroupId, QueuedProposal},
key_packages::KeyPackage,
messages::proposals::ProposalOrRefType,
messages::{group_info::GroupInfo, proposals::ProposalOrRefType},
prelude::LibraryError,
schedule::PreSharedKeyId,
treesync::LeafNode,
Expand Down Expand Up @@ -349,4 +349,50 @@ impl MlsGroup {

Ok((mls_message, proposal_ref))
}

pub fn update_group_context_extensions(
&mut self,
provider: &impl OpenMlsProvider,
extensions: Extensions,
signer: &impl Signer,
) -> Result<(MlsMessageOut, MlsMessageOut, Option<GroupInfo>), CreateGroupContextExtProposalError>
{
self.is_operational()?;

// if key_packages.is_empty() {
// return Err(CreateGroupContextExtProposalError::EmptyInput(EmptyInputError::AddMembers));
// }

// Create inline add proposals from key packages
let mut inline_proposals = vec![];
inline_proposals.push(Proposal::GroupContextExtensions(GroupContextExtensionProposal {
extensions,
}));

let params = CreateCommitParams::builder()
.framing_parameters(self.framing_parameters())
.proposal_store(&self.proposal_store)
.inline_proposals(inline_proposals)
.build();
let create_commit_result = self.group.create_commit(params, provider, signer).unwrap();
let welcome = match create_commit_result.welcome_option {
Some(welcome) => welcome,
None => {
return Err(LibraryError::custom("No secrets to generate commit message.").into())
}
};
let mls_messages = self.content_to_mls_message(create_commit_result.commit, provider)?;
self.group_state = MlsGroupState::PendingCommit(Box::new(PendingCommitState::Member(
create_commit_result.staged_commit,
)));

// Since the state of the group might be changed, arm the state flag
self.flag_state_change();

Ok((
mls_messages,
MlsMessageOut::from_welcome(welcome, self.group.version()),
create_commit_result.group_info,
))
}
}
163 changes: 163 additions & 0 deletions openmls/src/group/mls_group/test_mls_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,169 @@ fn unknown_extensions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider)
.expect("Error creating group from staged join");
}

// Test the successful update of Group Context Extension with type Extension::Unknown(0xff11)
#[apply(ciphersuites_and_providers)]
fn update_group_context_with_unknown_extension(
ciphersuite: Ciphersuite,
provider: &impl OpenMlsProvider,
) {
let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) =
setup_client("Alice", ciphersuite, provider);

// === Define the unknown group context extension and initial data ===
let unknown_extension_data = vec![1, 2];
let unknown_gc_extension = Extension::Unknown(0xff11, UnknownExtension(unknown_extension_data));
let required_extension_types = &[ExtensionType::Unknown(0xff11)];
let required_capabilities = Extension::RequiredCapabilities(
RequiredCapabilitiesExtension::new(required_extension_types, &[], &[]),
);
let capabilities = Capabilities::new(None, None, Some(required_extension_types), None, None);
let test_gc_extensions = Extensions::from_vec(vec![
unknown_gc_extension.clone(),
required_capabilities.clone(),
])
.expect("error creating test group context extensions");
let mls_group_create_config = MlsGroupCreateConfig::builder()
.with_group_context_extensions(test_gc_extensions.clone())
.expect("error adding unknown extension to config")
.capabilities(capabilities.clone())
.crypto_config(CryptoConfig::with_default_version(ciphersuite))
.build();

// === Alice creates a group ===
let mut alice_group = MlsGroup::new(
provider,
&alice_signer,
&mls_group_create_config,
alice_credential_with_key,
)
.expect("error creating group");

// === Verify the initial group context extension data is correct ===
let group_context_extensions = alice_group.group().context().extensions();
let mut extracted_data = None;
for extension in group_context_extensions.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data = Some(data.clone());
}
}
assert_eq!(
extracted_data.unwrap(),
vec![1, 2],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);

// === Alice adds Bob ===
let (bob_credential_with_key, _bob_kpb, bob_signer, _bob_pk) =
setup_client("Bob", ciphersuite, provider);

let bob_key_package = KeyPackage::builder()
.leaf_node_capabilities(capabilities)
.build(
CryptoConfig::with_default_version(ciphersuite),
provider,
&bob_signer,
bob_credential_with_key,
)
.expect("error building key package");

let (_, welcome, _) = alice_group
.add_members(provider, &alice_signer, &[bob_key_package.clone()])
.unwrap();
alice_group.merge_pending_commit(provider).unwrap();

let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");

let bob_group = StagedWelcome::new_from_welcome(
provider,
&MlsGroupJoinConfig::default(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating staged join from Welcome")
.into_group(provider)
.expect("Error creating group from staged join");

// === Verify Bob's initial group context extension data is correct ===
let group_context_extensions = bob_group.group().context().extensions();
let mut extracted_data_2 = None;
for extension in group_context_extensions.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data_2 = Some(data.clone());
}
}
assert_eq!(
extracted_data_2.unwrap(),
vec![1, 2],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);

// === Propose the new group context extension ===
let updated_unknown_extension_data = vec![3, 4]; // Sample data for the extension
let updated_unknown_gc_extension = Extension::Unknown(
0xff11,
UnknownExtension(updated_unknown_extension_data.clone()),
);

let mut updated_extensions = test_gc_extensions.clone();
updated_extensions.add_or_replace(updated_unknown_gc_extension);
alice_group
.propose_group_context_extensions(provider, updated_extensions, &alice_signer)
.expect("failed to propose group context extensions with unknown extension");

assert_eq!(
alice_group.pending_proposals().count(),
1,
"Expected one pending proposal"
);

// === Commit to the proposed group context extension ===
alice_group
.commit_to_pending_proposals(provider, &alice_signer)
.expect("failed to commit to pending group context extensions");

alice_group
.merge_pending_commit(provider)
.expect("error merging pending commit");

alice_group
.save(provider.key_store())
.expect("error saving group");

// === Verify the group context extension was updated ===
let group_context_extensions = alice_group.group().context().extensions();
let mut extracted_data_updated = None;
for extension in group_context_extensions.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data_updated = Some(data.clone());
}
}
assert_eq!(
extracted_data_updated.unwrap(),
vec![3, 4],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);

// === Verify Bob sees the group context extension updated ===
let bob_group_loaded = MlsGroup::load(bob_group.group().group_id(), provider.key_store())
.expect("error loading group");
let group_context_extensions_2 = bob_group_loaded.export_group_context().extensions();
let mut extracted_data_2 = None;
for extension in group_context_extensions_2.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data_2 = Some(data.clone());
}
}
assert_eq!(
extracted_data_2.unwrap(),
vec![3, 4],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);
}

#[apply(ciphersuites_and_providers)]
fn join_multiple_groups_last_resort_extension(
ciphersuite: Ciphersuite,
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/messages/proposals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ pub struct AppAckProposal {
TlsSize,
)]
pub struct GroupContextExtensionProposal {
extensions: Extensions,
pub(crate) extensions: Extensions,
}

impl GroupContextExtensionProposal {
Expand Down

0 comments on commit c6bc4cd

Please sign in to comment.