Skip to content

Commit

Permalink
Merge pull request #11 from xmtp/nm/add-group-extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas authored Jan 9, 2024
2 parents d327edf + b0e3742 commit 9b6b977
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 7 deletions.
8 changes: 8 additions & 0 deletions openmls/src/group/core_group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ impl CoreGroupBuilder {
self
}

/// Set the `group_context_extensions` of the [`CoreGroup`].
pub fn with_group_context_extensions(mut self, extensions: Extensions) -> Self {
self.public_group_builder = self
.public_group_builder
.with_group_context_extensions(extensions);
self
}

/// Build the [`CoreGroup`].
/// Any values that haven't been set in the builder are set to their default
/// values (which might be random).
Expand Down
13 changes: 13 additions & 0 deletions openmls/src/group/mls_group/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ pub struct MlsGroupConfig {
pub(crate) lifetime: Lifetime,
/// Ciphersuite and protocol version
pub(crate) crypto_config: CryptoConfig,
// Other extensions
pub(crate) group_context_extensions: Extensions,
}

impl MlsGroupConfig {
Expand Down Expand Up @@ -118,6 +120,11 @@ impl MlsGroupConfig {
&self.crypto_config
}

/// Set the `group_context_extensions` property of the MlsGroupConfig.
pub fn group_context_extensions(&self) -> &Extensions {
&self.group_context_extensions
}

#[cfg(any(feature = "test-utils", test))]
pub fn test_default(ciphersuite: Ciphersuite) -> Self {
Self::builder()
Expand Down Expand Up @@ -220,6 +227,12 @@ impl MlsGroupConfigBuilder {
self
}

/// Sets the `group_context_extensions` property of the MlsGroupConfig.
pub fn group_context_extensions(mut self, extensions: Extensions) -> Self {
self.config.group_context_extensions = extensions;
self
}

/// Finalizes the builder and retursn an `[MlsGroupConfig`].
pub fn build(self) -> MlsGroupConfig {
self.config
Expand Down
1 change: 1 addition & 0 deletions openmls/src/group/mls_group/creation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ impl MlsGroup {
credential_with_key,
)
.with_config(group_config)
.with_group_context_extensions(mls_group_config.group_context_extensions.clone())
.with_required_capabilities(mls_group_config.required_capabilities.clone())
.with_external_senders(mls_group_config.external_senders.clone())
.with_max_past_epoch_secrets(mls_group_config.max_past_epochs)
Expand Down
59 changes: 59 additions & 0 deletions openmls/src/group/mls_group/test_mls_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,3 +645,62 @@ fn remove_prosposal_by_ref(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv
_ => unreachable!("Expected a StagedCommit."),
}
}

#[apply(ciphersuites_and_providers)]
fn test_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) {
let group_id = GroupId::from_slice(b"Test Group");
let application_id = b"Test App ID";
let metadata = vec![1, 2, 3];

let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) =
setup_client("Alice", ciphersuite, provider);

// Define the MlsGroup configuration
let mls_group_config = MlsGroupConfig::builder()
.wire_format_policy(WireFormatPolicy::new(
OutgoingWireFormatPolicy::AlwaysPlaintext,
IncomingWireFormatPolicy::Mixed,
))
.crypto_config(CryptoConfig::with_default_version(ciphersuite))
.group_context_extensions(Extensions::single(Extension::ProtectedMetadata(
ProtectedMetadata::new(
&alice_signer,
application_id.to_vec(),
alice_credential_with_key.credential.clone(),
alice_credential_with_key.signature_key.as_slice().to_vec(),
metadata,
)
.unwrap(),
)))
.build();

// === Alice creates a group ===
let mut alice_group = MlsGroup::new_with_group_id(
provider,
&alice_signer,
&mls_group_config,
group_id.clone(),
alice_credential_with_key,
)
.expect("An unexpected error occurred.");

assert!(alice_group
.export_group_context()
.extensions()
.contains(ExtensionType::ProtectedMetadata));

// Check the internal state has changed
assert_eq!(alice_group.state_changed(), InnerState::Changed);

alice_group
.save(provider.key_store())
.expect("Could not write group state to file");

let alice_group_deserialized =
MlsGroup::load(&group_id, provider.key_store()).expect("Could not deserialize MlsGroup");

assert!(alice_group_deserialized
.export_group_context()
.extensions()
.contains(ExtensionType::ProtectedMetadata));
}
26 changes: 19 additions & 7 deletions openmls/src/group/public_group/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub(crate) struct TempBuilderPG1 {
required_capabilities: Option<RequiredCapabilitiesExtension>,
external_senders: Option<ExternalSendersExtension>,
leaf_extensions: Option<Extensions>,
group_context_extensions: Option<Extensions>,
}

impl TempBuilderPG1 {
Expand All @@ -34,6 +35,11 @@ impl TempBuilderPG1 {
self
}

pub(crate) fn with_group_context_extensions(mut self, extensions: Extensions) -> Self {
self.group_context_extensions = Some(extensions);
self
}

pub(crate) fn with_required_capabilities(
mut self,
required_capabilities: RequiredCapabilitiesExtension,
Expand Down Expand Up @@ -87,17 +93,22 @@ impl TempBuilderPG1 {
_ => LibraryError::custom("Unexpected ExtensionError").into(),
})?;
let required_capabilities = Extension::RequiredCapabilities(required_capabilities);
let extensions =
if let Some(ext_senders) = self.external_senders.map(Extension::ExternalSenders) {
vec![required_capabilities, ext_senders]
} else {
vec![required_capabilities]
};

let mut extensions = Extensions::from_vec(vec![required_capabilities])?;
if let Some(ext_senders) = self.external_senders.map(Extension::ExternalSenders) {
extensions.add(ext_senders)?;
}
if let Some(group_context_extensions) = self.group_context_extensions {
for extension in group_context_extensions.iter() {
extensions.add(extension.clone())?;
}
}

let group_context = GroupContext::create_initial_group_context(
self.crypto_config.ciphersuite,
self.group_id,
treesync.tree_hash().to_vec(),
Extensions::from_vec(extensions)?,
extensions,
);
let next_builder = TempBuilderPG2 {
treesync,
Expand Down Expand Up @@ -172,6 +183,7 @@ impl PublicGroup {
required_capabilities: None,
external_senders: None,
leaf_extensions: None,
group_context_extensions: None,
}
}
}

0 comments on commit 9b6b977

Please sign in to comment.