From 8be0819801c30d323221466c7d8bd387d59507df Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Tue, 4 Jun 2024 15:48:03 -0600 Subject: [PATCH] Fixing bug in mixing models and implements new entity-related methods (#18) * Fixing bug in mixing models and implements method for reading entity-agent ties from pointers * Agents getter * Adding print method for entity * Updating single-header --- epiworld.hpp | 107 +++++++++++++++++++++---- include/epiworld/entity-bones.hpp | 4 + include/epiworld/entity-meat.hpp | 18 +++++ include/epiworld/model-bones.hpp | 6 ++ include/epiworld/model-meat.hpp | 69 ++++++++++++---- include/epiworld/models/seirmixing.hpp | 5 +- include/epiworld/models/sirmixing.hpp | 5 +- 7 files changed, 182 insertions(+), 32 deletions(-) diff --git a/epiworld.hpp b/epiworld.hpp index 8c2e778e..bf42c046 100644 --- a/epiworld.hpp +++ b/epiworld.hpp @@ -6351,6 +6351,12 @@ class Model { const std::vector & entities_ids ); + void load_agents_entities_ties( + const int * agents_id, + const int * entities_id, + size_t n + ); + /** * @name Accessing population of the model * @@ -7955,7 +7961,7 @@ inline void Model::rm_entity(size_t entity_id) entity.reset(); // How should - if (entity_pos != (entities.size() - 1)) + if (entity_pos != (static_cast(entities.size()) - 1)) std::swap(entities[entity_pos], entities[entities.size() - 1]); entities.pop_back(); @@ -8101,46 +8107,87 @@ inline void Model::load_agents_entities_ties( const std::vector< int > & entities_ids ) { + // Checking the size if (agents_ids.size() != entities_ids.size()) throw std::length_error( - std::string("agents_ids (") + + std::string("The size of agents_ids (") + std::to_string(agents_ids.size()) + std::string(") and entities_ids (") + std::to_string(entities_ids.size()) + - std::string(") should match.") + std::string(") must be the same.") ); + return this->load_agents_entities_ties( + agents_ids.data(), + entities_ids.data(), + agents_ids.size() + ); + +} + +template +inline void Model::load_agents_entities_ties( + const int * agents_ids, + const int * entities_ids, + size_t n +) { + + auto get_agent = [agents_ids](int i) -> int { + return *(agents_ids + i); + }; - size_t n_entries = agents_ids.size(); - for (size_t i = 0u; i < n_entries; ++i) + auto get_entity = [entities_ids](int i) -> int { + return *(entities_ids + i); + }; + + for (size_t i = 0u; i < n; ++i) { - if (agents_ids[i] >= this->population.size()) + if (get_agent(i) < 0) + throw std::length_error( + std::string("agents_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_agent(i)) + + std::string(" is negative.") + ); + + if (get_entity(i) < 0) + throw std::length_error( + std::string("entities_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_entity(i)) + + std::string(" is negative.") + ); + + int pop_size = static_cast(this->population.size()); + if (get_agent(i) >= pop_size) throw std::length_error( std::string("agents_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(agents_ids[i]) + + std::to_string(get_agent(i)) + std::string(" is out of range (population size: ") + - std::to_string(this->population.size()) + + std::to_string(pop_size) + std::string(").") ); - - if (entities_ids[i] >= this->entities.size()) + int ent_size = static_cast(this->entities.size()); + if (get_entity(i) >= ent_size) throw std::length_error( std::string("entities_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(entities_ids[i]) + + std::to_string(get_entity(i)) + std::string(" is out of range (entities size: ") + - std::to_string(this->entities.size()) + + std::to_string(ent_size) + std::string(").") ); // Adding the entity to the agent - this->population[agents_ids[i]].add_entity( - this->entities[entities_ids[i]], + this->population[get_agent(i)].add_entity( + this->entities[get_entity(i)], nullptr /* Immediately add it to the agent */ ); @@ -12088,6 +12135,10 @@ class Entity { void distribute(); + std::vector< size_t > & get_agents(); + + void print() const; + }; @@ -12392,6 +12443,24 @@ inline void Entity::distribute() } +template +inline std::vector< size_t > & Entity::get_agents() +{ + return agents; +} + +template +inline void Entity::print() const +{ + + printf_epiworld( + "Entity '%s' (id %i) with %i agents.\n", + this->entity_name.c_str(), + static_cast(id), + static_cast(n_agents) + ); +} + #endif /*////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -19657,7 +19726,10 @@ inline void ModelSEIRMixing::update_infected() { if (a.get_state() == ModelSEIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } } @@ -20190,7 +20262,10 @@ inline void ModelSIRMixing::update_infected_list() { if (a.get_state() == ModelSIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } } diff --git a/include/epiworld/entity-bones.hpp b/include/epiworld/entity-bones.hpp index ba992700..fa956dab 100644 --- a/include/epiworld/entity-bones.hpp +++ b/include/epiworld/entity-bones.hpp @@ -103,6 +103,10 @@ class Entity { void distribute(); + std::vector< size_t > & get_agents(); + + void print() const; + }; diff --git a/include/epiworld/entity-meat.hpp b/include/epiworld/entity-meat.hpp index 68d6902c..df7c3064 100644 --- a/include/epiworld/entity-meat.hpp +++ b/include/epiworld/entity-meat.hpp @@ -280,4 +280,22 @@ inline void Entity::distribute() } +template +inline std::vector< size_t > & Entity::get_agents() +{ + return agents; +} + +template +inline void Entity::print() const +{ + + printf_epiworld( + "Entity '%s' (id %i) with %i agents.\n", + this->entity_name.c_str(), + static_cast(id), + static_cast(n_agents) + ); +} + #endif \ No newline at end of file diff --git a/include/epiworld/model-bones.hpp b/include/epiworld/model-bones.hpp index 033c0d42..d8b11694 100644 --- a/include/epiworld/model-bones.hpp +++ b/include/epiworld/model-bones.hpp @@ -372,6 +372,12 @@ class Model { const std::vector & entities_ids ); + void load_agents_entities_ties( + const int * agents_id, + const int * entities_id, + size_t n + ); + /** * @name Accessing population of the model * diff --git a/include/epiworld/model-meat.hpp b/include/epiworld/model-meat.hpp index 4f1caea3..523aaece 100644 --- a/include/epiworld/model-meat.hpp +++ b/include/epiworld/model-meat.hpp @@ -1213,7 +1213,7 @@ inline void Model::rm_entity(size_t entity_id) entity.reset(); // How should - if (entity_pos != (entities.size() - 1)) + if (entity_pos != (static_cast(entities.size()) - 1)) std::swap(entities[entity_pos], entities[entities.size() - 1]); entities.pop_back(); @@ -1359,46 +1359,87 @@ inline void Model::load_agents_entities_ties( const std::vector< int > & entities_ids ) { + // Checking the size if (agents_ids.size() != entities_ids.size()) throw std::length_error( - std::string("agents_ids (") + + std::string("The size of agents_ids (") + std::to_string(agents_ids.size()) + std::string(") and entities_ids (") + std::to_string(entities_ids.size()) + - std::string(") should match.") + std::string(") must be the same.") ); + return this->load_agents_entities_ties( + agents_ids.data(), + entities_ids.data(), + agents_ids.size() + ); + +} - size_t n_entries = agents_ids.size(); - for (size_t i = 0u; i < n_entries; ++i) +template +inline void Model::load_agents_entities_ties( + const int * agents_ids, + const int * entities_ids, + size_t n +) { + + auto get_agent = [agents_ids](int i) -> int { + return *(agents_ids + i); + }; + + auto get_entity = [entities_ids](int i) -> int { + return *(entities_ids + i); + }; + + for (size_t i = 0u; i < n; ++i) { - if (agents_ids[i] >= this->population.size()) + if (get_agent(i) < 0) + throw std::length_error( + std::string("agents_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_agent(i)) + + std::string(" is negative.") + ); + + if (get_entity(i) < 0) + throw std::length_error( + std::string("entities_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_entity(i)) + + std::string(" is negative.") + ); + + int pop_size = static_cast(this->population.size()); + if (get_agent(i) >= pop_size) throw std::length_error( std::string("agents_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(agents_ids[i]) + + std::to_string(get_agent(i)) + std::string(" is out of range (population size: ") + - std::to_string(this->population.size()) + + std::to_string(pop_size) + std::string(").") ); - - if (entities_ids[i] >= this->entities.size()) + int ent_size = static_cast(this->entities.size()); + if (get_entity(i) >= ent_size) throw std::length_error( std::string("entities_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(entities_ids[i]) + + std::to_string(get_entity(i)) + std::string(" is out of range (entities size: ") + - std::to_string(this->entities.size()) + + std::to_string(ent_size) + std::string(").") ); // Adding the entity to the agent - this->population[agents_ids[i]].add_entity( - this->entities[entities_ids[i]], + this->population[get_agent(i)].add_entity( + this->entities[get_entity(i)], nullptr /* Immediately add it to the agent */ ); diff --git a/include/epiworld/models/seirmixing.hpp b/include/epiworld/models/seirmixing.hpp index 91c86a93..50304a0e 100644 --- a/include/epiworld/models/seirmixing.hpp +++ b/include/epiworld/models/seirmixing.hpp @@ -164,7 +164,10 @@ inline void ModelSEIRMixing::update_infected() { if (a.get_state() == ModelSEIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } } diff --git a/include/epiworld/models/sirmixing.hpp b/include/epiworld/models/sirmixing.hpp index 74c2169a..d87a7eae 100644 --- a/include/epiworld/models/sirmixing.hpp +++ b/include/epiworld/models/sirmixing.hpp @@ -159,7 +159,10 @@ inline void ModelSIRMixing::update_infected_list() { if (a.get_state() == ModelSIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } }