Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write the context proposal #28

Merged
merged 8 commits into from
May 16, 2024
2 changes: 1 addition & 1 deletion basic_credential/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl SignatureKeyPair {
}
}

fn id(&self) -> StorageId {
pub fn id(&self) -> StorageId {
StorageId {
value: id(&self.public, self.signature_scheme),
}
Expand Down
14 changes: 10 additions & 4 deletions memory_storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,12 +803,18 @@ impl StorageProvider<CURRENT_VERSION> for MemoryStorage {
&self,
group_id: &GroupId,
) -> Result<(), Self::Error> {
// Get all proposal refs for this group.
let proposal_refs: Vec<ProposalRef> =
self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &serde_json::to_vec(group_id)?)?;
let mut values = self.values.write().unwrap();
for proposal_ref in proposal_refs {
// Delete all proposals.
let key = serde_json::to_vec(&(group_id, proposal_ref))?;
values.remove(&key);
}

let key = build_key::<CURRENT_VERSION, &GroupId>(QUEUED_PROPOSAL_LABEL, group_id);

// XXX #1566: also remove the proposal refs. can't be done now because they are stored in a
// non-recoverable way
// Delete the proposal refs from the store.
let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id);
values.remove(&key);

Ok(())
Expand Down
81 changes: 81 additions & 0 deletions memory_storage/tests/proposals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use openmls_memory_storage::MemoryStorage;
use openmls_traits::storage::{
traits::{self},
Entity, Key, StorageProvider, CURRENT_VERSION,
};
use serde::{Deserialize, Serialize};

// Test types
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
struct TestGroupId(Vec<u8>);
impl traits::GroupId<CURRENT_VERSION> for TestGroupId {}
impl Key<CURRENT_VERSION> for TestGroupId {}

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)]
struct ProposalRef(usize);
impl traits::ProposalRef<CURRENT_VERSION> for ProposalRef {}
impl Key<CURRENT_VERSION> for ProposalRef {}
impl Entity<CURRENT_VERSION> for ProposalRef {}

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
struct Proposal(Vec<u8>);
impl traits::QueuedProposal<CURRENT_VERSION> for Proposal {}
impl Entity<CURRENT_VERSION> for Proposal {}

/// Write and read some proposals
#[test]
fn read_write_delete() {
let group_id = TestGroupId(b"TestGroupId".to_vec());
let proposals = (0..10)
.map(|i| Proposal(format!("TestProposal{i}").as_bytes().to_vec()))
.collect::<Vec<_>>();
let storage = MemoryStorage::default();

// Store proposals
for (i, proposal) in proposals.iter().enumerate() {
storage
.queue_proposal(&group_id, &ProposalRef(i), proposal)
.unwrap();
}

// Read proposal refs
let proposal_refs_read: Vec<ProposalRef> = storage.queued_proposal_refs(&group_id).unwrap();
assert_eq!(
(0..10).map(|i| ProposalRef(i)).collect::<Vec<_>>(),
proposal_refs_read
);

// Read proposals
let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap();
let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
.map(|i| ProposalRef(i))
.zip(proposals.clone().into_iter())
.collect();
assert_eq!(proposals_expected, proposals_read);

// Remove proposal 5
storage.remove_proposal(&group_id, &ProposalRef(5)).unwrap();

let proposal_refs_read: Vec<ProposalRef> = storage.queued_proposal_refs(&group_id).unwrap();
let mut expected = (0..10).map(|i| ProposalRef(i)).collect::<Vec<_>>();
expected.remove(5);
assert_eq!(expected, proposal_refs_read);

let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap();
let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
.map(|i| ProposalRef(i))
.zip(proposals.clone().into_iter())
.collect();
proposals_expected.remove(5);
assert_eq!(proposals_expected, proposals_read);

// Clear all proposals
storage
.clear_proposal_queue::<TestGroupId, ProposalRef>(&group_id)
.unwrap();
let proposal_refs_read: Vec<ProposalRef> = storage.queued_proposal_refs(&group_id).unwrap();
assert!(proposal_refs_read.is_empty());

let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap();
assert!(proposals_read.is_empty());
}
3 changes: 3 additions & 0 deletions openmls/src/group/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,9 @@ pub enum CreateGroupContextExtProposalError<StorageError> {
/// See [`CreateCommitError`] for more details.
#[error(transparent)]
CreateCommitError(#[from] CreateCommitError<StorageError>),
/// Error writing updated group to storage.
#[error("Error writing updated group data to storage.")]
StorageError(StorageError),
}

/// Error merging a commit.
Expand Down
5 changes: 5 additions & 0 deletions openmls/src/group/mls_group/proposal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ impl MlsGroup {
create_commit_result.staged_commit,
)));

provider
.storage()
.write_group_state(self.group_id(), &self.group_state)
.map_err(CreateGroupContextExtProposalError::StorageError)?;

Ok((
mls_messages,
create_commit_result
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub trait OpenMlsProvider:
}

impl<
Error: std::error::Error + PartialEq,
Error: std::error::Error,
SP: StorageProvider<Error = Error>,
OP: openmls_traits::OpenMlsProvider<StorageProvider = SP>,
> OpenMlsProvider for OP
Expand Down
Loading