Skip to content

Commit

Permalink
Refactoring virus dist - expected to fail [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Jun 6, 2024
1 parent cf1c574 commit 5cc1bcd
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 288 deletions.
279 changes: 135 additions & 144 deletions epiworld.hpp

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion include/epiworld/epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#ifndef EPIWORLD_HPP
#define EPIWORLD_HPP


/* Versioning */
#define EPIWORLD_VERSION_MAJOR 0
#define EPIWORLD_VERSION_MINOR 3
Expand Down
5 changes: 0 additions & 5 deletions include/epiworld/model-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ class Model {
bool directed = false;

std::vector< VirusPtr<TSeq> > viruses = {};
std::vector< epiworld_double > prevalence_virus = {}; ///< Initial prevalence_virus of each virus
std::vector< bool > prevalence_virus_as_proportion = {};
std::vector< VirusToAgentFun<TSeq> > viruses_dist_funs = {};

std::vector< ToolPtr<TSeq> > tools = {};
std::vector< epiworld_double > prevalence_tool = {};
Expand Down Expand Up @@ -704,8 +701,6 @@ class Model {
///@}

const std::vector< VirusPtr<TSeq> > & get_viruses() const;
const std::vector< epiworld_double > & get_prevalence_virus() const;
const std::vector< bool > & get_prevalence_virus_as_proportion() const;
const std::vector< ToolPtr<TSeq> > & get_tools() const;
Virus<TSeq> & get_virus(size_t id);
Tool<TSeq> & get_tool(size_t id);
Expand Down
14 changes: 8 additions & 6 deletions include/epiworld/model-meat-print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ inline const Model<TSeq> & Model<TSeq>::print(bool lite) const
for (size_t i = 0u; i < n_viruses_model; ++i)
{


const auto & virus = viruses[i];
if ((n_viruses_model > 10) && (i >= 10))
{
printf_epiworld(" ...and %i more viruses...\n",
Expand All @@ -155,13 +157,13 @@ inline const Model<TSeq> & Model<TSeq>::print(bool lite) const
if (i < n_viruses_model)
{

if (prevalence_virus_as_proportion[i])
if (virus.get_prevalence_as_proportion())
{

printf_epiworld(
" - %s (baseline prevalence: %.2f%%)\n",
viruses[i]->get_name().c_str(),
prevalence_virus[i] * 100.00
virus.get_name().c_str(),
virus.get_prevalence() * 100.00
);

}
Expand All @@ -170,8 +172,8 @@ inline const Model<TSeq> & Model<TSeq>::print(bool lite) const

printf_epiworld(
" - %s (baseline prevalence: %i seeds)\n",
viruses[i]->get_name().c_str(),
static_cast<int>(prevalence_virus[i])
virus.get_name().c_str(),
static_cast<int>(virus.get_prevalence())
);

}
Expand All @@ -180,7 +182,7 @@ inline const Model<TSeq> & Model<TSeq>::print(bool lite) const

printf_epiworld(
" - %s (originated in the model...)\n",
viruses[i]->get_name().c_str()
virus.get_name().c_str()
);

}
Expand Down
108 changes: 2 additions & 106 deletions include/epiworld/model-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,6 @@ inline Model<TSeq>::Model(const Model<TSeq> & model) :
population_backup(model.population_backup),
directed(model.directed),
viruses(model.viruses),
prevalence_virus(model.prevalence_virus),
prevalence_virus_as_proportion(model.prevalence_virus_as_proportion),
viruses_dist_funs(model.viruses_dist_funs),
tools(model.tools),
prevalence_tool(model.prevalence_tool),
prevalence_tool_as_proportion(model.prevalence_tool_as_proportion),
Expand Down Expand Up @@ -452,9 +449,6 @@ inline Model<TSeq>::Model(Model<TSeq> && model) :
directed(std::move(model.directed)),
// Virus
viruses(std::move(model.viruses)),
prevalence_virus(std::move(model.prevalence_virus)),
prevalence_virus_as_proportion(std::move(model.prevalence_virus_as_proportion)),
viruses_dist_funs(std::move(model.viruses_dist_funs)),
// Tools
tools(std::move(model.tools)),
prevalence_tool(std::move(model.prevalence_tool)),
Expand Down Expand Up @@ -528,9 +522,6 @@ inline Model<TSeq> & Model<TSeq>::operator=(const Model<TSeq> & m)
directed = m.directed;

viruses = m.viruses;
prevalence_virus = m.prevalence_virus;
prevalence_virus_as_proportion = m.prevalence_virus_as_proportion;
viruses_dist_funs = m.viruses_dist_funs;

tools = m.tools;
prevalence_tool = m.prevalence_tool;
Expand Down Expand Up @@ -760,62 +751,10 @@ template<typename TSeq>
inline void Model<TSeq>::dist_virus()
{

// Starting first infection
int n = size();
std::vector< size_t > idx(n, 0u);
std::iota(idx.begin(), idx.end(), 0);
int n_left = idx.size();

for (size_t v = 0u; v < viruses.size(); ++v)
for (auto & v: viruses)
{

if (viruses_dist_funs[v])
{

viruses_dist_funs[v](*viruses[v], this);

} else {

// Picking how many
int nsampled;
if (prevalence_virus_as_proportion[v])
{
nsampled = static_cast<int>(std::floor(prevalence_virus[v] * size()));
}
else
{
nsampled = static_cast<int>(prevalence_virus[v]);
}

if (nsampled > static_cast<int>(size()))
throw std::range_error("There are only " + std::to_string(size()) +
" individuals in the population. Cannot add the virus to " + std::to_string(nsampled));


VirusPtr<TSeq> virus = viruses[v];

while (nsampled > 0)
{

int loc = static_cast<epiworld_fast_uint>(floor(runif() * (n_left--)));

Agent<TSeq> & agent = population[idx[loc]];

// Adding action
agent.set_virus(
virus,
const_cast<Model<TSeq> * >(this),
virus->state_init,
virus->queue_init
);

// Adjusting sample
nsampled--;
std::swap(idx[loc], idx[n_left]);

}

}
v.distribute(this);

// Apply the events
events_run();
Expand Down Expand Up @@ -1060,9 +999,6 @@ inline void Model<TSeq>::add_virus(Virus<TSeq> & v, epiworld_double preval)

// Adding new virus
viruses.push_back(std::make_shared< Virus<TSeq> >(v));
prevalence_virus.push_back(preval);
prevalence_virus_as_proportion.push_back(true);
viruses_dist_funs.push_back(nullptr);

}

Expand All @@ -1088,9 +1024,6 @@ inline void Model<TSeq>::add_virus_n(Virus<TSeq> & v, epiworld_fast_uint preval)

// Adding new virus
viruses.push_back(std::make_shared< Virus<TSeq> >(v));
prevalence_virus.push_back(preval);
prevalence_virus_as_proportion.push_back(false);
viruses_dist_funs.push_back(nullptr);

}

Expand All @@ -1117,9 +1050,6 @@ inline void Model<TSeq>::add_virus_fun(Virus<TSeq> & v, VirusToAgentFun<TSeq> fu

// Adding new virus
viruses.push_back(std::make_shared< Virus<TSeq> >(v));
prevalence_virus.push_back(0.0);
prevalence_virus_as_proportion.push_back(false);
viruses_dist_funs.push_back(fun);

}

Expand Down Expand Up @@ -1212,17 +1142,7 @@ inline void Model<TSeq>::rm_virus(size_t virus_pos)

// Flipping with the last one
std::swap(viruses[virus_pos], viruses[viruses.size() - 1]);
std::swap(viruses_dist_funs[virus_pos], viruses_dist_funs[viruses.size() - 1]);
std::swap(prevalence_virus[virus_pos], prevalence_virus[viruses.size() - 1]);
std::vector<bool>::swap(
prevalence_virus_as_proportion[virus_pos],
prevalence_virus_as_proportion[viruses.size() - 1]
);

viruses.pop_back();
viruses_dist_funs.pop_back();
prevalence_virus.pop_back();
prevalence_virus_as_proportion.pop_back();

return;

Expand Down Expand Up @@ -2606,18 +2526,6 @@ inline const std::vector< VirusPtr<TSeq> > & Model<TSeq>::get_viruses() const
return viruses;
}

template<typename TSeq>
inline const std::vector< epiworld_double > & Model<TSeq>::get_prevalence_virus() const
{
return prevalence_virus;
}

template<typename TSeq>
inline const std::vector< bool > & Model<TSeq>::get_prevalence_virus_as_proportion() const
{
return prevalence_virus_as_proportion;
}

template<typename TSeq>
const std::vector< ToolPtr<TSeq> > & Model<TSeq>::get_tools() const
{
Expand Down Expand Up @@ -2746,18 +2654,6 @@ inline bool Model<TSeq>::operator==(const Model<TSeq> & other) const
)

}

VECT_MATCH(
prevalence_virus,
other.prevalence_virus,
"virus prevalence don't match"
)

VECT_MATCH(
prevalence_virus_as_proportion,
other.prevalence_virus_as_proportion,
"virus prevalence as prop don't match"
)

// Tools -------------------------------------------------------------------
EPI_DEBUG_FAIL_AT_TRUE(
Expand Down
40 changes: 16 additions & 24 deletions include/epiworld/models/init-functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_sir(

// Figuring out information about the viruses
double tot = 0.0;
const auto & vpreval = model->get_prevalence_virus();
const auto & vprop = model->get_prevalence_virus_as_proportion();
double n = static_cast<double>(model->size());
for (size_t i = 0u; i < model->get_n_viruses(); ++i)
for (const auto & virus: model->get_viruses())
{
if (vprop[i])
tot += vpreval[i];
if (virus.get_prevalence_as_proportion())
tot += virus.get_prevalence();
else
tot += vpreval[i] / n;
tot += virus.get_prevalence() / n;
}

// Putting the total into context
Expand Down Expand Up @@ -105,15 +103,13 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_sird(

// Figuring out information about the viruses
double tot = 0.0;
const auto & vpreval = model->get_prevalence_virus();
const auto & vprop = model->get_prevalence_virus_as_proportion();
double n = static_cast<double>(model->size());
for (size_t i = 0u; i < model->get_n_viruses(); ++i)
for (const auto & virus: model->get_viruses())
{
if (vprop[i])
tot += vpreval[i];
if (virus.get_prevalence_as_proportion())
tot += virus.get_prevalence();
else
tot += vpreval[i] / n;
tot += virus.get_prevalence() / n;
}

// Putting the total into context
Expand Down Expand Up @@ -185,15 +181,13 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_seir(

// Figuring out information about the viruses
double tot = 0.0;
const auto & vpreval = model->get_prevalence_virus();
const auto & vprop = model->get_prevalence_virus_as_proportion();
double n = static_cast<double>(model->size());
for (size_t i = 0u; i < model->get_n_viruses(); ++i)
for (const auto & virus: model->get_viruses())
{
if (vprop[i])
tot += vpreval[i];
if (virus.get_prevalence_as_proportion())
tot += virus.get_prevalence();
else
tot += vpreval[i] / n;
tot += virus.get_prevalence() / n;
}

// Putting the total into context
Expand Down Expand Up @@ -269,15 +263,13 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_seird(

// Figuring out information about the viruses
double tot = 0.0;
const auto & vpreval = model->get_prevalence_virus();
const auto & vprop = model->get_prevalence_virus_as_proportion();
double n = static_cast<double>(model->size());
for (size_t i = 0u; i < model->get_n_viruses(); ++i)
for (const auto & virus: model->get_viruses())
{
if (vprop[i])
tot += vpreval[i];
if (virus.get_prevalence_as_proportion())
tot += virus.get_prevalence();
else
tot += vpreval[i] / n;
tot += virus.get_prevalence() / n;
}

// Putting the total into context
Expand Down
24 changes: 23 additions & 1 deletion include/epiworld/virus-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,18 @@ class Virus {
epiworld_fast_int queue_post = -Queue<TSeq>::Everyone; ///< Change of state when removed from agent.
epiworld_fast_int queue_removed = -99; ///< Change of state when agent is removed

// Information about how distribution works
epiworld_double prevalence = 0.0;
bool prevalence_as_proportion = false;
VirusToAgentFun<TSeq> dist_fun = nullptr;

public:
Virus(std::string name = "unknown virus");
Virus(
std::string name = "unknown virus",
epiworld_double prevalence = 0.0,
bool prevalence_as_proportion = false,
VirusToAgentFun<TSeq> dist_fun = nullptr
);

void mutate(Model<TSeq> * model);
void set_mutation(MutFun<TSeq> fun);
Expand Down Expand Up @@ -156,6 +166,18 @@ class Virus {

void print() const;

/**
* @brief Get information about the prevalence of the virus
*/
///@{
epiworld_double get_prevalence() const;
void set_prevalence(epiworld_double prevalence);
bool get_prevalence_as_proportion() const;
void set_prevalence_as_proportion(bool prevalence_as_proportion);
void distribute(Model<TSeq> * model);
///@}


};

#endif
Loading

0 comments on commit 5cc1bcd

Please sign in to comment.