Skip to content

Commit

Permalink
Add tests for plan queue and handle invalid cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonasher committed Jul 31, 2024
1 parent e5b5d20 commit 18672a1
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 54 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ license = "Apache-2.0"
homepage = "https://github.com/CDCgov/ixa"

[dependencies]
derivative = "2.2.0"
14 changes: 7 additions & 7 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::plan::{PlanId, PlanQueue};

type Callback = dyn FnOnce(&mut Context);
pub struct Context {
plan_queue: PlanQueue,
plan_queue: PlanQueue<Box<Callback>>,
callback_queue: VecDeque<Box<Callback>>,
data_plugins: HashMap<TypeId, Box<dyn Any>>,
current_time: f64,
Expand All @@ -47,7 +47,7 @@ impl Context {

pub fn add_plan(&mut self, time: f64, callback: impl FnOnce(&mut Context) + 'static) -> PlanId {
// TODO: Handle invalid times (past, NAN, etc)
self.plan_queue.add_plan(time, callback)
self.plan_queue.add_plan(time, Box::new(callback))
}

pub fn cancel_plan(&mut self, id: PlanId) {
Expand Down Expand Up @@ -99,12 +99,12 @@ impl Context {
continue;
}

// There aren't any callbacks, so look at the first timed plan.
if let Some(timed_plan) = self.plan_queue.get_next_timed_plan() {
self.current_time = timed_plan.time;
(timed_plan.callback)(self);
// There aren't any callbacks, so look at the first plan.
if let Some(plan) = self.plan_queue.get_next_plan() {
self.current_time = plan.time;
(plan.data)(self);
} else {
// OK, there aren't any timed plans, so we're done.
// OK, there aren't any plans, so we're done.
break;
}
}
Expand Down
130 changes: 84 additions & 46 deletions src/plan.rs
Original file line number Diff line number Diff line change
@@ -1,90 +1,87 @@
use std::{
cmp::Ordering,
collections::{BinaryHeap, HashSet},
collections::{BinaryHeap, HashMap},
};

use derivative::Derivative;

use crate::context::Context;

pub struct PlanId {
id: u64,
id: usize,
}

#[derive(Derivative)]
#[derivative(Eq, PartialEq, Debug)]
pub struct TimedPlan {
pub struct Plan<T> {
pub time: f64,
plan_id: u64,
#[derivative(PartialEq = "ignore", Debug = "ignore")]
pub callback: Box<dyn FnOnce(&mut Context)>,
pub data: T,
}

impl Ord for TimedPlan {
#[derive(PartialEq, Debug)]
pub struct PlanRecord {
pub time: f64,
id: usize,
}

impl Eq for PlanRecord {}

impl PartialOrd for PlanRecord {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for PlanRecord {
fn cmp(&self, other: &Self) -> Ordering {
let time_ordering = self.time.partial_cmp(&other.time).unwrap().reverse();
if time_ordering == Ordering::Equal {
// Break time ties in order of plan id
self.plan_id.cmp(&other.plan_id).reverse()
self.id.cmp(&other.id).reverse()
} else {
time_ordering
}
}
}

impl PartialOrd for TimedPlan {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

#[derive(Debug)]
pub struct PlanQueue {
queue: BinaryHeap<TimedPlan>,
invalid_set: HashSet<u64>,
plan_counter: u64,
pub struct PlanQueue<T> {
queue: BinaryHeap<PlanRecord>,
data_map: HashMap<usize, T>,
plan_counter: usize,
}

impl Default for PlanQueue {
impl<T> Default for PlanQueue<T> {
fn default() -> Self {
Self::new()
}
}

impl PlanQueue {
pub fn new() -> PlanQueue {
impl<T> PlanQueue<T> {
pub fn new() -> PlanQueue<T> {
PlanQueue {
queue: BinaryHeap::new(),
invalid_set: HashSet::new(),
data_map: HashMap::new(),
plan_counter: 0,
}
}

pub fn add_plan(&mut self, time: f64, callback: impl FnOnce(&mut Context) + 'static) -> PlanId {
// Add plan to queue and increment counter
let plan_id = self.plan_counter;
self.queue.push(TimedPlan {
time,
plan_id,
callback: Box::new(callback),
});
pub fn add_plan(&mut self, time: f64, data: T) -> PlanId {
// Add plan to queue, store data, and increment counter
let id = self.plan_counter;
self.queue.push(PlanRecord { time, id });
self.data_map.insert(id, data);
self.plan_counter += 1;
PlanId { id: plan_id }
PlanId { id }
}

pub fn cancel_plan(&mut self, id: PlanId) {
self.invalid_set.insert(id.id);
self.data_map.remove(&id.id).expect("Plan does not exist");
}

pub fn get_next_timed_plan(&mut self) -> Option<TimedPlan> {
pub fn get_next_plan(&mut self) -> Option<Plan<T>> {
loop {
let next_timed_plan = self.queue.pop();
match next_timed_plan {
Some(timed_plan) => {
if self.invalid_set.contains(&timed_plan.plan_id) {
self.invalid_set.remove(&timed_plan.plan_id);
} else {
return Some(timed_plan);
match self.queue.pop() {
Some(plan_record) => {
if let Some(data) = self.data_map.remove(&plan_record.id) {
return Some(Plan {
time: plan_record.time,
data,
});
}
}
None => {
Expand All @@ -94,3 +91,44 @@ impl PlanQueue {
}
}
}

#[cfg(test)]
mod tests {
use super::PlanQueue;

#[test]
fn test_add_cancel() {
// Add some plans and cancel and make sure ordering occurs as expected
let mut plan_queue = PlanQueue::<usize>::new();
plan_queue.add_plan(1.0, 1);
plan_queue.add_plan(3.0, 3);
plan_queue.add_plan(3.0, 4);
let plan_to_cancel = plan_queue.add_plan(1.5, 0);
plan_queue.add_plan(2.0, 2);
plan_queue.cancel_plan(plan_to_cancel);

assert_eq!(plan_queue.get_next_plan().unwrap().time, 1.0);
assert_eq!(plan_queue.get_next_plan().unwrap().time, 2.0);

// Check tie handling
let next_plan = plan_queue.get_next_plan().unwrap();
assert_eq!(next_plan.time, 3.0);
assert_eq!(next_plan.data, 3);

let next_plan = plan_queue.get_next_plan().unwrap();
assert_eq!(next_plan.time, 3.0);
assert_eq!(next_plan.data, 4);

assert!(plan_queue.get_next_plan().is_none());
}

#[test]
#[should_panic]
fn test_invalid_cancel() {
// Cancel a plan that has already occured and make sure it panics
let mut plan_queue = PlanQueue::<()>::new();
let plan_to_cancel = plan_queue.add_plan(1.0, ());
plan_queue.get_next_plan();
plan_queue.cancel_plan(plan_to_cancel);
}
}

0 comments on commit 18672a1

Please sign in to comment.