diff --git a/Cargo.toml b/Cargo.toml index 8ab8798..2c940b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,5 @@ csv = "1.3.1" ixa = { git = "https://github.com/cdcgov/ixa", version = "0.0.1" } ixa-derive = { git = "https://github.com/cdcgov/ixa", version = "0.0.0" } serde = "1.0.215" +serde_derive = "1.0.215" +tempfile = "3.14.0" diff --git a/src/parameters_loader.rs b/src/parameters_loader.rs index c874a95..ad45453 100644 --- a/src/parameters_loader.rs +++ b/src/parameters_loader.rs @@ -9,7 +9,6 @@ use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ParametersValues { - pub population: usize, pub max_time: f64, pub seed: u64, pub r_0: f64, diff --git a/src/population_loader.rs b/src/population_loader.rs index 2dca1e3..c7b75c7 100644 --- a/src/population_loader.rs +++ b/src/population_loader.rs @@ -1,112 +1,110 @@ use crate::parameters_loader::Parameters; use ixa::{ context::Context, - define_derived_property, define_person_property, define_person_property_with_default, + define_person_property, define_person_property_with_default, error::IxaError, global_properties::ContextGlobalPropertiesExt, people::{ContextPeopleExt, PersonId}, }; use serde::Deserialize; +use std::path::PathBuf; #[derive(Deserialize, Debug)] #[allow(non_snake_case)] -pub struct PeopleRecord { +pub struct PeopleRecord<'a> { age: u8, - homeId: usize, + homeId: &'a [u8], } define_person_property!(Age, u8); define_person_property!(HomeId, usize); define_person_property_with_default!(Alive, bool, true); -#[allow(clippy::cast_possible_truncation)] -#[allow(clippy::cast_sign_loss)] -const CENSUS_MAX: usize = 1e15 as usize; - -#[allow(clippy::cast_possible_truncation)] -#[allow(clippy::cast_sign_loss)] -const CENSUS_MIN: usize = 1e14 as usize; - -define_derived_property!(CensusTract, usize, [HomeId], |home_id| { - if (CENSUS_MIN..CENSUS_MAX).contains(&home_id) { - home_id / 10_000 - } else { - 0 //Err(IxaError::IxaError(String::from("Census tract invalid from homeId"))) - } -}); +define_person_property!(CensusTract, usize); fn create_person_from_record( context: &mut Context, person_record: &PeopleRecord, ) -> Result { - let person_id = - context.add_person(((Age, person_record.age), (HomeId, person_record.homeId)))?; + let tract: usize = String::from_utf8(person_record.homeId[..11].to_owned()) + .expect("Home id should have 11 digits for tract + home id") + .parse() + .expect("Could not parse census tract"); + let home_id: usize = String::from_utf8(person_record.homeId.to_owned()) + .expect("Could not read home id") + .parse() + .expect("Could not read home id"); + + let person_id = context.add_person(( + (Age, person_record.age), + (HomeId, home_id), + (CensusTract, tract), + ))?; Ok(person_id) } +fn load_synth_population(context: &mut Context, synth_input_file: PathBuf) -> Result<(), IxaError> { + let mut reader = + csv::Reader::from_path(synth_input_file).expect("Failed to open file. No headers found."); + let mut raw_record = csv::ByteRecord::new(); + let headers = reader.byte_headers().unwrap().clone(); + + while reader.read_byte_record(&mut raw_record).unwrap() { + let record: PeopleRecord = raw_record + .deserialize(Some(&headers)) + .expect("Failed to parse record."); + create_person_from_record(context, &record)?; + } + Ok(()) +} + pub fn init(context: &mut Context) -> Result<(), IxaError> { let parameters = context .get_global_property_value(Parameters) .unwrap() .clone(); - let mut reader = - csv::Reader::from_path(parameters.synth_population_file).expect("Failed to open file."); - - for result in reader.deserialize() { - let record: PeopleRecord = result.expect("Failed to parse record."); - create_person_from_record(context, &record)?; - } + load_synth_population(context, PathBuf::from(parameters.synth_population_file))?; context.index_property(Age); context.index_property(CensusTract); Ok(()) } #[cfg(test)] - -mod tests { +mod test { use super::*; - use ixa::{context::Context, people::ContextPeopleExt, random::ContextRandomExt}; + use std::fs::File; + use std::io::Write; + use std::path::PathBuf; + use tempfile::tempdir; #[test] - #[allow(clippy::inconsistent_digit_grouping)] - fn test_create_person_from_record() { + fn check_synth_file_tract() { let mut context = Context::new(); - context.init_random(0); - let record = PeopleRecord { - age: 42, - homeId: 36_09_30_33102_0005, - }; - let person_id = create_person_from_record(&mut context, &record).unwrap(); - assert_eq!(context.get_person_property(person_id, Age), 42); - assert_eq!( - context.get_person_property(person_id, HomeId), - 36_09_30_33102_0005 - ); - assert!(context.get_person_property(person_id, Alive)); + let temp_dir = tempdir().unwrap(); + let path = PathBuf::from(&temp_dir.path()); + let persisted_file = path.join("synth_pop_test.csv"); + let mut file = File::create(persisted_file).unwrap(); + file.write_all( + b" +age,homeId +43,360930331020001 +42,360930331020002", + ) + .unwrap(); + let synth_file = path.join("synth_pop_test.csv"); + load_synth_population(&mut context, synth_file).unwrap(); + let age: u8 = 43; + let tract: usize = 36_093_033_102; + assert_eq!( - context.get_person_property(person_id, CensusTract), - 36_09_30_33102 + age, + context.get_person_property(context.get_person_id(0), Age) ); - } - - #[test] - #[allow(clippy::inconsistent_digit_grouping)] - fn test_create_person_from_record_invalid_census_tract() { - let mut context = Context::new(); - context.init_random(0); - let record = PeopleRecord { - age: 42, - homeId: 36_09_30_33102_0005_0005, - }; - let person_id = create_person_from_record(&mut context, &record).unwrap(); - assert_eq!(context.get_person_property(person_id, Age), 42); assert_eq!( - context.get_person_property(person_id, HomeId), - 36_09_30_33102_0005_0005 + tract, + context.get_person_property(context.get_person_id(0), CensusTract) ); - assert!(context.get_person_property(person_id, Alive)); - assert_eq!(context.get_person_property(person_id, CensusTract), 0); } }