diff --git a/examples/births-deaths/infection_manager.rs b/examples/births-deaths/infection_manager.rs index 4657fe8..a00bd92 100644 --- a/examples/births-deaths/infection_manager.rs +++ b/examples/births-deaths/infection_manager.rs @@ -142,17 +142,19 @@ mod test { ); let population_size: usize = 10; - for _ in 0..population_size { + for index in 0..population_size { let person = context.create_new_person(0); context.add_plan(1.0, move |context| { context.set_person_property(person, InfectionStatusType, InfectionStatus::I); }); - } - context.add_plan(1.1, move |context| { - context.kill_person(context.get_person_id(0)); - }); + if index == 0 { + context.add_plan(1.1, move |context| { + context.kill_person(person); + }); + } + } context.execute(); assert_eq!(population_size, context.get_current_population()); diff --git a/examples/births-deaths/population_manager.rs b/examples/births-deaths/population_manager.rs index 6ffbb83..b5df1b3 100644 --- a/examples/births-deaths/population_manager.rs +++ b/examples/births-deaths/population_manager.rs @@ -129,13 +129,6 @@ pub trait ContextPopulationExt { fn get_current_group_population(&mut self, age_group: AgeGroupRisk) -> usize; fn sample_person(&mut self, age_group: AgeGroupRisk) -> Option; #[allow(dead_code)] - fn get_population_by_property( - &mut self, - property: T, - value: T::Value, - ) -> usize - where - ::Value: PartialEq; fn sample_person_by_property( &mut self, property: T, @@ -155,51 +148,17 @@ impl ContextPopulationExt for Context { } fn get_current_group_population(&mut self, age_group: AgeGroupRisk) -> usize { - let mut current_population = 0; - for i in 0..self.get_current_population() { - let person_id = self.get_person_id(i); - if self.get_person_property(person_id, Alive) - && self.get_person_property(person_id, AgeGroupFoi) == age_group - { - current_population += 1; - } - } - current_population + self.query_people_count(((Alive, true), (AgeGroupFoi, age_group))) } fn sample_person(&mut self, age_group: AgeGroupRisk) -> Option { - let mut people_vec = Vec::::new(); - for i in 0..self.get_current_population() { - let person_id = self.get_person_id(i); - if self.get_person_property(person_id, Alive) - && self.get_person_property(person_id, AgeGroupFoi) == age_group - { - people_vec.push(person_id); - } - } + let people_vec = self.query_people(((Alive, true), (AgeGroupFoi, age_group))); if people_vec.is_empty() { None } else { Some(people_vec[self.sample_range(PeopleRng, 0..people_vec.len())]) } } - fn get_population_by_property( - &mut self, - property: T, - value: T::Value, - ) -> usize - where - ::Value: PartialEq, - { - let mut population_counter = 0; - for i in 0..self.get_current_population() { - let person_id = self.get_person_id(i); - if self.get_person_property(person_id, property) == value { - population_counter += 1; - } - } - population_counter - } fn sample_person_by_property( &mut self, @@ -209,13 +168,7 @@ impl ContextPopulationExt for Context { where ::Value: PartialEq, { - let mut people_vec = Vec::::new(); - for i in 0..self.get_current_population() { - let person_id = self.get_person_id(i); - if self.get_person_property(person_id, property) == value { - people_vec.push(person_id); - } - } + let people_vec = self.query_people((property, value)); if people_vec.is_empty() { None } else { @@ -229,32 +182,37 @@ mod test { use super::*; use crate::parameters_loader::{FoiAgeGroups, ParametersValues}; use ixa::context::Context; + use std::cell::RefCell; + use std::rc::Rc; #[test] fn test_birth_death() { let mut context = Context::new(); - let person = context.create_new_person(10); - context.add_plan(380.0, |context| { - _ = context.create_new_person(0); + let person1 = context.create_new_person(10); + let person2 = Rc::>>::new(RefCell::new(None)); + let person2_clone = Rc::clone(&person2); + + context.add_plan(380.0, move |context| { + *person2_clone.borrow_mut() = Some(context.create_new_person(0)); }); context.add_plan(400.0, move |context| { - context.kill_person(person); + context.kill_person(person1); }); context.add_plan(390.0, |context| { - let pop = context.get_population_by_property(Alive, true); + let pop = context.query_people_count((Alive, true)); assert_eq!(pop, 2); }); context.add_plan(401.0, |context| { - let pop = context.get_population_by_property(Alive, true); + let pop = context.query_people_count((Alive, true)); assert_eq!(pop, 1); }); context.execute(); let population = context.get_current_population(); // Even if these people have died during simulation, we can still get their properties - let age_0 = context.get_person_property(context.get_person_id(0), Age); - let age_1 = context.get_person_property(context.get_person_id(1), Age); + let age_0 = context.get_person_property(person1, Age); + let age_1 = context.get_person_property((*person2).borrow().unwrap(), Age); assert_eq!(age_0, 10); assert_eq!(age_1, 0); @@ -321,17 +279,21 @@ mod test { AgeGroupRisk::General, AgeGroupRisk::OldAdult, ]; + let mut people = Vec::::new(); for age in &age_vec { - let _person = context.create_new_person(*age); + people.push(context.create_new_person(*age)); } - for p in 0..context.get_current_population() { - let person = context.get_person_id(p); + for i in 0..people.len() { + let person = people[i]; context.add_plan(365.0, move |context| { schedule_aging(context, person); }); - let age_group = age_groups[p]; - assert_eq!(age_group, context.get_person_property(person, AgeGroupFoi)); + let age_group = age_groups[i]; + assert_eq!( + age_group, + context.get_person_property(people[i], AgeGroupFoi) + ); } // Plan to check in 5 years @@ -342,10 +304,12 @@ mod test { AgeGroupRisk::OldAdult, ]; context.add_plan(years * 365.0, move |context| { - for p in 0..context.get_current_population() { - let person = context.get_person_id(p); - let age_group = future_age_groups[p]; - assert_eq!(age_group, context.get_person_property(person, AgeGroupFoi)); + for i in 0..people.len() { + let age_group = future_age_groups[i]; + assert_eq!( + age_group, + context.get_person_property(people[i], AgeGroupFoi) + ); } }); diff --git a/examples/parameter-loading/transmission_manager.rs b/examples/parameter-loading/transmission_manager.rs index fc994ca..3be015e 100644 --- a/examples/parameter-loading/transmission_manager.rs +++ b/examples/parameter-loading/transmission_manager.rs @@ -13,8 +13,7 @@ define_rng!(TransmissionRng); fn attempt_infection(context: &mut Context) { let population_size: usize = context.get_current_population(); - let person_to_infect = - context.get_person_id(context.sample_range(TransmissionRng, 0..population_size)); + let person_to_infect = context.sample_person(TransmissionRng).unwrap(); let person_status: InfectionStatus = context.get_person_property(person_to_infect, InfectionStatusType); let parameters = context diff --git a/examples/time-varying-infection/exposure_manager.rs b/examples/time-varying-infection/exposure_manager.rs index 478bdb5..585880d 100644 --- a/examples/time-varying-infection/exposure_manager.rs +++ b/examples/time-varying-infection/exposure_manager.rs @@ -103,14 +103,11 @@ mod test { .clone(); context.init_random(parameters.seed); init(&mut context); - context.add_person(()).unwrap(); + let person = context.add_person(()).unwrap(); context.execute(); - let person_status = - context.get_person_property(context.get_person_id(0), DiseaseStatusType); + let person_status = context.get_person_property(person, DiseaseStatusType); assert_eq!(person_status, DiseaseStatus::I); - let infection_time = context - .get_person_property(context.get_person_id(0), InfectionTime) - .unwrap(); + let infection_time = context.get_person_property(person, InfectionTime).unwrap(); assert_eq!(infection_time, context.get_current_time()); } diff --git a/examples/time-varying-infection/infection_manager.rs b/examples/time-varying-infection/infection_manager.rs index 74b37ee..aa18c10 100644 --- a/examples/time-varying-infection/infection_manager.rs +++ b/examples/time-varying-infection/infection_manager.rs @@ -15,22 +15,15 @@ fn recovery_cdf(context: &mut Context, time_spent_infected: f64) -> f64 { 1.0 - f64::exp(-time_spent_infected * n_eff_inv_infec(context)) } +#[allow(clippy::cast_precision_loss)] fn n_eff_inv_infec(context: &mut Context) -> f64 { let parameters = context .get_global_property_value(Parameters) .unwrap() .clone(); // get number of infected people - let mut n_infected = 0; - for usize_id in 0..context.get_current_population() { - if matches!( - context.get_person_property(context.get_person_id(usize_id), DiseaseStatusType), - DiseaseStatus::I - ) { - n_infected += 1; - } - } - (1.0 / parameters.infection_duration) / f64::from(n_infected) + let n_infected = context.query_people_count((DiseaseStatusType, DiseaseStatus::I)); + (1.0 / parameters.infection_duration) / (n_infected as f64) } fn evaluate_recovery( @@ -163,12 +156,15 @@ mod test { output_dir: ".".to_string(), output_file: ".".to_string(), }; + let mut people = Vec::new(); + context .set_global_property_value(Parameters, parameters.clone()) .unwrap(); context.init_random(parameters.seed); for _ in 0..parameters.population { let person_id = context.add_person(()).unwrap(); + people.push(person_id); context.set_person_property(person_id, DiseaseStatusType, DiseaseStatus::I); } assert_eq!( @@ -178,12 +174,8 @@ mod test { let time_spent_infected = 0.5; let cdf_value_many_infected = recovery_cdf(&mut context, time_spent_infected); // now make it so that all but 1 person becomes recovered - for person_id in 1..parameters.population { - context.set_person_property( - context.get_person_id(person_id), - DiseaseStatusType, - DiseaseStatus::R, - ); + for i in 1..parameters.population { + context.set_person_property(people[i], DiseaseStatusType, DiseaseStatus::R); } assert_eq!( n_eff_inv_infec(&mut context), diff --git a/src/people.rs b/src/people.rs index 7403b15..08cd4f7 100644 --- a/src/people.rs +++ b/src/people.rs @@ -71,8 +71,10 @@ use crate::{ context::{Context, IxaEvent}, define_data_plugin, error::IxaError, + random::{ContextRandomExt, RngId}, }; use ixa_derive::IxaEvent; +use rand::Rng; use seq_macro::seq; use serde::{Deserialize, Serialize}; use std::{ @@ -118,6 +120,13 @@ pub trait Query { fn get_query(&self) -> Vec<(TypeId, IndexValue)>; } +impl Query for () { + fn setup(_: &Context) {} + + fn get_query(&self) -> Vec<(TypeId, IndexValue)> { + vec![] + } +} // Implement the query version with one parameter. impl Query for (T1, T1::Value) { fn setup(context: &Context) { @@ -846,10 +855,6 @@ pub trait ContextPeopleExt { value: T::Value, ); - // Returns a PersonId for a usize - #[doc(hidden)] - fn get_person_id(&self, person_id: usize) -> PersonId; - /// Create an index for property `T`. /// /// If an index is available [`Context::query_people()`] will use it, so this is @@ -888,6 +893,17 @@ pub trait ContextPeopleExt { fn tabulate_person_properties(&self, tabulator: &T, print_fn: F) where F: Fn(&Context, &[String], usize); + + /// Randomly sample a person from the population. + /// + /// This is currently implemented by sampling in `0..current_population` + /// but in the future we might have holes where people were removed. + /// + /// # Errors + /// Returns `IxaError` if population is 0. + fn sample_person(&self, rng_id: R) -> Result + where + R::RngType: Rng; } fn process_indices( @@ -1061,14 +1077,6 @@ impl ContextPeopleExt for Context { } } - fn get_person_id(&self, person_id: usize) -> PersonId { - assert!( - person_id < self.get_current_population(), - "Person does not exist" - ); - PersonId { id: person_id } - } - fn index_property(&mut self, _property: T) { // Ensure that the data container exists { @@ -1195,6 +1203,17 @@ impl ContextPeopleExt for Context { &print_fn, ); } + + fn sample_person(&self, rng_id: R) -> Result + where + R::RngType: Rng, + { + if self.get_current_population() == 0 { + return Err(IxaError::IxaError(String::from("Empty population"))); + } + let result = self.sample_range(rng_id, 0..self.get_current_population()); + Ok(PersonId { id: result }) + } } trait ContextPeopleExtInternal { @@ -1601,14 +1620,6 @@ mod test { context.get_person_property(person_id, RiskCategoryType); } - #[test] - #[should_panic(expected = "Person does not exist")] - fn dont_return_person_id() { - let mut context = Context::new(); - context.add_person(()).unwrap(); - context.get_person_id(1); - } - #[test] fn get_person_property_returns_correct_value() { let mut context = Context::new(); @@ -2116,4 +2127,19 @@ mod test { &expected, ); } + + use crate::random::{define_rng, ContextRandomExt}; + + #[test] + fn test_sample_person() { + define_rng!(SampleRng1); + let mut context = Context::new(); + context.init_random(42); + assert!(matches!( + context.sample_person(SampleRng1), + Err(IxaError::IxaError(_)) + )); + let person = context.add_person(()).unwrap(); + assert_eq!(context.sample_person(SampleRng1).unwrap(), person); + } }