Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input the contacts in parallel from numpy #2

Merged
merged 10 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions bp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,60 @@ void FactorGraph::drop_contacts(times_t t)
}
}

void FactorGraph::check_neighbors(int i, int j){
if (i == j)
throw invalid_argument("self loops are not allowed");
add_node(i);
add_node(j);
Node & fi = nodes[i];
Node & fj = nodes[j];
int ki = find_neighbor(i, j);
int kj = find_neighbor(j, i);
//check neighbors are mutual
if (ki == int(fi.neighs.size())) {
assert(kj == int(fj.neighs.size()));
fi.neighs.push_back(Neigh(j, kj));
fj.neighs.push_back(Neigh(i, ki));
}
}

void FactorGraph::add_contact_single(int i, int j, times_t t, real_t lambdaij){
Node & fi = nodes[i];
int qi = fi.times.size();
if (fi.times[qi - 2] > t)
throw invalid_argument("time of contacts should be ordered");
int ki = find_neighbor(i, j);
//add contact times for i & j
Neigh & ni = fi.neighs[ki];
if (fi.times[qi - 2] < t) {
fi.push_back_time(t);
++qi;
}
if (ni.t.size() < 2 || ni.t[ni.t.size() - 2] < qi - 2) {
//the time are not in the times
ni.t.back() = qi - 2;
ni.t.push_back(qi - 1);
if (lambdaij != DO_NOT_OVERWRITE)
ni.lambdas.back() = lambdaij;

ni.lambdas.push_back(0.0);
//expand the messages
++ni.msg;
} else if (ni.t[ni.t.size() - 2] == qi - 2) {
//times are already done, write the lambdas
if (lambdaij != DO_NOT_OVERWRITE)
ni.lambdas[ni.t.size() - 2] = lambdaij;

} else {
throw invalid_argument("time of contacts should be ordered");
}
// adjust infinite times
for (int k = 0; k < int(fi.neighs.size()); ++k) {
fi.neighs[k].t.back() = qi - 1;
}

}

void FactorGraph::append_contact(int i, int j, times_t t, real_t lambdaij, real_t lambdaji)
{
if (i == j)
Expand Down Expand Up @@ -640,6 +694,7 @@ real_t FactorGraph::iterate(int maxit, real_t tol, real_t damping, bool learn)
return err;
}


