Skip to content

Commit

Permalink
Adding function to sample from contact matrix more efficiently
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Apr 12, 2024
1 parent 67ce649 commit f5fe2cb
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/11-entities/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ int main() {

epimodels::ModelSEIREntitiesConn model(
"Flu", // std::string vname,
10000, // epiworld_fast_uint n,
100000, // epiworld_fast_uint n,
0.01,// epiworld_double prevalence,
4.0,// epiworld_double contact_rate,
0.1,// epiworld_double transmission_rate,
Expand Down
10 changes: 5 additions & 5 deletions include/epiworld/agentssample-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,6 @@ inline AgentsSample<TSeq>::AgentsSample(
agents = &agent_.sampled_agents;
agents_n = &agent_.sampled_agents_n;

agents_left = &agent_.sampled_agents_left;
agents_left_n = &agent_.sampled_agents_left_n;

// Computing the cumulative sum of counts across entities
size_t agents_in_entities = 0;
Entities<TSeq> entities_a = agent->get_entities();
Expand Down Expand Up @@ -249,14 +246,17 @@ inline AgentsSample<TSeq>::AgentsSample(
else
agent_idx = entities_a[e][jth - cum_agents_count[e - 1]]->get_id();

// Getting the state
size_t state = model->population[agent_idx].get_state();

// Checking if states was specified
if (states.size())
{

// Getting the state
size_t state = model->population[agent_idx].get_state();

if (std::find(states.begin(), states.end(), state) != states.end())
continue;

}

agents->operator[](i_obs++) = &(model->population[agent_idx]);
Expand Down
107 changes: 105 additions & 2 deletions include/epiworld/models/seirentitiesconnected.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,110 @@ class ModelSEIREntitiesConn : public epiworld::Model<TSeq>

};

// Global event that moves agents between states
template<typename TSeq>
class GroupSampler {

private:

epiworld::Model<TSeq> & model;
const std::vector< double > & contact_matrix; ///< Contact matrix between groups
const std::vector< size_t > & group_sizes; ///< Sizes of the groups
std::vector< double > cumulate; ///< Cumulative sum of the contact matrix (row-major for faster access)

/**
* @brief Get the index of the contact matrix
*
* The matrix is a vector stored in column-major order.
*
* @param i Index of the row
* @param j Index of the column
* @return Index of the contact matrix
*/
inline int idx(const int i, const int j, bool rowmajor = false) const
{

if (rowmajor)
return i * group_sizes.size() + j;

return j * group_sizes.size() + i;

}

public:

GroupSampler(
epiworld::Model<TSeq> & model,
const std::vector< double > & contact_matrix,
const std::vector< size_t > & group_sizes
): model(model), contact_matrix(contact_matrix), group_sizes(group_sizes) {

this->cumulate.resize(contact_matrix.size());
std::fill(cumulate.begin(), cumulate.end(), 0.0);

// Cumulative sum
for (size_t j = 1; j < group_sizes.size(); ++j)
{
for (size_t i = 0; i < group_sizes.size(); ++i)
cumulate[idx(i, j, true)] +=
cumulate[idx(i, j - 1, true)] +
contact_matrix[idx(i, j)];
}

};

int sample_1(const int origin_group);

void sample_n(
std::vector< size_t > & sample,
const int origin_group,
const int nsamples
);

};

template<typename TSeq>
int GroupSampler<TSeq>::sample_1(const int origin_group)
{

// Random number
double r = model.runif();

// Finding the group
size_t j = 0;
while (r > cumulate[idx(origin_group, j, true)])
++j;

// Adjusting the prob
r = r - (j == 0 ? 0.0 : cumulate[idx(origin_group, j - 1, true)]);

int res = static_cast<int>(
std::floor(r * group_sizes[j])
);

// Making sure we are not picling outside of the group
if (res >= static_cast<int>(group_sizes[j]))
res = static_cast<int>(group_sizes[j]) - 1;

return res;

}

template<typename TSeq>
void GroupSampler<TSeq>::sample_n(
std::vector< size_t > & sample,
const int origin_group,
const int nsamples
)
{

for (int i = 0; i < nsamples; ++i)
sample[i] = sample_1(origin_group);

return;

}




template<typename TSeq>
Expand Down Expand Up @@ -182,7 +285,7 @@ inline ModelSEIREntitiesConn<TSeq>::ModelSEIREntitiesConn(
return;

// Sampling from the agent's entities
auto sample = epiworld::AgentsSample<TSeq>(m, *p, ndraw, {}, true);
epiworld::AgentsSample<TSeq> sample(m, *p, ndraw, {}, true);

// Drawing from the set
int nviruses_tmp = 0;
Expand Down

0 comments on commit f5fe2cb

Please sign in to comment.