Skip to content

Commit

Permalink
feat(rust): persist tcp inlets data to reuse the same address when re…
Browse files Browse the repository at this point in the history
…creating them
  • Loading branch information
adrianbenavides committed Sep 19, 2023
1 parent 4819e08 commit 74b9623
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 30 deletions.
151 changes: 125 additions & 26 deletions implementations/rust/ockam/ockam_app/src/invitations/commands.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use miette::IntoDiagnostic;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::str::FromStr;
use tauri::{AppHandle, Manager, Runtime, State};
use tokio::sync::RwLockWriteGuard;
use tracing::{debug, info, trace, warn};

use ockam_api::address::get_free_address;
Expand All @@ -13,6 +15,7 @@ use ockam_api::nodes::models::portal::InletStatus;

use crate::app::{AppState, PROJECT_NAME};
use crate::cli::cli_bin;
use crate::invitations::state::{InvitationState, TcpInlet};
use crate::projects::commands::{create_enrollment_ticket, SyncAdminProjectsState};
use crate::shared_service::relay::RELAY_NAME;

Expand Down Expand Up @@ -130,34 +133,40 @@ pub async fn refresh_invitations<R: Runtime>(app: AppHandle<R>) -> Result<(), St
let invitation_state: State<'_, SyncInvitationsState> = app.state();
let mut writer = invitation_state.write().await;
writer.replace_by(invitations.clone());
refresh_inlets(state, writer)
.await
.map_err(|e| e.to_string())?;
}
refresh_inlets(&app, invitations.accepted.as_ref())
.await
.map_err(|e| e.to_string())?;
app.trigger_global(REFRESHED_INVITATIONS, None);
Ok(())
}

