Skip to content

Commit

Permalink
Reset entities properly
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Jun 11, 2024
1 parent 030a51d commit c22f77b
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 88 deletions.
244 changes: 187 additions & 57 deletions epiworld.hpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions include/epiworld/agent-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ class Agent {
bool has_virus(epiworld_fast_uint t) const;
bool has_virus(std::string name) const;
bool has_virus(const Virus<TSeq> & v) const;
bool has_entity(epiworld_fast_uint t) const;
bool has_entity(std::string name) const;

void print(Model<TSeq> * model, bool compressed = false) const;

Expand Down
34 changes: 34 additions & 0 deletions include/epiworld/agent-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ inline void Agent<TSeq>::reset()
this->tools.clear();
n_tools = 0u;

this->entities.clear();
this->entities_locations.clear();
this->n_entities = 0u;

this->state = 0u;
this->state_prev = 0u;

Expand Down Expand Up @@ -702,6 +706,30 @@ inline bool Agent<TSeq>::has_virus(const Virus<TSeq> & virus) const

}

template<typename TSeq>
inline bool Agent<TSeq>::has_entity(epiworld_fast_uint t) const
{

for (auto & entity : entities)
if (entity == t)
return true;

return false;

}

template<typename TSeq>
inline bool Agent<TSeq>::has_entity(std::string name) const
{

for (auto & entity : entities)
if (model->get_entity(entity).get_name() == name)
return true;

return false;

}

template<typename TSeq>
inline void Agent<TSeq>::print(
Model<TSeq> * model,
Expand Down Expand Up @@ -813,6 +841,9 @@ inline const Entities_const<TSeq> Agent<TSeq>::get_entities() const
template<typename TSeq>
inline const Entity<TSeq> & Agent<TSeq>::get_entity(size_t i) const
{
if (n_entities == 0)
throw std::range_error("Agent id " + std::to_string(id) + " has no entities.");

if (i >= n_entities)
throw std::range_error("Trying to get to an agent's entity outside of the range.");

Expand All @@ -822,6 +853,9 @@ inline const Entity<TSeq> & Agent<TSeq>::get_entity(size_t i) const
template<typename TSeq>
inline Entity<TSeq> & Agent<TSeq>::get_entity(size_t i)
{
if (n_entities == 0)
throw std::range_error("Agent id " + std::to_string(id) + " has no entities.");

if (i >= n_entities)
throw std::range_error("Trying to get to an agent's entity outside of the range.");

Expand Down
6 changes: 3 additions & 3 deletions include/epiworld/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ template<typename TSeq>
struct Event;

template<typename TSeq>
using ActionFun = std::function<void(Event<TSeq>&,Model<TSeq>*)>;
using EventFun = std::function<void(Event<TSeq>&,Model<TSeq>*)>;

/**
* @brief Decides how to distribute viruses at initialization
Expand Down Expand Up @@ -124,7 +124,7 @@ struct Event {
Entity<TSeq> * entity;
epiworld_fast_int new_state;
epiworld_fast_int queue;
ActionFun<TSeq> call;
EventFun<TSeq> call;
int idx_agent;
int idx_object;
public:
Expand All @@ -151,7 +151,7 @@ struct Event {
Entity<TSeq> * entity_,
epiworld_fast_int new_state_,
epiworld_fast_int queue_,
ActionFun<TSeq> call_,
EventFun<TSeq> call_,
int idx_agent_,
int idx_object_
) : agent(agent_), virus(virus_), tool(tool_), entity(entity_),
Expand Down
19 changes: 11 additions & 8 deletions include/epiworld/entity-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ inline void Entity<TSeq>::reset()
sampled_agents_left.clear();
sampled_agents_left_n = 0u;

// Removing agents from entities
for (size_t i = 0u; i < n_agents; ++i)
this->rm_agent(i);
this->agents.clear();
this->n_agents = 0u;
this->agents_location.clear();

return;

Expand Down Expand Up @@ -358,7 +358,7 @@ inline void Entity<TSeq>::distribute()

int n_left = n;
std::iota(idx.begin(), idx.end(), 0);
for (int i = 0; i < n_to_assign; ++i)
while (n_to_assign > 0)
{
int loc = static_cast<epiworld_fast_uint>(
floor(model->runif() * n_left--)
Expand All @@ -368,10 +368,13 @@ inline void Entity<TSeq>::distribute()
if ((loc > 0) && (loc >= n_left))
loc = n_left - 1;

model->get_agent(idx[loc]).add_entity(
*this, this->model, this->state_init, this->queue_init
);

auto & agent = model->get_agent(idx[loc]);

if (!agent.has_entity(id))
agent.add_entity(
*this, this->model, this->state_init, this->queue_init
);

std::swap(idx[loc], idx[n_left]);

}
Expand Down
2 changes: 1 addition & 1 deletion include/epiworld/epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
/* Versioning */
#define EPIWORLD_VERSION_MAJOR 0
#define EPIWORLD_VERSION_MINOR 3
#define EPIWORLD_VERSION_PATCH 0
#define EPIWORLD_VERSION_PATCH 1

static const int epiworld_version_major = EPIWORLD_VERSION_MAJOR;
static const int epiworld_version_minor = EPIWORLD_VERSION_MINOR;
Expand Down
2 changes: 1 addition & 1 deletion include/epiworld/model-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class Model {
Entity<TSeq> * entity_,
epiworld_fast_int new_state_,
epiworld_fast_int queue_,
ActionFun<TSeq> call_,
EventFun<TSeq> call_,
int idx_agent_,
int idx_object_
);
Expand Down
4 changes: 2 additions & 2 deletions include/epiworld/model-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ inline void Model<TSeq>::events_add(
Entity<TSeq> * entity_,
epiworld_fast_int new_state_,
epiworld_fast_int queue_,
ActionFun<TSeq> call_,
EventFun<TSeq> call_,
int idx_agent_,
int idx_object_
) {
Expand All @@ -166,7 +166,7 @@ inline void Model<TSeq>::events_add(

#ifdef EPI_DEBUG
if (nactions == 0)
throw std::logic_error("Actions cannot be zero!!");
throw std::logic_error("Events cannot be zero!!");
#endif

if (nactions > events.size())
Expand Down
32 changes: 16 additions & 16 deletions tests/05-mixing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,22 @@ EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-mixing]") {
REQUIRE_THAT(totals, Catch::Equals(expected_totals));
#endif

// // If entities don't have a dist function, then it should be
// // OK
// e1.set_dist_fun(nullptr);
// e2.set_dist_fun(nullptr);
// e3.set_dist_fun(nullptr);

// model.rm_entity(0);
// model.rm_entity(1);
// model.rm_entity(2);

// model.add_entity(e1);
// model.add_entity(e2);
// model.add_entity(e3);

// // Running and checking the results
// model.run(50, 123);
// If entities don't have a dist function, then it should be
// OK
e1.set_dist_fun(nullptr);
e2.set_dist_fun(nullptr);
e3.set_dist_fun(nullptr);

model.rm_entity(0);
model.rm_entity(1);
model.rm_entity(2);

model.add_entity(e1);
model.add_entity(e2);
model.add_entity(e3);

// Running and checking the results
model.run(50, 123);


#ifndef CATCH_CONFIG_MAIN
Expand Down

0 comments on commit c22f77b

Please sign in to comment.