diff --git a/Cargo.toml b/Cargo.toml index 6e0c8e10..be2f210f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,3 @@ license = "Apache-2.0" homepage = "https://github.com/CDCgov/ixa" [dependencies] -derivative = "2.2.0" diff --git a/src/context.rs b/src/context.rs index e0dbc93c..ffbe03af 100644 --- a/src/context.rs +++ b/src/context.rs @@ -29,7 +29,7 @@ use crate::plan::{PlanId, PlanQueue}; type Callback = dyn FnOnce(&mut Context); pub struct Context { - plan_queue: PlanQueue, + plan_queue: PlanQueue>, callback_queue: VecDeque>, data_plugins: HashMap>, current_time: f64, @@ -47,7 +47,7 @@ impl Context { pub fn add_plan(&mut self, time: f64, callback: impl FnOnce(&mut Context) + 'static) -> PlanId { // TODO: Handle invalid times (past, NAN, etc) - self.plan_queue.add_plan(time, callback) + self.plan_queue.add_plan(time, Box::new(callback)) } pub fn cancel_plan(&mut self, id: PlanId) { @@ -99,12 +99,12 @@ impl Context { continue; } - // There aren't any callbacks, so look at the first timed plan. - if let Some(timed_plan) = self.plan_queue.get_next_timed_plan() { - self.current_time = timed_plan.time; - (timed_plan.callback)(self); + // There aren't any callbacks, so look at the first plan. + if let Some(plan) = self.plan_queue.get_next_plan() { + self.current_time = plan.time; + (plan.data)(self); } else { - // OK, there aren't any timed plans, so we're done. + // OK, there aren't any plans, so we're done. break; } } diff --git a/src/plan.rs b/src/plan.rs index 15c1bd30..532a8515 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -1,90 +1,87 @@ use std::{ cmp::Ordering, - collections::{BinaryHeap, HashSet}, + collections::{BinaryHeap, HashMap}, }; -use derivative::Derivative; - -use crate::context::Context; - pub struct PlanId { - id: u64, + id: usize, } -#[derive(Derivative)] -#[derivative(Eq, PartialEq, Debug)] -pub struct TimedPlan { +pub struct Plan { pub time: f64, - plan_id: u64, - #[derivative(PartialEq = "ignore", Debug = "ignore")] - pub callback: Box, + pub data: T, } -impl Ord for TimedPlan { +#[derive(PartialEq, Debug)] +pub struct PlanRecord { + pub time: f64, + id: usize, +} + +impl Eq for PlanRecord {} + +impl PartialOrd for PlanRecord { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PlanRecord { fn cmp(&self, other: &Self) -> Ordering { let time_ordering = self.time.partial_cmp(&other.time).unwrap().reverse(); if time_ordering == Ordering::Equal { // Break time ties in order of plan id - self.plan_id.cmp(&other.plan_id).reverse() + self.id.cmp(&other.id).reverse() } else { time_ordering } } } -impl PartialOrd for TimedPlan { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - #[derive(Debug)] -pub struct PlanQueue { - queue: BinaryHeap, - invalid_set: HashSet, - plan_counter: u64, +pub struct PlanQueue { + queue: BinaryHeap, + data_map: HashMap, + plan_counter: usize, } -impl Default for PlanQueue { +impl Default for PlanQueue { fn default() -> Self { Self::new() } } -impl PlanQueue { - pub fn new() -> PlanQueue { +impl PlanQueue { + pub fn new() -> PlanQueue { PlanQueue { queue: BinaryHeap::new(), - invalid_set: HashSet::new(), + data_map: HashMap::new(), plan_counter: 0, } } - pub fn add_plan(&mut self, time: f64, callback: impl FnOnce(&mut Context) + 'static) -> PlanId { - // Add plan to queue and increment counter - let plan_id = self.plan_counter; - self.queue.push(TimedPlan { - time, - plan_id, - callback: Box::new(callback), - }); + pub fn add_plan(&mut self, time: f64, data: T) -> PlanId { + // Add plan to queue, store data, and increment counter + let id = self.plan_counter; + self.queue.push(PlanRecord { time, id }); + self.data_map.insert(id, data); self.plan_counter += 1; - PlanId { id: plan_id } + PlanId { id } } pub fn cancel_plan(&mut self, id: PlanId) { - self.invalid_set.insert(id.id); + self.data_map.remove(&id.id).expect("Plan does not exist"); } - pub fn get_next_timed_plan(&mut self) -> Option { + pub fn get_next_plan(&mut self) -> Option> { loop { - let next_timed_plan = self.queue.pop(); - match next_timed_plan { - Some(timed_plan) => { - if self.invalid_set.contains(&timed_plan.plan_id) { - self.invalid_set.remove(&timed_plan.plan_id); - } else { - return Some(timed_plan); + match self.queue.pop() { + Some(plan_record) => { + if let Some(data) = self.data_map.remove(&plan_record.id) { + return Some(Plan { + time: plan_record.time, + data, + }); } } None => { @@ -94,3 +91,44 @@ impl PlanQueue { } } } + +#[cfg(test)] +mod tests { + use super::PlanQueue; + + #[test] + fn test_add_cancel() { + // Add some plans and cancel and make sure ordering occurs as expected + let mut plan_queue = PlanQueue::::new(); + plan_queue.add_plan(1.0, 1); + plan_queue.add_plan(3.0, 3); + plan_queue.add_plan(3.0, 4); + let plan_to_cancel = plan_queue.add_plan(1.5, 0); + plan_queue.add_plan(2.0, 2); + plan_queue.cancel_plan(plan_to_cancel); + + assert_eq!(plan_queue.get_next_plan().unwrap().time, 1.0); + assert_eq!(plan_queue.get_next_plan().unwrap().time, 2.0); + + // Check tie handling + let next_plan = plan_queue.get_next_plan().unwrap(); + assert_eq!(next_plan.time, 3.0); + assert_eq!(next_plan.data, 3); + + let next_plan = plan_queue.get_next_plan().unwrap(); + assert_eq!(next_plan.time, 3.0); + assert_eq!(next_plan.data, 4); + + assert!(plan_queue.get_next_plan().is_none()); + } + + #[test] + #[should_panic] + fn test_invalid_cancel() { + // Cancel a plan that has already occured and make sure it panics + let mut plan_queue = PlanQueue::<()>::new(); + let plan_to_cancel = plan_queue.add_plan(1.0, ()); + plan_queue.get_next_plan(); + plan_queue.cancel_plan(plan_to_cancel); + } +}