async fn refresh_inlets<R: Runtime>(
app: &AppHandle<R>,
accepted_invitations: Option<&Vec<InvitationWithAccess>>,
async fn refresh_inlets(
app_state: State<'_, AppState>,
mut invitations_state: RwLockWriteGuard<'_, InvitationState>,
) -> crate::Result<()> {
debug!("Refreshing inlets");
let accepted_invitations = match accepted_invitations {
Some(accepted_invitations) => accepted_invitations,
None => {
debug!("No accepted invitations, skipping inlets refresh");
return Ok(());
}
};
let app_state: State<'_, AppState> = app.state();
if invitations_state.accepted.invitations.is_empty() {
debug!("No accepted invitations, skipping inlets refresh");
return Ok(());
}

let cli_state = app_state.state().await;
let cli_bin = cli_bin()?;
let mut inlets_socket_addrs = vec![];
for invitation in accepted_invitations {
match InletDataFromInvitation::new(&cli_state, invitation) {
for invitation in &invitations_state.accepted.invitations {
match InletDataFromInvitation::new(
&cli_state,
invitation,
&invitations_state.accepted.inlets,
) {
Ok(i) => match i {
Some(i) => {
if !i.enabled {
debug!(node = %i.local_node_name, "TCP inlet is disabled by the user, skipping");
continue;
}

let mut inlet_is_running = false;
debug!(node = %i.local_node_name, "Checking node status");
if let Ok(node) = cli_state.nodes.get(&i.local_node_name) {
Expand Down Expand Up @@ -228,15 +237,11 @@ async fn refresh_inlets<R: Runtime>(
}
}
}
{
let invitations_state: State<'_, SyncInvitationsState> = app.state();
let mut writer = invitations_state.write().await;
for (invitation_id, inlet_socket_addr) in inlets_socket_addrs {
writer
.accepted
.inlets
.insert(invitation_id, inlet_socket_addr);
}
for (invitation_id, inlet_socket_addr) in inlets_socket_addrs {
invitations_state
.accepted
.inlets
.insert(invitation_id, TcpInlet::new(inlet_socket_addr));
}
info!("Inlets refreshed");
Ok(())
Expand All @@ -247,12 +252,20 @@ async fn refresh_inlets<R: Runtime>(
async fn create_inlet(inlet_data: &InletDataFromInvitation) -> crate::Result<SocketAddr> {
debug!(service_name = ?inlet_data.service_name, "Creating TCP inlet for accepted invitation");
let InletDataFromInvitation {
enabled,
local_node_name,
service_name,
service_route,
enrollment_ticket_hex,
socket_addr,
} = inlet_data;
let from = get_free_address()?;
if !enabled {
return Err("TCP inlet is disabled by the user".into());
}
let from = match socket_addr {
Some(socket_addr) => *socket_addr,
None => get_free_address()?,
};
let from_str = from.to_string();
let cli_bin = cli_bin()?;
if let Some(enrollment_ticket_hex) = enrollment_ticket_hex {
Expand Down Expand Up @@ -312,16 +325,19 @@ async fn create_inlet(inlet_data: &InletDataFromInvitation) -> crate::Result<Soc

#[derive(Debug)]
struct InletDataFromInvitation {
pub enabled: bool,
pub local_node_name: String,
pub service_name: String,
pub service_route: String,
pub enrollment_ticket_hex: Option<String>,
pub socket_addr: Option<SocketAddr>,
}

impl InletDataFromInvitation {
pub fn new(
cli_state: &CliState,
invitation: &InvitationWithAccess,
inlets: &HashMap<String, TcpInlet>,
) -> crate::Result<Option<Self>> {
match &invitation.service_access_details {
Some(d) => {
Expand Down Expand Up @@ -359,11 +375,17 @@ impl InletDataFromInvitation {
*RELAY_NAME
);

let inlet = inlets.get(&invitation.invitation.id);
let enabled = inlet.map(|i| i.enabled).unwrap_or(true);
let socket_addr = inlet.map(|i| i.socket_addr);

Ok(Some(Self {
enabled,
local_node_name,
service_name,
service_route,
enrollment_ticket_hex,
socket_addr,
}))
} else {
warn!(?invitation, "No project data found in enrollment ticket");
Expand All @@ -380,3 +402,80 @@ impl InletDataFromInvitation {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use ockam::identity::OneTimeCode;
use ockam_api::cloud::share::{
ReceivedInvitation, RoleInShare, ServiceAccessDetails, ShareScope,
};
use ockam_api::config::lookup::ProjectLookup;
use ockam_api::identity::EnrollmentTicket;

#[test]
fn test_inlet_data_from_invitation() {
let cli_state = CliState::test().unwrap();
let mut inlets = HashMap::new();
let mut invitation = InvitationWithAccess {
invitation: ReceivedInvitation {
id: "invitation_id".to_string(),
expires_at: "2020-09-12T15:07:14.00".to_string(),
grant_role: RoleInShare::Admin,
owner_email: "owner_email".to_string(),
scope: ShareScope::Project,
target_id: "target_id".to_string(),
},
service_access_details: None,
};

// InletDataFromInvitation will be none because `service_access_details` is none
assert!(
InletDataFromInvitation::new(&cli_state, &invitation, &inlets)
.unwrap()
.is_none()
);

invitation.service_access_details = Some(ServiceAccessDetails {
project_identity: "project_identity".to_string(),
project_route: "project_route".to_string(),
project_authority_identity: "project_authority_identity".to_string(),
project_authority_route: "project_authority_route".to_string(),
shared_node_identity: "shared_node_identity".to_string(),
shared_node_route: "shared_node_route".to_string(),
enrollment_ticket: EnrollmentTicket::new(
OneTimeCode::new(),
Some(ProjectLookup {
node_route: None,
id: "project_identity".to_string(),
name: "project_name".to_string(),
identity_id: None,
authority: None,
okta: None,
}),
None,
)
.hex_encoded()
.unwrap(),
});

// Validate the inlet data, with no prior inlet data
let inlet_data = InletDataFromInvitation::new(&cli_state, &invitation, &inlets)
.unwrap()
.unwrap();
assert!(inlet_data.socket_addr.is_none());

// Validate the inlet data, with prior inlet data
inlets.insert(
"invitation_id".to_string(),
TcpInlet {
socket_addr: "127.0.0.1:1000".parse().unwrap(),
enabled: true,
},
);
let inlet_data = InletDataFromInvitation::new(&cli_state, &invitation, &inlets)
.unwrap()
.unwrap();
assert!(inlet_data.socket_addr.is_some());
}
}
66 changes: 65 additions & 1 deletion implementations/rust/ockam/ockam_app/src/invitations/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,71 @@ pub struct AcceptedInvitations {

/// Inlets for accepted invitations, keyed by invitation id.
#[serde(default)]
pub(crate) inlets: HashMap<String, SocketAddr>,
pub(crate) inlets: HashMap<String, TcpInlet>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TcpInlet {
pub(crate) socket_addr: SocketAddr,
pub(crate) enabled: bool,
}

impl TcpInlet {
pub fn new(socket_addr: SocketAddr) -> Self {
Self {
socket_addr,
enabled: true,
}
}
}

pub(crate) type SyncInvitationsState = Arc<RwLock<InvitationState>>;

#[cfg(test)]
mod tests {
use super::*;
use ockam_api::cloud::share::{RoleInShare, ShareScope};

#[test]
fn test_replace_by() {
let mut state = InvitationState::default();
assert!(state.sent.is_empty());
assert!(state.received.is_empty());
assert!(state.accepted.invitations.is_empty());
let list = InvitationList {
sent: Some(vec![SentInvitation {
id: "id".to_string(),
expires_at: "expires_at".to_string(),
grant_role: RoleInShare::Admin,
owner_id: 0,
recipient_email: "".to_string(),
remaining_uses: 0,
scope: ShareScope::Project,
target_id: "target_id".to_string(),
}]),
received: Some(vec![ReceivedInvitation {
id: "id".to_string(),
expires_at: "expires_at".to_string(),
grant_role: RoleInShare::Admin,
owner_email: "owner_email".to_string(),
scope: ShareScope::Project,
target_id: "target_id".to_string(),
}]),
accepted: Some(vec![InvitationWithAccess {
invitation: ReceivedInvitation {
id: "id".to_string(),
expires_at: "expires_at".to_string(),
grant_role: RoleInShare::Admin,
owner_email: "owner_email".to_string(),
scope: ShareScope::Project,
target_id: "target_id".to_string(),
},
service_access_details: None,
}]),
};
state.replace_by(list);
assert_eq!(state.sent.len(), 1);
assert_eq!(state.received.len(), 1);
assert_eq!(state.accepted.invitations.len(), 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,13 @@ fn add_accepted_menus<R: Runtime>(
let mut submenu_builder = SubmenuBuilder::new(app_handle, owner_email);
submenu_builder = invitations
.into_iter()
.map(|(invitation_id, access_details, inlet_socket_addr)| {
accepted_invite_menu(app_handle, invitation_id, access_details, inlet_socket_addr)
.map(|(invitation_id, access_details, inlet)| {
accepted_invite_menu(
app_handle,
invitation_id,
access_details,
inlet.map(|i| &i.socket_addr),
)
})
.fold(submenu_builder, |menu, submenu| menu.item(&submenu));
submenus.push(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct CreateCommand {
at: Option<String>,

/// Address on which to accept tcp connections.
#[arg(long, display_order = 900, id = "SOCKET_ADDRESS", default_value_t = default_from_addr(), value_parser = socket_addr_parser)]
#[arg(long, display_order = 900, id = "SOCKET_ADDRESS", hide_default_value = true, default_value_t = default_from_addr(), value_parser = socket_addr_parser)]
from: SocketAddr,

/// Route to a tcp outlet.
Expand Down

0 comments on commit 74b9623

Please sign in to comment.