-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Input the contacts in parallel from numpy #2
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
d324726
Add comments
fabmazz 4f98e06
Fix deprecation of numpy types in tests
fabmazz 1fc4a6a
add python contacts read function
fabmazz 05ed881
split method to add contacts
fabmazz a4327e2
split into functions
fabmazz 46a4a31
Use map to preprocess contacts
fabmazz 8c84f7c
second try in parallel
fabmazz 0c86c7d
Add test for large tree
fabmazz a70728f
Fix spaces
fabmazz 7969a8b
reset spaces
fabmazz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"N": 19531, | ||
"t_limit": 25, | ||
"mu": 0.01 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
import sys | ||
sys.path.insert(0,"..") | ||
|
||
import numpy as np | ||
import json | ||
import sib | ||
|
||
with open("data/large_tree_pars.json") as f: | ||
params = json.load(f) | ||
|
||
cts_beliefs_f = np.load("data/large_tree_data.npz") | ||
|
||
cts = cts_beliefs_f["cts"] | ||
|
||
beliefs_all = cts_beliefs_f["beliefs"] | ||
|
||
obs_all = {} | ||
for k in cts_beliefs_f.files: | ||
if "obs_" in k: | ||
u=int(k.split("_")[-1]) | ||
#print(k, u) | ||
obs_all[u] = cts_beliefs_f[k] | ||
|
||
#close file | ||
cts_beliefs_f.close() | ||
|
||
sib_pars = sib.Params(prob_r=sib.Gamma(mu=params["mu"])) | ||
|
||
N = params["N"] | ||
t_limit = params["t_limit"] | ||
cts_sib = [(int(r["i"]),int(r["j"]),int(r["t"]),r["lam"]) for r in cts] | ||
|
||
tests = [sib.Test(s==0,s==1,s==2) for s in range(3)] | ||
def make_obs_sib(N, t_limit,obs, tests): | ||
obs_list_sib =[(i,-1,t) for t in [t_limit] for i in range(N) ] | ||
obs_list_sib.extend([(r["i"],tests[r["st"]],r["t"]) for r in obs]) | ||
|
||
obs_list_sib.sort(key=lambda x: x[-1]) | ||
|
||
return obs_list_sib | ||
|
||
callback = lambda t, err, fg: print(f"iter: {t:6}, err: {err:.5e} ", end="\r") | ||
|
||
for ii,obs in obs_all.items(): | ||
fg = sib.FactorGraph(params=sib_pars) | ||
beliefs = beliefs_all[ii] | ||
#print(f"Instance {ii}") | ||
#for c in cts_sib: | ||
# fg.append_contact(*c) | ||
fg.append_contacts_npy(cts["i"], cts["j"], cts["t"], cts["lam"]) | ||
obs_list_sib = make_obs_sib(N,t_limit, obs, tests) | ||
for o in obs_list_sib: | ||
fg.append_observation(*o) | ||
|
||
sib.iterate(fg,200,1e-20,callback=callback ) | ||
print("") | ||
s=0. | ||
for i in range(len(fg.nodes)): | ||
s+=np.abs(np.array(fg.nodes[i].bt)-beliefs[i][0]).sum() | ||
s+=np.abs(np.array(fg.nodes[i].bg)-beliefs[i][1]).sum() | ||
#fg.nodes[i].bg])) | ||
print(f"instance {ii}: {s:4.3e} {s < 1e-10}") | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,5 +81,4 @@ instance 45: True | |
instance 46: True | ||
instance 47: True | ||
instance 48: True | ||
instance 49: True | ||
|
||
instance 49: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
>>> from pathlib import Path | ||
>>> import sys | ||
>>> sys.path.insert(0,'test/') | ||
>>> import numpy as np | ||
>>> import sib | ||
>>> import data_load | ||
|
||
### LOAD DATA | ||
>>> folder_data = Path("test/data/tree_check/") | ||
>>> params,contacts,observ,epidem = data_load.load_exported_data(folder_data) | ||
>>> contacts = contacts[["i","j","t","lambda"]] | ||
>>> obs_all_df = [] | ||
>>> for obs in observ: | ||
... obs_df = data_load.convert_obs_to_df(obs) | ||
... obs_all_df.append(obs_df[["i","st","t"]]) | ||
>>> n_inst = len(observ) | ||
>>> print(f"Number of instances: {n_inst}") | ||
Number of instances: 50 | ||
|
||
### TEST RESULTS | ||
>>> beliefs = np.load(folder_data / "beliefs_tree.npz") | ||
>>> tests = [sib.Test(s==0,s==1,s==2) for s in range(3)] | ||
>>> sib_pars = sib.Params(prob_r=sib.Gamma(mu=params["mu"])) | ||
>>> cts = contacts.to_records(index=False) | ||
>>> for inst in range(n_inst): | ||
... obs = list(obs_all_df[inst].to_records(index=False)) | ||
... obs = [(i,tests[s],t) for (i,s,t) in obs] | ||
... fg = sib.FactorGraph(params=sib_pars) | ||
... fg.append_contacts_npy(cts["i"], cts["j"], cts["t"], cts["lambda"]) | ||
... for o in obs: | ||
... fg.append_observation(*o) | ||
... sib.iterate(fg,200,1e-20,callback=None) | ||
... s = 0.0 | ||
... for i in range(len(fg.nodes)): | ||
... s += sum(abs(beliefs[f"{inst}_{i}"][0]-np.array(fg.nodes[i].bt))) | ||
... print(f"instance {inst}: {s < 1e-10}") | ||
instance 0: True | ||
instance 1: True | ||
instance 2: True | ||
instance 3: True | ||
instance 4: True | ||
instance 5: True | ||
instance 6: True | ||
instance 7: True | ||
instance 8: True | ||
instance 9: True | ||
instance 10: True | ||
instance 11: True | ||
instance 12: True | ||
instance 13: True | ||
instance 14: True | ||
instance 15: True | ||
instance 16: True | ||
instance 17: True | ||
instance 18: True | ||
instance 19: True | ||
instance 20: True | ||
instance 21: True | ||
instance 22: True | ||
instance 23: True | ||
instance 24: True | ||
instance 25: True | ||
instance 26: True | ||
instance 27: True | ||
instance 28: True | ||
instance 29: True | ||
instance 30: True | ||
instance 31: True | ||
instance 32: True | ||
instance 33: True | ||
instance 34: True | ||
instance 35: True | ||
instance 36: True | ||
instance 37: True | ||
instance 38: True | ||
instance 39: True | ||
instance 40: True | ||
instance 41: True | ||
instance 42: True | ||
instance 43: True | ||
instance 44: True | ||
instance 45: True | ||
instance 46: True | ||
instance 47: True | ||
instance 48: True | ||
instance 49: True |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
l'aggiunta di queste due funzioni sembrano gli unici cambiamenti "veri" in bp.h e bp.cpp. Puoi lasciare solo questo in questo PR?
grazie mille fabio