diff --git a/mm2src/mm2_main/src/lp_swap.rs b/mm2src/mm2_main/src/lp_swap.rs index 5b073b8dac..b4eac43929 100644 --- a/mm2src/mm2_main/src/lp_swap.rs +++ b/mm2src/mm2_main/src/lp_swap.rs @@ -87,7 +87,7 @@ use std::convert::TryFrom; use std::num::NonZeroUsize; use std::path::PathBuf; use std::str::FromStr; -use std::sync::{Arc, Mutex, Weak}; +use std::sync::{Arc, Mutex}; use std::time::Duration; use uuid::Uuid; @@ -529,8 +529,11 @@ struct LockedAmountInfo { locked_amount: LockedAmount, } +/// A running swap is the swap accompanied by the abort handle of the thread the swap is running on. +type RunningSwap = (Arc, AbortOnDropHandle); + struct SwapsContext { - running_swaps: Mutex, AbortOnDropHandle)>>, + running_swaps: Mutex>, active_swaps_v2_infos: Mutex>, banned_pubkeys: Mutex>, swap_msgs: Mutex>, @@ -546,7 +549,7 @@ impl SwapsContext { fn from_ctx(ctx: &MmArc) -> Result, String> { Ok(try_s!(from_ctx(&ctx.swaps_ctx, move || { Ok(SwapsContext { - running_swaps: Mutex::new(vec![]), + running_swaps: Mutex::new(HashMap::new()), active_swaps_v2_infos: Mutex::new(HashMap::new()), banned_pubkeys: Mutex::new(HashMap::new()), swap_msgs: Mutex::new(HashMap::new()), @@ -631,11 +634,9 @@ pub fn get_locked_amount(ctx: &MmArc, coin: &str) -> MmNumber { let swap_ctx = SwapsContext::from_ctx(ctx).unwrap(); let swap_lock = swap_ctx.running_swaps.lock().unwrap(); - let mut locked = swap_lock - .iter() - .filter_map(|(swap, _)| swap.upgrade()) - .flat_map(|swap| swap.locked_amount()) - .fold(MmNumber::from(0), |mut total_amount, locked| { + let mut locked = swap_lock.values().flat_map(|(swap, _)| swap.locked_amount()).fold( + MmNumber::from(0), + |mut total_amount, locked| { if locked.coin == coin { total_amount += locked.amount; } @@ -645,7 +646,8 @@ pub fn get_locked_amount(ctx: &MmArc, coin: &str) -> MmNumber { } } total_amount - }); + }, + ); drop(swap_lock); let locked_amounts = swap_ctx.locked_amounts.lock().unwrap(); @@ -669,11 +671,8 @@ pub fn get_locked_amount(ctx: &MmArc, coin: &str) -> MmNumber { /// Get number of currently running swaps pub fn running_swaps_num(ctx: &MmArc) -> u64 { let swap_ctx = SwapsContext::from_ctx(ctx).unwrap(); - let swaps = swap_ctx.running_swaps.lock().unwrap(); - swaps.iter().fold(0, |total, (swap, _)| match swap.upgrade() { - Some(_) => total + 1, - None => total, - }) + let count = swap_ctx.running_swaps.lock().unwrap().len(); + count as u64 } /// Get total amount of selected coin locked by all currently ongoing swaps except the one with selected uuid @@ -682,10 +681,9 @@ fn get_locked_amount_by_other_swaps(ctx: &MmArc, except_uuid: &Uuid, coin: &str) let swap_lock = swap_ctx.running_swaps.lock().unwrap(); swap_lock - .iter() - .filter_map(|(swap, _)| swap.upgrade()) - .filter(|swap| swap.uuid() != except_uuid) - .flat_map(|swap| swap.locked_amount()) + .values() + .filter(|(swap, _)| swap.uuid() != except_uuid) + .flat_map(|(swap, _)| swap.locked_amount()) .fold(MmNumber::from(0), |mut total_amount, locked| { if locked.coin == coin { total_amount += locked.amount; @@ -703,11 +701,9 @@ pub fn active_swaps_using_coins(ctx: &MmArc, coins: &HashSet) -> Result< let swap_ctx = try_s!(SwapsContext::from_ctx(ctx)); let swaps = try_s!(swap_ctx.running_swaps.lock()); let mut uuids = vec![]; - for (swap, _) in swaps.iter() { - if let Some(swap) = swap.upgrade() { - if coins.contains(&swap.maker_coin().to_string()) || coins.contains(&swap.taker_coin().to_string()) { - uuids.push(*swap.uuid()) - } + for (swap, _) in swaps.values() { + if coins.contains(&swap.maker_coin().to_string()) || coins.contains(&swap.taker_coin().to_string()) { + uuids.push(*swap.uuid()) } } drop(swaps); @@ -723,15 +719,13 @@ pub fn active_swaps_using_coins(ctx: &MmArc, coins: &HashSet) -> Result< pub fn active_swaps(ctx: &MmArc) -> Result, String> { let swap_ctx = try_s!(SwapsContext::from_ctx(ctx)); - let swaps = swap_ctx.running_swaps.lock().unwrap(); - let mut uuids = vec![]; - for (swap, _) in swaps.iter() { - if let Some(swap) = swap.upgrade() { - uuids.push((*swap.uuid(), LEGACY_SWAP_TYPE)) - } - } - - drop(swaps); + let mut uuids: Vec<_> = swap_ctx + .running_swaps + .lock() + .unwrap() + .keys() + .map(|uuid| (*uuid, LEGACY_SWAP_TYPE)) + .collect(); let swaps_v2 = swap_ctx.active_swaps_v2_infos.lock().unwrap(); uuids.extend(swaps_v2.iter().map(|(uuid, info)| (*uuid, info.swap_type))); diff --git a/mm2src/mm2_main/src/lp_swap/maker_swap.rs b/mm2src/mm2_main/src/lp_swap/maker_swap.rs index 5d0866c69b..bd3b3702da 100644 --- a/mm2src/mm2_main/src/lp_swap/maker_swap.rs +++ b/mm2src/mm2_main/src/lp_swap/maker_swap.rs @@ -2090,10 +2090,10 @@ pub async fn run_maker_swap(swap: RunMakerSwapInput, ctx: MmArc) { }; } let running_swap = Arc::new(swap); - let weak_ref = Arc::downgrade(&running_swap); let swap_ctx = SwapsContext::from_ctx(&ctx).unwrap(); swap_ctx.init_msg_store(running_swap.uuid, running_swap.taker); - let mut swap_fut = Box::pin( + let mut swap_fut = Box::pin({ + let running_swap = running_swap.clone(); async move { let mut events; loop { @@ -2150,8 +2150,8 @@ pub async fn run_maker_swap(swap: RunMakerSwapInput, ctx: MmArc) { } } } - .fuse(), - ); + .fuse() + }); // Run the swap in an abortable task and wait for it to finish. let (swap_ended_notifier, swap_ended_notification) = oneshot::channel(); let abortable_swap = spawn_abortable(async move { @@ -2163,9 +2163,15 @@ pub async fn run_maker_swap(swap: RunMakerSwapInput, ctx: MmArc) { error!("Swap listener stopped listening!"); } }); - swap_ctx.running_swaps.lock().unwrap().push((weak_ref, abortable_swap)); + let uuid = running_swap.uuid; + swap_ctx + .running_swaps + .lock() + .unwrap() + .insert(uuid, (running_swap, abortable_swap)); // Halt this function until the swap has finished (or interrupted, i.e. aborted/panic). swap_ended_notification.await.error_log_with_msg("Swap interrupted!"); + swap_ctx.running_swaps.lock().unwrap().remove(&uuid); } pub struct MakerSwapPreparedParams { diff --git a/mm2src/mm2_main/src/lp_swap/swap_v2_rpcs.rs b/mm2src/mm2_main/src/lp_swap/swap_v2_rpcs.rs index 95edd534bb..1dac93f78d 100644 --- a/mm2src/mm2_main/src/lp_swap/swap_v2_rpcs.rs +++ b/mm2src/mm2_main/src/lp_swap/swap_v2_rpcs.rs @@ -532,11 +532,10 @@ pub(crate) struct StopSwapResponse { pub(crate) async fn stop_swap_rpc(ctx: MmArc, req: StopSwapRequest) -> MmResult { let swap_ctx = SwapsContext::from_ctx(&ctx).map_err(StopSwapErr::Internal)?; - let mut running_swaps = swap_ctx.running_swaps.lock().unwrap(); - let Some(position) = running_swaps.iter().position(|(swap, _)| swap.upgrade().map_or(true, |swap| swap.uuid() == &req.uuid)) else { + // By just removing the swap's abort handle from the running swaps map, the swap will terminate. + if swap_ctx.running_swaps.lock().unwrap().remove(&req.uuid).is_none() { return MmError::err(StopSwapErr::NotRunning); - }; - let (_swap, _abort_handle) = running_swaps.swap_remove(position); + } Ok(StopSwapResponse { result: "Success".to_string(), }) @@ -582,12 +581,8 @@ pub(crate) async fn kickstart_swap_rpc( // up with the same swap being kickstarted twice, but we have filesystem swap locks for that. This check is // rather for convenience. let swap_ctx = SwapsContext::from_ctx(&ctx).map_err(KickStartSwapErr::Internal)?; - for (swap, _) in swap_ctx.running_swaps.lock().unwrap().iter() { - if let Some(swap) = swap.upgrade() { - if swap.uuid() == &req.uuid { - return MmError::err(KickStartSwapErr::AlreadyRunning); - } - } + if swap_ctx.running_swaps.lock().unwrap().contains_key(&req.uuid) { + return MmError::err(KickStartSwapErr::AlreadyRunning); } // Load the swap from the DB. let swap = match SavedSwap::load_my_swap_from_db(&ctx, req.uuid).await { @@ -647,7 +642,7 @@ pub(crate) async fn kickstart_swap_rpc( ))); }, }; - // Kickstart the swap. A new aborthandle will show up shortly for the swap. + // Kickstart the swap. A new abort handle will show up shortly for the swap. match swap { SavedSwap::Maker(saved_swap) => ctx.spawner().spawn(run_maker_swap( RunMakerSwapInput::KickStart { diff --git a/mm2src/mm2_main/src/lp_swap/taker_swap.rs b/mm2src/mm2_main/src/lp_swap/taker_swap.rs index 609021f5fa..5c16eb6e9a 100644 --- a/mm2src/mm2_main/src/lp_swap/taker_swap.rs +++ b/mm2src/mm2_main/src/lp_swap/taker_swap.rs @@ -463,10 +463,10 @@ pub async fn run_taker_swap(swap: RunTakerSwapInput, ctx: MmArc) { let uuid = swap.uuid.to_string(); let to_broadcast = !(swap.maker_coin.is_privacy() || swap.taker_coin.is_privacy()); let running_swap = Arc::new(swap); - let weak_ref = Arc::downgrade(&running_swap); let swap_ctx = SwapsContext::from_ctx(&ctx).unwrap(); swap_ctx.init_msg_store(running_swap.uuid, running_swap.maker); - let mut swap_fut = Box::pin( + let mut swap_fut = Box::pin({ + let running_swap = running_swap.clone(); async move { let mut events; loop { @@ -516,8 +516,8 @@ pub async fn run_taker_swap(swap: RunTakerSwapInput, ctx: MmArc) { } } } - .fuse(), - ); + .fuse() + }); // Run the swap in an abortable task and wait for it to finish. let (swap_ended_notifier, swap_ended_notification) = oneshot::channel(); let abortable_swap = spawn_abortable(async move { @@ -529,9 +529,15 @@ pub async fn run_taker_swap(swap: RunTakerSwapInput, ctx: MmArc) { error!("Swap listener stopped listening!"); } }); - swap_ctx.running_swaps.lock().unwrap().push((weak_ref, abortable_swap)); + let uuid = running_swap.uuid; + swap_ctx + .running_swaps + .lock() + .unwrap() + .insert(uuid, (running_swap, abortable_swap)); // Halt this function until the swap has finished (or interrupted, i.e. aborted/panic). swap_ended_notification.await.error_log_with_msg("Swap interrupted!"); + swap_ctx.running_swaps.lock().unwrap().remove(&uuid); } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] @@ -3217,10 +3223,13 @@ mod taker_swap_tests { .unwrap(); let swaps_ctx = SwapsContext::from_ctx(&ctx).unwrap(); let arc = Arc::new(swap); - let weak_ref = Arc::downgrade(&arc); // Create a dummy abort handle as if it was a running swap. let abortable_swap = spawn_abortable(async move {}); - swaps_ctx.running_swaps.lock().unwrap().push((weak_ref, abortable_swap)); + swaps_ctx + .running_swaps + .lock() + .unwrap() + .insert(arc.uuid, (arc, abortable_swap)); let actual = get_locked_amount(&ctx, "RICK"); assert_eq!(actual, MmNumber::from(0));