diff --git a/examples/11-entities/main.cpp b/examples/11-entities/main.cpp index e6ab094d..da3a298f 100644 --- a/examples/11-entities/main.cpp +++ b/examples/11-entities/main.cpp @@ -7,7 +7,7 @@ int main() { epimodels::ModelSEIREntitiesConn model( "Flu", // std::string vname, - 10000, // epiworld_fast_uint n, + 100000, // epiworld_fast_uint n, 0.01,// epiworld_double prevalence, 4.0,// epiworld_double contact_rate, 0.1,// epiworld_double transmission_rate, diff --git a/include/epiworld/agentssample-bones.hpp b/include/epiworld/agentssample-bones.hpp index 3026e3aa..e0d1285a 100644 --- a/include/epiworld/agentssample-bones.hpp +++ b/include/epiworld/agentssample-bones.hpp @@ -191,9 +191,6 @@ inline AgentsSample::AgentsSample( agents = &agent_.sampled_agents; agents_n = &agent_.sampled_agents_n; - agents_left = &agent_.sampled_agents_left; - agents_left_n = &agent_.sampled_agents_left_n; - // Computing the cumulative sum of counts across entities size_t agents_in_entities = 0; Entities entities_a = agent->get_entities(); @@ -249,14 +246,17 @@ inline AgentsSample::AgentsSample( else agent_idx = entities_a[e][jth - cum_agents_count[e - 1]]->get_id(); - // Getting the state - size_t state = model->population[agent_idx].get_state(); // Checking if states was specified if (states.size()) { + + // Getting the state + size_t state = model->population[agent_idx].get_state(); + if (std::find(states.begin(), states.end(), state) != states.end()) continue; + } agents->operator[](i_obs++) = &(model->population[agent_idx]); diff --git a/include/epiworld/models/seirentitiesconnected.hpp b/include/epiworld/models/seirentitiesconnected.hpp index c3e77e8c..3dcd20f9 100644 --- a/include/epiworld/models/seirentitiesconnected.hpp +++ b/include/epiworld/models/seirentitiesconnected.hpp @@ -91,7 +91,110 @@ class ModelSEIREntitiesConn : public epiworld::Model }; -// Global event that moves agents between states +template +class GroupSampler { + +private: + + epiworld::Model & model; + const std::vector< double > & contact_matrix; ///< Contact matrix between groups + const std::vector< size_t > & group_sizes; ///< Sizes of the groups + std::vector< double > cumulate; ///< Cumulative sum of the contact matrix (row-major for faster access) + + /** + * @brief Get the index of the contact matrix + * + * The matrix is a vector stored in column-major order. + * + * @param i Index of the row + * @param j Index of the column + * @return Index of the contact matrix + */ + inline int idx(const int i, const int j, bool rowmajor = false) const + { + + if (rowmajor) + return i * group_sizes.size() + j; + + return j * group_sizes.size() + i; + + } + +public: + + GroupSampler( + epiworld::Model & model, + const std::vector< double > & contact_matrix, + const std::vector< size_t > & group_sizes + ): model(model), contact_matrix(contact_matrix), group_sizes(group_sizes) { + + this->cumulate.resize(contact_matrix.size()); + std::fill(cumulate.begin(), cumulate.end(), 0.0); + + // Cumulative sum + for (size_t j = 1; j < group_sizes.size(); ++j) + { + for (size_t i = 0; i < group_sizes.size(); ++i) + cumulate[idx(i, j, true)] += + cumulate[idx(i, j - 1, true)] + + contact_matrix[idx(i, j)]; + } + + }; + + int sample_1(const int origin_group); + + void sample_n( + std::vector< size_t > & sample, + const int origin_group, + const int nsamples + ); + +}; + +template +int GroupSampler::sample_1(const int origin_group) +{ + + // Random number + double r = model.runif(); + + // Finding the group + size_t j = 0; + while (r > cumulate[idx(origin_group, j, true)]) + ++j; + + // Adjusting the prob + r = r - (j == 0 ? 0.0 : cumulate[idx(origin_group, j - 1, true)]); + + int res = static_cast( + std::floor(r * group_sizes[j]) + ); + + // Making sure we are not picling outside of the group + if (res >= static_cast(group_sizes[j])) + res = static_cast(group_sizes[j]) - 1; + + return res; + +} + +template +void GroupSampler::sample_n( + std::vector< size_t > & sample, + const int origin_group, + const int nsamples +) +{ + + for (int i = 0; i < nsamples; ++i) + sample[i] = sample_1(origin_group); + + return; + +} + + template @@ -182,7 +285,7 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( return; // Sampling from the agent's entities - auto sample = epiworld::AgentsSample(m, *p, ndraw, {}, true); + epiworld::AgentsSample sample(m, *p, ndraw, {}, true); // Drawing from the set int nviruses_tmp = 0;