Skip to content

Commit

Permalink
feat(rust): optimize sqlite queries
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjoDeundiak committed Dec 19, 2024
1 parent 2ff6e39 commit c46e2a8
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,49 +53,25 @@ impl AuthorityEnrollmentTokenRepository for AuthorityEnrollmentTokenSqlxDatabase
one_time_code: OneTimeCode,
now: TimestampInSeconds,
) -> Result<Option<EnrollmentToken>> {
// We need to delete expired tokens regularly
// Also makes sure we don't get expired tokens later inside this function
let query1 = query("DELETE FROM authority_enrollment_token WHERE expires_at <= $1")
.bind(now.0 as i64);

let res = query1.execute(&*self.database.pool).await.into_core()?;
debug!("Deleted {} expired enrollment tokens", res.rows_affected());

let mut transaction = self.database.pool.begin().await.into_core()?;

let query2 = query_as("SELECT one_time_code, reference, issued_by, created_at, expires_at, ttl_count, attributes FROM authority_enrollment_token WHERE one_time_code = $1")
.bind(one_time_code);
let row: Option<EnrollmentTokenRow> =
query2.fetch_optional(&mut *transaction).await.into_core()?;
let token: Option<EnrollmentToken> = row.map(|r| r.try_into()).transpose()?;

if let Some(token) = &token {
if token.ttl_count <= 1 {
let query3 =
query("DElETE FROM authority_enrollment_token WHERE one_time_code = $1")
.bind(one_time_code);
query3.execute(&mut *transaction).await.void()?;
debug!(
"Deleted enrollment token because it has been used. Reference: {}",
token.reference()
);
} else {
let new_ttl_count = token.ttl_count - 1;
let query3 = query(
"UPDATE authority_enrollment_token SET ttl_count = $1 WHERE one_time_code = $2",
)
.bind(new_ttl_count as i64)
.bind(one_time_code);
query3.execute(&mut *transaction).await.void()?;
debug!(
"Decreasing enrollment token usage count to {}. Reference: {}",
new_ttl_count,
token.reference()
);
}
}

transaction.commit().await.void()?;
// FIXME: We need to delete expired tokens regularly
// FIXME 2: Also now need to clear tokens with ttl_count = 0

let query = query_as("UPDATE authority_enrollment_token SET ttl_count = ttl_count - 1 WHERE one_time_code = $1 AND expires_at > $2 AND ttl_count > 0 RETURNING one_time_code, reference, issued_by, created_at, expires_at, ttl_count, attributes")
.bind(one_time_code)
.bind(now);
let row: Option<EnrollmentTokenRow> = query
.fetch_optional(&*self.database.pool)
.await
.into_core()?;
let token = if let Some(row) = row {
let mut t = EnrollmentToken::try_from(row)?;

t.ttl_count += 1; // We decremented it inside the query

Some(t)
} else {
None
};

Ok(token)
}
Expand Down
73 changes: 73 additions & 0 deletions implementations/rust/ockam/ockam_api/tests/perf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::common::common::{change_client_identifier, start_authority, AuthorityInfo};
use ockam::identity::SecureChannels;
use ockam_api::authenticator::direct::{
OCKAM_ROLE_ATTRIBUTE_ENROLLER_VALUE, OCKAM_ROLE_ATTRIBUTE_KEY,
};
use ockam_api::authenticator::enrollment_tokens::TokenIssuer;
use ockam_core::Result;
use ockam_node::Context;
use std::collections::BTreeMap;
use std::time::Instant;

mod common;

#[ockam_macros::test]
async fn enrollment_token_loop(ctx: &mut Context) -> Result<()> {
let secure_channels = SecureChannels::builder().await?.build();

let AuthorityInfo { admins, .. } = start_authority(ctx, secure_channels.clone(), 1).await?;
let admin = &admins[0];

let mut attributes = BTreeMap::<String, String>::default();
attributes.insert(
OCKAM_ROLE_ATTRIBUTE_KEY.to_string(),
OCKAM_ROLE_ATTRIBUTE_ENROLLER_VALUE.to_string(),
);
let otc = admin
.client
.create_token(ctx, attributes.clone(), None, None)
.await
.unwrap();

let enroller = secure_channels
.identities()
.identities_creation()
.create_identity()
.await?;
let enroller_client = change_client_identifier(&admin.client, &enroller, None);

{
use ockam_api::authenticator::enrollment_tokens::TokenAcceptor;
enroller_client.present_token(ctx, otc).await.unwrap();
}

let mut attributes_member = BTreeMap::<String, String>::default();
attributes_member.insert("KEY".to_string(), "VALUE".to_string());

let otc = enroller_client
.create_token(ctx, attributes_member.clone(), None, Some(1024))
.await
.unwrap();

let t1 = Instant::now();
for _ in 0..1000 {
let member = secure_channels
.identities()
.identities_creation()
.create_identity()
.await?;
let member_client = change_client_identifier(&admin.client, &member, None);

member_client.present_token(ctx, &otc).await.unwrap();

use ockam_api::enroll::enrollment::Enrollment;

member_client.issue_credential(ctx).await.unwrap();
}

let t2 = Instant::now();

println!("TIME: {:?}", t2 - t1);

Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ impl IdentitiesAttributes {
subject: &Identifier,
attested_by: &Identifier,
) -> Result<Option<AttributesEntry>> {
self.repository.delete_expired_attributes(now()?).await?;
self.repository.get_attributes(subject, attested_by).await
let now = now()?;
// FIXME: This should be run periodically in a separate task
// self.repository.delete_expired_attributes(now()?).await?;
self.repository
.get_non_expired_attributes(subject, attested_by, now)
.await
}

/// Set the attributes associated with the given identity identifier.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ use ockam_node::retry;
#[async_trait]
pub trait IdentityAttributesRepository: Send + Sync + 'static {
/// Get the attributes associated with the given identity identifier
async fn get_attributes(
async fn get_non_expired_attributes(
&self,
subject: &Identifier,
attested_by: &Identifier,
now: TimestampInSeconds,
) -> Result<Option<AttributesEntry>>;

/// Set the attributes associated with the given identity identifier.
Expand All @@ -28,12 +29,15 @@ pub trait IdentityAttributesRepository: Send + Sync + 'static {
#[cfg(feature = "std")]
#[async_trait]
impl<T: IdentityAttributesRepository> IdentityAttributesRepository for AutoRetry<T> {
async fn get_attributes(
async fn get_non_expired_attributes(
&self,
subject: &Identifier,
attested_by: &Identifier,
now: TimestampInSeconds,
) -> Result<Option<AttributesEntry>> {
retry!(self.wrapped.get_attributes(subject, attested_by))
retry!(self
.wrapped
.get_non_expired_attributes(subject, attested_by, now))
}

async fn put_attributes(&self, subject: &Identifier, entry: AttributesEntry) -> Result<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ impl IdentityAttributesSqlxDatabase {

#[async_trait]
impl IdentityAttributesRepository for IdentityAttributesSqlxDatabase {
async fn get_attributes(
async fn get_non_expired_attributes(
&self,
identity: &Identifier,
attested_by: &Identifier,
now: TimestampInSeconds,
) -> Result<Option<AttributesEntry>> {
let query = query_as(
"SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE identifier = $1 AND attested_by = $2 AND node_name = $3"
"SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE identifier = $1 AND attested_by = $2 AND (expires > $3 or expires IS NULL) AND node_name = $4"
)
.bind(identity)
.bind(attested_by)
.bind(now)
.bind(&self.node_name);
let identity_attributes: Option<IdentityAttributesRow> = query
.fetch_optional(&*self.database.pool)
Expand Down Expand Up @@ -186,12 +188,12 @@ mod tests {
.await?;

let result = repository
.get_attributes(&identifier1, &identifier1)
.get_non_expired_attributes(&identifier1, &identifier1, now)
.await?;
assert_eq!(result, Some(attributes1.clone()));

let result = repository
.get_attributes(&identifier2, &identifier2)
.get_non_expired_attributes(&identifier2, &identifier2, now)
.await?;
assert_eq!(result, Some(attributes2.clone()));

Expand Down Expand Up @@ -236,17 +238,17 @@ mod tests {
repository.delete_expired_attributes(now.add(10)).await?;

let result = repository
.get_attributes(&identifier1, &identifier1)
.get_non_expired_attributes(&identifier1, &identifier1, now)
.await?;
assert_eq!(result, None);

let result = repository
.get_attributes(&identifier2, &identifier2)
.get_non_expired_attributes(&identifier2, &identifier2, now)
.await?;
assert_eq!(result, None);

let result = repository
.get_attributes(&identifier3, &identifier3)
.get_non_expired_attributes(&identifier3, &identifier3, now)
.await?;
assert_eq!(
result,
Expand All @@ -255,7 +257,7 @@ mod tests {
);

let result = repository
.get_attributes(&identifier4, &identifier4)
.get_non_expired_attributes(&identifier4, &identifier4, now)
.await?;
assert_eq!(
result,
Expand Down

0 comments on commit c46e2a8

Please sign in to comment.