ostream & operator<<(ostream & ost, FactorGraph const & f)
{
int nasym = 0;
Expand Down
3 changes: 3 additions & 0 deletions bp.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct Message : public std::vector<T>
Message(size_t qj) : std::vector<T>(qj*qj), qj(qj) {}
void clear() { for (int i = 0; i < int(std::vector<T>::size()); ++i) std::vector<T>::operator[](i)*=0.0; }
size_t dim() const { return qj;}
//map reference from 2D to 1D
inline T & operator()(int sji, int sij) { return std::vector<T>::operator[](qj * sij + sji); }
inline T const & operator()(int sji, int sij) const { return std::vector<T>::operator[](qj * sij + sji); }
size_t qj;
Expand Down Expand Up @@ -102,6 +103,8 @@ class FactorGraph {
std::vector<std::tuple<int, std::shared_ptr<Proba>, std::shared_ptr<Proba>, std::shared_ptr<Proba>, std::shared_ptr<Proba>> > const & individuals = std::vector<std::tuple<int, std::shared_ptr<Proba>, std::shared_ptr<Proba>, std::shared_ptr<Proba>, std::shared_ptr<Proba>>>());
int find_neighbor(int i, int j) const;
void append_contact(int i, int j, times_t t, real_t lambdaij, real_t lambdaji = DO_NOT_OVERWRITE);
void check_neighbors(int i, int j);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l'aggiunta di queste due funzioni sembrano gli unici cambiamenti "veri" in bp.h e bp.cpp. Puoi lasciare solo questo in questo PR?

grazie mille fabio

void add_contact_single(int i, int j, times_t t, real_t lambdaij);
void drop_contacts(times_t t);
void append_observation(int i, std::shared_ptr<Test> const & o, times_t t);
void append_time(int i, times_t t);
Expand Down
92 changes: 92 additions & 0 deletions pysib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <boost/lexical_cast.hpp>
#include <iterator>
#include <exception>
#include <unordered_map>
#include "bp.h"
#include "drop.h"

Expand All @@ -27,7 +28,91 @@ namespace py = pybind11;
using namespace std;
using boost::lexical_cast;

template <typename T>
struct NpyArrayC{
typedef py::array_t<T,py::array::c_style | py::array::forcecast> typ;
};

void append_contacts_numpy(FactorGraph &G, NpyArrayC<int>::typ &from, NpyArrayC<int>::typ &to,
NpyArrayC<int>::typ &times, NpyArrayC<real_t>::typ &lambdas){

auto buf_i = from.request();
auto buf_j = to.request();
auto buf_t = times.request();

auto buf_lam = lambdas.request();

if (buf_i.ndim !=1 || buf_j.ndim !=1 || buf_t.ndim != 1 || buf_lam.ndim != 1)
{
throw std::runtime_error("Provide vectors of single dimension");
}

auto mlen = buf_i.shape[0];
if(buf_j.shape[0]!=mlen || buf_t.shape[0]!=mlen || buf_lam.shape[0]!=mlen){
throw std::runtime_error("Vectors have to be equal in length");
}

//pointers to memory
auto ptr_i = static_cast<int*>(buf_i.ptr);
auto ptr_j = static_cast<int*>(buf_j.ptr);
auto ptr_t = static_cast<int*>(buf_t.ptr);
auto ptr_lam = static_cast<real_t*>(buf_lam.ptr);
// first loop -> add nodes to the graph if needed
// second loop in parallel -> expand times and messages
// check for uniqueness of (i,j)
typedef std::unordered_map<int, vector<tuple<int, int, real_t> > > mapType;
unordered_map<int, vector<tuple<int, int, real_t> > > itolistmap;
for(int k=0; k<mlen; k++){
//cerr << ptr_i[k] << " -> "<<ptr_j[k] <<", t: "<<ptr_t[k]<<", lam: "<<ptr_lam[k]<<endl;
//G.append_contact(ptr_i[k], ptr_j[k], ptr_t[k], ptr_lam[k]);
auto i = ptr_i[k];
auto j = ptr_j[k];
auto t = ptr_t[k];
auto lam = ptr_lam[k];
G.check_neighbors(ptr_i[k], ptr_j[k]);
//G.add_contact_single(ptr_i[k], ptr_j[k], ptr_t[k],ptr_lam[k]);
//G.add_contact_single(ptr_j[k], ptr_i[k],ptr_t[k], FactorGraph::DO_NOT_OVERWRITE);
// find vector
auto vec_l = itolistmap.find(i);
if( vec_l != itolistmap.end()){
//append
vec_l->second.push_back( make_tuple(j, t, lam));
}else{
vector<tuple<int, int, real_t> > mvec {make_tuple(j,t,lam)};
itolistmap.emplace(make_pair(i,mvec));
}
//REVERSE
vec_l = itolistmap.find(j);
if( vec_l != itolistmap.end()){
//append
vec_l->second.push_back( make_tuple(i, t, FactorGraph::DO_NOT_OVERWRITE));
}else{
vector<tuple<int, int, real_t> > mvec{make_tuple(i,t,FactorGraph::DO_NOT_OVERWRITE)};
itolistmap.emplace(make_pair(j,mvec));
}

}

mapType::iterator mapIter;
#pragma omp parallel
{
#pragma omp single nowait
{
for(mapIter=itolistmap.begin();mapIter!=itolistmap.end();++mapIter) //construct the distance matrix
{
#pragma omp task firstprivate(mapIter)
{
int i = mapIter-> first;
auto listC = mapIter -> second;
for(auto vecIt = listC.begin(); vecIt!=listC.end(); ++vecIt){
G.add_contact_single(i, get<0>(*vecIt), get<1>(*vecIt), get<2>(*vecIt));
}
}
}
}
}

}



Expand Down Expand Up @@ -237,6 +322,13 @@ PYBIND11_MODULE(_sib, m) {
py::arg("lambdaij"),
py::arg("lambdaji") = real_t(FactorGraph::DO_NOT_OVERWRITE),
"appends a new contact from i to j at time t with transmission probabilities lambdaij, lambdaji")
.def("append_contacts_npy", &append_contacts_numpy,
py::arg("arr_i"),
py::arg("arr_j"),
py::arg("arr_t"),
py::arg("arr_lambs"),
"Append many contacts from numpy arrays"
)
.def("reset_observations", &FactorGraph::reset_observations,
py::arg("obs"),
"resets all observations")
Expand Down
Binary file added test/data/large_tree_data.npz
Binary file not shown.
5 changes: 5 additions & 0 deletions test/data/large_tree_pars.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"N": 19531,
"t_limit": 25,
"mu": 0.01
}
2 changes: 1 addition & 1 deletion test/data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def save_json(obj,file_,indent=1):
ALL_EPI = "all_epidemies"

NAMES_COLS_CONTACTS = ["t","i","j","lambda"]
CONTACTS_DTYPES = dict(zip(NAMES_COLS_CONTACTS,(np.int,np.int,np.int,np.float) ))
fabmazz marked this conversation as resolved.
Show resolved Hide resolved
CONTACTS_DTYPES = dict(zip(NAMES_COLS_CONTACTS,(int, int, int, float) ))



Expand Down
69 changes: 69 additions & 0 deletions test/test_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python
# coding: utf-8
import sys
sys.path.insert(0,"..")

import numpy as np
import json
import sib

with open("data/large_tree_pars.json") as f:
params = json.load(f)

