Skip to content

Commit

Permalink
Check if address/installation is in association state (#1336)
Browse files Browse the repository at this point in the history
* unintentional rabbit hole

* wip

* cleanup

* lint

* lint

* cleanup

* more lint

* take an inbox_id param

* lint
  • Loading branch information
codabrink authored Nov 25, 2024
1 parent e498595 commit faae68c
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 80 deletions.
5 changes: 4 additions & 1 deletion bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ impl FfiXmtpClient {
) -> Result<Vec<FfiInboxState>, GenericError> {
let state = self
.inner_client
.inbox_addresses(refresh_from_network, inbox_ids)
.inbox_addresses(
refresh_from_network,
inbox_ids.iter().map(String::as_str).collect(),
)
.await?;
Ok(state.into_iter().map(Into::into).collect())
}
Expand Down
44 changes: 43 additions & 1 deletion bindings_node/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tracing_subscriber::{fmt, prelude::*};
pub use xmtp_api_grpc::grpc_api_helper::Client as TonicApiClient;
use xmtp_cryptography::signature::ed25519_public_key_to_address;
use xmtp_id::associations::builder::SignatureRequest;
use xmtp_id::associations::MemberIdentifier;
use xmtp_mls::builder::ClientBuilder;
use xmtp_mls::groups::scoped_client::LocalScopedGroupClient;
use xmtp_mls::identity::IdentityStrategy;
Expand Down Expand Up @@ -292,9 +293,50 @@ impl Client {
) -> Result<Vec<InboxState>> {
let state = self
.inner_client
.inbox_addresses(refresh_from_network, inbox_ids)
.inbox_addresses(
refresh_from_network,
inbox_ids.iter().map(String::as_str).collect(),
)
.await
.map_err(ErrorWrapper::from)?;
Ok(state.into_iter().map(Into::into).collect())
}

#[napi]
pub async fn is_address_authorized(&self, inbox_id: String, address: String) -> Result<bool> {
self
.is_member_of_association_state(&inbox_id, &MemberIdentifier::Address(address))
.await
}

#[napi]
pub async fn is_installation_authorized(
&self,
inbox_id: String,
installation: Vec<u8>,
) -> Result<bool> {
self
.is_member_of_association_state(&inbox_id, &MemberIdentifier::Installation(installation))
.await
}

async fn is_member_of_association_state(
&self,
inbox_id: &str,
identifier: &MemberIdentifier,
) -> Result<bool> {
let client = &self.inner_client;
let conn = self
.inner_client
.store()
.conn()
.map_err(ErrorWrapper::from)?;

let association_state = client
.get_association_state(&conn, inbox_id, None)
.await
.map_err(ErrorWrapper::from)?;

Ok(association_state.get(identifier).is_some())
}
}
4 changes: 2 additions & 2 deletions xmtp_debug/src/app/generate/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl GenerateIdentity {
let first = identities.next().ok_or(eyre::eyre!("Does not exist"))??;

let state = client
.get_latest_association_state(&connection, hex::encode(first.inbox_id))
.get_latest_association_state(&connection, &hex::encode(first.inbox_id))
.await?;
info!("Found generated identities, checking for registration on backend...",);
// we assume that if the first identity is registered, they all are
Expand Down Expand Up @@ -114,7 +114,7 @@ impl GenerateIdentity {
let future = |inbox_id: [u8; 32]| async move {
let id = hex::encode(inbox_id);
trace!(inbox_id = id, "getting association state");
let state = tmp.get_latest_association_state(&conn, id).await?;
let state = tmp.get_latest_association_state(&conn, &id).await?;
bar_ref.inc(1);
Ok(state)
};
Expand Down
2 changes: 1 addition & 1 deletion xmtp_debug/src/app/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Inspect {
match kind {
Associations => {
let state = client
.get_latest_association_state(&conn, hex::encode(*inbox_id))
.get_latest_association_state(&conn, &hex::encode(*inbox_id))
.await?;
info!(
inbox_id = state.inbox_id(),
Expand Down
8 changes: 4 additions & 4 deletions xmtp_id/src/associations/association_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl IdentityAction for AddAssociation {

let existing_entity_id = match existing_member {
// If there is an existing member of the XID, use that member's ID
Some(member) => member.identifier,
Some(member) => member.identifier.clone(),
None => {
// Get the recovery address from the state as a MemberIdentifier
let recovery_identifier: MemberIdentifier =
Expand Down Expand Up @@ -228,7 +228,7 @@ impl IdentityAction for RevokeAssociation {
// Ensure that the new signature is on the same chain as the signature to create the account
let existing_member = existing_state.get(&self.recovery_address_signature.signer);
if let Some(member) = existing_member {
verify_chain_id_matches(&member, &self.recovery_address_signature)?;
verify_chain_id_matches(member, &self.recovery_address_signature)?;
}

if is_legacy_signature(&self.recovery_address_signature) {
Expand Down Expand Up @@ -289,7 +289,7 @@ impl IdentityAction for ChangeRecoveryAddress {

let existing_member = existing_state.get(&self.recovery_address_signature.signer);
if let Some(member) = existing_member {
verify_chain_id_matches(&member, &self.recovery_address_signature)?;
verify_chain_id_matches(member, &self.recovery_address_signature)?;
}

if is_legacy_signature(&self.recovery_address_signature) {
Expand Down Expand Up @@ -459,7 +459,7 @@ fn verify_chain_id_matches(
member: &Member,
signature: &VerifiedSignature,
) -> Result<(), AssociationError> {
if member.added_on_chain_id.ne(&signature.chain_id) {
if member.added_on_chain_id != signature.chain_id {
return Err(AssociationError::ChainIdMismatch(
member.added_on_chain_id.unwrap_or(0),
signature.chain_id.unwrap_or(0),
Expand Down
4 changes: 2 additions & 2 deletions xmtp_id/src/associations/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ impl AssociationState {
new_state
}

pub fn get(&self, identifier: &MemberIdentifier) -> Option<Member> {
self.members.get(identifier).cloned()
pub fn get(&self, identifier: &MemberIdentifier) -> Option<&Member> {
self.members.get(identifier)
}

pub fn add_seen_signatures(&self, signatures: Vec<Vec<u8>>) -> Self {
Expand Down
11 changes: 3 additions & 8 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,19 +365,14 @@ where
}

/// Get the [`AssociationState`] for each `inbox_id`
pub async fn inbox_addresses<InboxId: AsRef<str>>(
pub async fn inbox_addresses<'a>(
&self,
refresh_from_network: bool,
inbox_ids: Vec<InboxId>,
inbox_ids: Vec<InboxIdRef<'a>>,
) -> Result<Vec<AssociationState>, ClientError> {
let conn = self.store().conn()?;
if refresh_from_network {
load_identity_updates(
&self.api_client,
&conn,
&inbox_ids.iter().map(|s| s.as_ref()).collect::<Vec<&str>>(),
)
.await?;
load_identity_updates(&self.api_client, &conn, &inbox_ids).await?;
}
let state = self
.batch_get_association_state(
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ where
{
return None;
}
Some((id.clone(), Some(*sequence)))
Some((id.as_str(), Some(*sequence)))
})
.collect();

Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1493,9 +1493,9 @@ async fn validate_initial_group_membership(

let futures: Vec<_> = membership
.members
.into_iter()
.iter()
.map(|(inbox_id, sequence_id)| {
client.get_association_state(conn, inbox_id, Some(sequence_id as i64))
client.get_association_state(conn, inbox_id, Some(*sequence_id as i64))
})
.collect();

Expand Down
40 changes: 20 additions & 20 deletions xmtp_mls/src/groups/scoped_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ pub trait LocalScopedGroupClient: Send + Sync + Sized {
installation_ids: Vec<Vec<u8>>,
) -> Result<Vec<VerifiedKeyPackageV2>, ClientError>;

async fn get_association_state(
async fn get_association_state<'a>(
&self,
conn: &DbConnection,
inbox_id: String,
inbox_id: InboxIdRef<'a>,
to_sequence_id: Option<i64>,
) -> Result<AssociationState, ClientError>;

async fn batch_get_association_state(
async fn batch_get_association_state<'a>(
&self,
conn: &DbConnection,
identifiers: &[(String, Option<i64>)],
identifiers: &[(InboxIdRef<'a>, Option<i64>)],
) -> Result<Vec<AssociationState>, ClientError>;

async fn query_group_messages(
Expand Down Expand Up @@ -126,17 +126,17 @@ pub trait ScopedGroupClient: Sized {
installation_ids: Vec<Vec<u8>>,
) -> Result<Vec<VerifiedKeyPackageV2>, ClientError>;

async fn get_association_state(
async fn get_association_state<'a>(
&self,
conn: &DbConnection,
inbox_id: String,
inbox_id: InboxIdRef<'a>,
to_sequence_id: Option<i64>,
) -> Result<AssociationState, ClientError>;

async fn batch_get_association_state(
async fn batch_get_association_state<'a>(
&self,
conn: &DbConnection,
identifiers: &[(String, Option<i64>)],
identifiers: &[(InboxIdRef<'a>, Option<i64>)],
) -> Result<Vec<AssociationState>, ClientError>;

async fn query_group_messages(
Expand Down Expand Up @@ -201,10 +201,10 @@ where
.await
}

async fn get_association_state(
async fn get_association_state<'a>(
&self,
conn: &DbConnection,
inbox_id: String,
inbox_id: InboxIdRef<'a>,
to_sequence_id: Option<i64>,
) -> Result<AssociationState, ClientError> {
crate::Client::<ApiClient, Verifier>::get_association_state(
Expand All @@ -216,10 +216,10 @@ where
.await
}

async fn batch_get_association_state(
async fn batch_get_association_state<'a>(
&self,
conn: &DbConnection,
identifiers: &[(String, Option<i64>)],
identifiers: &[(InboxIdRef<'a>, Option<i64>)],
) -> Result<Vec<AssociationState>, ClientError> {
crate::Client::<ApiClient, Verifier>::batch_get_association_state(self, conn, identifiers)
.await
Expand Down Expand Up @@ -298,21 +298,21 @@ where
.await
}

async fn get_association_state(
async fn get_association_state<'a>(
&self,
conn: &DbConnection,
inbox_id: String,
inbox_id: InboxIdRef<'a>,
to_sequence_id: Option<i64>,
) -> Result<AssociationState, ClientError> {
(**self)
.get_association_state(conn, inbox_id, to_sequence_id)
.await
}

async fn batch_get_association_state(
async fn batch_get_association_state<'a>(
&self,
conn: &DbConnection,
identifiers: &[(String, Option<i64>)],
identifiers: &[(InboxIdRef<'a>, Option<i64>)],
) -> Result<Vec<AssociationState>, ClientError> {
(**self)
.batch_get_association_state(conn, identifiers)
Expand Down Expand Up @@ -392,21 +392,21 @@ where
.await
}

async fn get_association_state(
async fn get_association_state<'a>(
&self,
conn: &DbConnection,
inbox_id: String,
inbox_id: InboxIdRef<'a>,
to_sequence_id: Option<i64>,
) -> Result<AssociationState, ClientError> {
(**self)
.get_association_state(conn, inbox_id, to_sequence_id)
.await
}

async fn batch_get_association_state(
async fn batch_get_association_state<'a>(
&self,
conn: &DbConnection,
identifiers: &[(String, Option<i64>)],
identifiers: &[(InboxIdRef<'a>, Option<i64>)],
) -> Result<Vec<AssociationState>, ClientError> {
(**self)
.batch_get_association_state(conn, identifiers)
Expand Down
6 changes: 1 addition & 5 deletions xmtp_mls/src/groups/validated_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,7 @@ impl ValidatedCommit {
.ok_or(CommitValidationError::SubjectDoesNotExist)?;

let inbox_state = client
.get_association_state(
conn,
participant.inbox_id.clone(),
Some(*to_sequence_id as i64),
)
.get_association_state(conn, &participant.inbox_id, Some(*to_sequence_id as i64))
.await
.map_err(InstallationDiffError::from)?;

Expand Down
Loading

0 comments on commit faae68c

Please sign in to comment.