Skip to content

Commit

Permalink
Faster sampler for connected models (#13)
Browse files Browse the repository at this point in the history
* Replacing how connected model samples

* Updating the rate daily

* Adding feature to seir conn

* Adding faster sampler for SEIRD connected

* Missed updating infected during reset of SEIRDconn

* Fixing testing

* More generous marging for catch
  • Loading branch information
gvegayon authored Apr 25, 2024
1 parent 22ecc64 commit 9d139e1
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 115 deletions.
3 changes: 2 additions & 1 deletion .vscode/c_cpp_properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
"includePath": [
"${workspaceFolder}/**",
"/usr/include",
"/usr/local/include"
"/usr/local/include",
"include/epiworld"
],
"defines": [],
"compilerPath": "/usr/bin/gcc",
Expand Down
4 changes: 4 additions & 0 deletions include/epiworld/model-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,11 @@ inline void Model<TSeq>::run_multiple(
std::function<void(size_t,Model<TSeq>*)> fun,
bool reset,
bool verbose,
#ifdef _OPENMP
int nthreads
#else
int
#endif
)
{

Expand Down
101 changes: 71 additions & 30 deletions include/epiworld/models/seirconnected.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
template<typename TSeq = EPI_DEFAULT_TSEQ>
class ModelSEIRCONN : public epiworld::Model<TSeq>
{
private:
std::vector< epiworld::Agent<TSeq> * > infected;
void update_infected();

public:

static const int SUSCEPTIBLE = 0;
Expand Down Expand Up @@ -54,8 +58,35 @@ class ModelSEIRCONN : public epiworld::Model<TSeq>
std::vector< int > queue_ = {}
);

size_t get_n_infected() const { return infected.size(); }

};

template<typename TSeq>
inline void ModelSEIRCONN<TSeq>::update_infected()
{

infected.clear();
infected.reserve(this->size());

for (auto & p : this->get_agents())
{
if (p.get_state() == ModelSEIRCONN<TSeq>::INFECTED)
{
infected.push_back(&p);
}
}

Model<TSeq>::set_rand_binom(
this->get_n_infected(),
static_cast<double>(Model<TSeq>::par("Contact rate"))/
static_cast<double>(Model<TSeq>::size())
);

return;

}

template<typename TSeq>
inline ModelSEIRCONN<TSeq> & ModelSEIRCONN<TSeq>::run(
epiworld_fast_uint ndays,
Expand All @@ -74,13 +105,7 @@ inline void ModelSEIRCONN<TSeq>::reset()
{

Model<TSeq>::reset();

Model<TSeq>::set_rand_binom(
Model<TSeq>::size(),
static_cast<double>(
Model<TSeq>::par("Contact rate"))/
static_cast<double>(Model<TSeq>::size())
);
this->update_infected();

return;

Expand Down Expand Up @@ -133,13 +158,16 @@ inline ModelSEIRCONN<TSeq>::ModelSEIRCONN(
if (ndraw == 0)
return;

ModelSEIRCONN<TSeq> * model = dynamic_cast<ModelSEIRCONN<TSeq> *>(m);
size_t ninfected = model->get_n_infected();

// Drawing from the set
int nviruses_tmp = 0;
for (int i = 0; i < ndraw; ++i)
{
// Now selecting who is transmitting the disease
int which = static_cast<int>(
std::floor(m->size() * m->runif())
std::floor(ninfected * m->runif())
);

/* There is a bug in which runif() returns 1.0. It is rare, but
Expand All @@ -149,35 +177,32 @@ inline ModelSEIRCONN<TSeq>::ModelSEIRCONN(
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176
*
*/
if (which == static_cast<int>(m->size()))
if (which == static_cast<int>(ninfected))
--which;

epiworld::Agent<TSeq> & neighbor = *model->infected[which];

// Can't sample itself
if (which == static_cast<int>(p->get_id()))
if (neighbor.get_id() == p->get_id())
continue;

// If the neighbor is infected, then proceed
auto & neighbor = m->get_agents()[which];
if (neighbor.get_state() == ModelSEIRCONN<TSeq>::INFECTED)
{
// The neighbor is infected by construction
auto & v = neighbor.get_virus();

auto & v = neighbor.get_virus();

#ifdef EPI_DEBUG
if (nviruses_tmp >= static_cast<int>(m->array_virus_tmp.size()))
throw std::logic_error("Trying to add an extra element to a temporal array outside of the range.");
#endif

/* And it is a function of susceptibility_reduction as well */
m->array_double_tmp[nviruses_tmp] =
(1.0 - p->get_susceptibility_reduction(v, m)) *
v->get_prob_infecting(m) *
(1.0 - neighbor.get_transmission_reduction(v, m))
;

m->array_virus_tmp[nviruses_tmp++] = &(*v);
#ifdef EPI_DEBUG
if (nviruses_tmp >= static_cast<int>(m->array_virus_tmp.size()))
throw std::logic_error("Trying to add an extra element to a temporal array outside of the range.");
#endif

/* And it is a function of susceptibility_reduction as well */
m->array_double_tmp[nviruses_tmp] =
(1.0 - p->get_susceptibility_reduction(v, m)) *
v->get_prob_infecting(m) *
(1.0 - neighbor.get_transmission_reduction(v, m))
;

m->array_virus_tmp[nviruses_tmp++] = &(*v);

}
}

// No virus to compute
Expand Down Expand Up @@ -279,6 +304,22 @@ inline ModelSEIRCONN<TSeq>::ModelSEIRCONN(
model.add_state("Infected", update_infected);
model.add_state("Recovered");

// Adding update function
epiworld::GlobalFun<TSeq> update = [](
epiworld::Model<TSeq> * m
) -> void
{

ModelSEIRCONN<TSeq> * model = dynamic_cast<ModelSEIRCONN<TSeq> *>(m);

model->update_infected();

return;

};

model.add_globalevent(update, "Update infected individuals");


// Preparing the virus -------------------------------------------
epiworld::Virus<TSeq> virus(vname);
Expand Down
103 changes: 71 additions & 32 deletions include/epiworld/models/seirdconnected.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
template<typename TSeq = EPI_DEFAULT_TSEQ>
class ModelSEIRDCONN : public epiworld::Model<TSeq>
{
private:
std::vector< epiworld::Agent<TSeq> * > infected;
void update_infected();

public:

static const int SUSCEPTIBLE = 0;
Expand All @@ -12,7 +16,6 @@ class ModelSEIRDCONN : public epiworld::Model<TSeq>
static const int REMOVED = 3;
static const int DECEASED = 4;


ModelSEIRDCONN() {};

ModelSEIRDCONN(
Expand Down Expand Up @@ -58,8 +61,36 @@ class ModelSEIRDCONN : public epiworld::Model<TSeq>
std::vector< int > queue_ = {}
);

size_t get_n_infected() const
{
return infected.size();
}

};

template<typename TSeq>
inline void ModelSEIRDCONN<TSeq>::update_infected()
{
infected.clear();
infected.reserve(this->size());

for (auto & p : this->get_agents())
{
if (p.get_state() == ModelSEIRDCONN<TSeq>::INFECTED)
{
infected.push_back(&p);
}
}

Model<TSeq>::set_rand_binom(
this->get_n_infected(),
static_cast<double>(Model<TSeq>::par("Contact rate"))/
static_cast<double>(Model<TSeq>::size())
);

return;
}

template<typename TSeq>
inline ModelSEIRDCONN<TSeq> & ModelSEIRDCONN<TSeq>::run(
epiworld_fast_uint ndays,
Expand All @@ -79,12 +110,7 @@ inline void ModelSEIRDCONN<TSeq>::reset()

Model<TSeq>::reset();

Model<TSeq>::set_rand_binom(
Model<TSeq>::size(),
static_cast<double>(
Model<TSeq>::par("Contact rate"))/
static_cast<double>(Model<TSeq>::size())
);
this->update_infected();

return;

Expand Down Expand Up @@ -139,13 +165,19 @@ inline ModelSEIRDCONN<TSeq>::ModelSEIRDCONN(
if (ndraw == 0)
return;

ModelSEIRDCONN<TSeq> * model = dynamic_cast<ModelSEIRDCONN<TSeq> *>(
m
);

size_t ninfected = model->get_n_infected();

// Drawing from the set
int nviruses_tmp = 0;
for (int i = 0; i < ndraw; ++i)
{
// Now selecting who is transmitting the disease
int which = static_cast<int>(
std::floor(m->size() * m->runif())
std::floor(ninfected * m->runif())
);

/* There is a bug in which runif() returns 1.0. It is rare, but
Expand All @@ -155,36 +187,31 @@ inline ModelSEIRDCONN<TSeq>::ModelSEIRDCONN(
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176
*
*/
if (which == static_cast<int>(m->size()))
if (which == static_cast<int>(ninfected))
--which;

epiworld::Agent<TSeq> & neighbor = *model->infected[which];

// Can't sample itself
if (which == static_cast<int>(p->get_id()))
if (neighbor.get_id() == p->get_id())
continue;

// If the neighbor is infected, then proceed
auto & neighbor = m->get_agents()[which];
if (neighbor.get_state() == ModelSEIRDCONN<TSeq>::INFECTED)
{

const auto & v = neighbor.get_virus();


#ifdef EPI_DEBUG
if (nviruses_tmp >= static_cast<int>(m->array_virus_tmp.size()))
throw std::logic_error("Trying to add an extra element to a temporal array outside of the range.");
#endif

/* And it is a function of susceptibility_reduction as well */
m->array_double_tmp[nviruses_tmp] =
(1.0 - p->get_susceptibility_reduction(v, m)) *
v->get_prob_infecting(m) *
(1.0 - neighbor.get_transmission_reduction(v, m))
;

m->array_virus_tmp[nviruses_tmp++] = &(*v);
// All neighbors in this set are infected by construction
const auto & v = neighbor.get_virus();

#ifdef EPI_DEBUG
if (nviruses_tmp >= static_cast<int>(m->array_virus_tmp.size()))
throw std::logic_error("Trying to add an extra element to a temporal array outside of the range.");
#endif

}
/* And it is a function of susceptibility_reduction as well */
m->array_double_tmp[nviruses_tmp] =
(1.0 - p->get_susceptibility_reduction(v, m)) *
v->get_prob_infecting(m) *
(1.0 - neighbor.get_transmission_reduction(v, m))
;

m->array_virus_tmp[nviruses_tmp++] = &(*v);
}

// No virus to compute
Expand Down Expand Up @@ -301,6 +328,18 @@ inline ModelSEIRDCONN<TSeq>::ModelSEIRDCONN(
model.add_state("Deceased");


// Adding update function
epiworld::GlobalFun<TSeq> update = [](epiworld::Model<TSeq> * m) -> void
{
ModelSEIRDCONN<TSeq> * model = dynamic_cast<ModelSEIRDCONN<TSeq> *>(m);
model->update_infected();

return;
};

model.add_globalevent(update, "Update infected individuals");


// Preparing the virus -------------------------------------------
epiworld::Virus<TSeq> virus(vname);
virus.set_state(
Expand Down
Loading

0 comments on commit 9d139e1

Please sign in to comment.