cts_beliefs_f = np.load("data/large_tree_data.npz")

cts = cts_beliefs_f["cts"]

beliefs_all = cts_beliefs_f["beliefs"]

obs_all = {}
for k in cts_beliefs_f.files:
if "obs_" in k:
u=int(k.split("_")[-1])
#print(k, u)
obs_all[u] = cts_beliefs_f[k]

#close file
cts_beliefs_f.close()

sib_pars = sib.Params(prob_r=sib.Gamma(mu=params["mu"]))

N = params["N"]
t_limit = params["t_limit"]
cts_sib = [(int(r["i"]),int(r["j"]),int(r["t"]),r["lam"]) for r in cts]

tests = [sib.Test(s==0,s==1,s==2) for s in range(3)]
def make_obs_sib(N, t_limit,obs, tests):
obs_list_sib =[(i,-1,t) for t in [t_limit] for i in range(N) ]
obs_list_sib.extend([(r["i"],tests[r["st"]],r["t"]) for r in obs])

obs_list_sib.sort(key=lambda x: x[-1])

return obs_list_sib

callback = lambda t, err, fg: print(f"iter: {t:6}, err: {err:.5e} ", end="\r")

for ii,obs in obs_all.items():
fg = sib.FactorGraph(params=sib_pars)
beliefs = beliefs_all[ii]
#print(f"Instance {ii}")
#for c in cts_sib:
# fg.append_contact(*c)
fg.append_contacts_npy(cts["i"], cts["j"], cts["t"], cts["lam"])
obs_list_sib = make_obs_sib(N,t_limit, obs, tests)
for o in obs_list_sib:
fg.append_observation(*o)

sib.iterate(fg,200,1e-20,callback=callback )
print("")
s=0.
for i in range(len(fg.nodes)):
s+=np.abs(np.array(fg.nodes[i].bt)-beliefs[i][0]).sum()
s+=np.abs(np.array(fg.nodes[i].bg)-beliefs[i][1]).sum()
#fg.nodes[i].bg]))
print(f"instance {ii}: {s:4.3e} {s < 1e-10}")





3 changes: 1 addition & 2 deletions test/test_trees.doctest
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,4 @@ instance 45: True
instance 46: True
instance 47: True
instance 48: True
instance 49: True

instance 49: True
86 changes: 86 additions & 0 deletions test/test_trees_np.doctest
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
>>> from pathlib import Path
>>> import sys
>>> sys.path.insert(0,'test/')
>>> import numpy as np
>>> import sib
>>> import data_load

### LOAD DATA
>>> folder_data = Path("test/data/tree_check/")
>>> params,contacts,observ,epidem = data_load.load_exported_data(folder_data)
>>> contacts = contacts[["i","j","t","lambda"]]
>>> obs_all_df = []
>>> for obs in observ:
... obs_df = data_load.convert_obs_to_df(obs)
... obs_all_df.append(obs_df[["i","st","t"]])
>>> n_inst = len(observ)
>>> print(f"Number of instances: {n_inst}")
Number of instances: 50

### TEST RESULTS
>>> beliefs = np.load(folder_data / "beliefs_tree.npz")
>>> tests = [sib.Test(s==0,s==1,s==2) for s in range(3)]
>>> sib_pars = sib.Params(prob_r=sib.Gamma(mu=params["mu"]))
>>> cts = contacts.to_records(index=False)
>>> for inst in range(n_inst):
... obs = list(obs_all_df[inst].to_records(index=False))
... obs = [(i,tests[s],t) for (i,s,t) in obs]
... fg = sib.FactorGraph(params=sib_pars)
... fg.append_contacts_npy(cts["i"], cts["j"], cts["t"], cts["lambda"])
... for o in obs:
... fg.append_observation(*o)
... sib.iterate(fg,200,1e-20,callback=None)
... s = 0.0
... for i in range(len(fg.nodes)):
... s += sum(abs(beliefs[f"{inst}_{i}"][0]-np.array(fg.nodes[i].bt)))
... print(f"instance {inst}: {s < 1e-10}")
instance 0: True
instance 1: True
instance 2: True
instance 3: True
instance 4: True
instance 5: True
instance 6: True
instance 7: True
instance 8: True
instance 9: True
instance 10: True
instance 11: True
instance 12: True
instance 13: True
instance 14: True
instance 15: True
instance 16: True
instance 17: True
instance 18: True
instance 19: True
instance 20: True
instance 21: True
instance 22: True
instance 23: True
instance 24: True
instance 25: True
instance 26: True
instance 27: True
instance 28: True
instance 29: True
instance 30: True
instance 31: True
instance 32: True
instance 33: True
instance 34: True
instance 35: True
instance 36: True
instance 37: True
instance 38: True
instance 39: True
instance 40: True
instance 41: True
instance 42: True
instance 43: True
instance 44: True
instance 45: True
instance 46: True
instance 47: True
instance 48: True
instance 49: True
Loading