Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable assignment & visualisation with partial query graphs #332

Merged
merged 39 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c1ca50a
Expose save partial query to CLI
nickjcroucher Oct 9, 2024
395c2b9
Add test for --save-partial-query-graph
nickjcroucher Oct 9, 2024
7982406
Update beebop tests
nickjcroucher Oct 9, 2024
35fc89d
Update update tests
nickjcroucher Oct 9, 2024
9feb957
Update test update files
nickjcroucher Oct 10, 2024
1079fea
Correct update tests
nickjcroucher Oct 10, 2024
decf58c
Enable use of partial graphs
nickjcroucher Oct 10, 2024
5a14b8d
Fix lineage querying functions
nickjcroucher Oct 11, 2024
eb2249d
Add full network variable to lineage assignment call
nickjcroucher Oct 11, 2024
a97e75b
Move parenthesis
nickjcroucher Oct 11, 2024
df4dfec
Fix positional argument
nickjcroucher Oct 11, 2024
52dde86
Enable recalculation of distances for visualisation
nickjcroucher Oct 11, 2024
948f468
Add network argument to function call
nickjcroucher Oct 11, 2024
36291b2
Enable recalculation of distances for visualisation
nickjcroucher Oct 11, 2024
5cfef05
Add distance recalculation test
nickjcroucher Oct 11, 2024
79dd044
Update external clustering file
nickjcroucher Oct 11, 2024
396b3c2
Describe use of partial query graphs
nickjcroucher Oct 11, 2024
ef95674
Update external clustering files
nickjcroucher Oct 11, 2024
8fc6001
Add test for external clustering
nickjcroucher Oct 12, 2024
e7c94a5
Update external clustering files
nickjcroucher Oct 12, 2024
33c2500
Bump version
nickjcroucher Oct 12, 2024
a1d088d
Restructure to enable selection by cluster
nickjcroucher Oct 14, 2024
e399db1
Speed up network pruning
nickjcroucher Oct 14, 2024
dfc2300
More efficient database use
nickjcroucher Oct 17, 2024
f40a828
Fix indexing error
nickjcroucher Oct 17, 2024
9cc6c13
Quicker network plotting
nickjcroucher Oct 17, 2024
3ab5079
Fix logic of network file construction
nickjcroucher Oct 23, 2024
42b1c20
Make logic consistent for partial graphs
nickjcroucher Nov 4, 2024
8897960
Faster graph processing
nickjcroucher Nov 4, 2024
c1ae7b6
Fix communication between functions
nickjcroucher Nov 4, 2024
118861b
Restore earlier version
nickjcroucher Nov 6, 2024
16551de
Correct indentation
nickjcroucher Nov 6, 2024
808f0be
Clarify logic of remove_non_query_components
nickjcroucher Nov 6, 2024
8f28b0f
Remove debugging statement
nickjcroucher Nov 6, 2024
0b94b2d
Improve error messaging
nickjcroucher Nov 6, 2024
e8f28d6
Fix function renaming
nickjcroucher Nov 6, 2024
6d32dcc
Add explanation for max search depth
nickjcroucher Nov 6, 2024
727776d
Align statement
nickjcroucher Nov 6, 2024
6464fcd
Rename function consistently
nickjcroucher Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.7.2'
__version__ = '2.7.1'

# Minimum sketchlib version
SKETCHLIB_MAJOR = 2
Expand Down
27 changes: 20 additions & 7 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_options():
oGroup.add_argument('--update-db', help='Update reference database with query sequences', default=False, action='store_true')
oGroup.add_argument('--overwrite', help='Overwrite any existing database files', default=False, action='store_true')
oGroup.add_argument('--graph-weights', help='Save within-strain Euclidean distances into the graph', default=False, action='store_true')
oGroup.add_argument('--save-partial-query-graph', help='Save the network components to which queries are assigned', default=False, action='store_true')

# comparison metrics
kmerGroup = parser.add_argument_group('Kmer comparison options')
Expand Down Expand Up @@ -106,6 +107,9 @@ def get_options():
queryingGroup.add_argument('--accessory', help='(with a \'refine\' or \'lineage\' model) '
'Use an accessory-distance only model for assigning queries '
'[default = False]', default=False, action='store_true')
queryingGroup.add_argument('--use-full-network', help='Use full network rather than reference network for querying [default = False]',
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved
default = False,
action = 'store_true')

# processing
other = parser.add_argument_group('Other options')
Expand Down Expand Up @@ -234,7 +238,8 @@ def main():
args.gpu_dist,
args.gpu_graph,
args.deviceid,
save_partial_query_graph=False)
args.save_partial_query_graph,
args.use_full_network)

sys.stderr.write("\nDone\n")

Expand Down Expand Up @@ -267,7 +272,8 @@ def assign_query(dbFuncs,
gpu_dist,
gpu_graph,
deviceid,
save_partial_query_graph):
save_partial_query_graph,
use_full_network):
"""Code for assign query mode for CLI"""
createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
Expand Down Expand Up @@ -316,7 +322,8 @@ def assign_query(dbFuncs,
accessory,
gpu_dist,
gpu_graph,
save_partial_query_graph)
save_partial_query_graph,
use_full_network)
return(isolateClustering)

def assign_query_hdf5(dbFuncs,
Expand All @@ -341,7 +348,8 @@ def assign_query_hdf5(dbFuncs,
accessory,
gpu_dist,
gpu_graph,
save_partial_query_graph):
save_partial_query_graph,
use_full_network):
"""Code for assign query mode taking hdf5 as input. Written as a separate function so it can be called
by web APIs"""
# Modules imported here as graph tool is very slow to load (it pulls in all of GTK?)
Expand All @@ -359,6 +367,7 @@ def assign_query_hdf5(dbFuncs,
from .network import get_vertex_list
from .network import printExternalClusters
from .network import vertex_betweenness
from .network import remove_non_query_components
from .qc import sketchlibAssemblyQC

from .plot import writeClusterCsv
Expand Down Expand Up @@ -453,7 +462,7 @@ def assign_query_hdf5(dbFuncs,
ref_file_name = os.path.join(model_prefix,
os.path.basename(model_prefix) + file_extension_string + ".refs")
use_ref_graph = \
os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage'
os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage' and not use_full_network
if use_ref_graph:
with open(ref_file_name) as refFile:
for reference in refFile:
Expand Down Expand Up @@ -791,12 +800,16 @@ def assign_query_hdf5(dbFuncs,
output + "/" + os.path.basename(output) + db_suffix)
else:
storePickle(rNames, qNames, False, qrDistMat, dists_out)
if save_partial_query_graph and not serial:
if model.type == 'lineage':
if save_partial_query_graph:
genomeNetwork, pruned_isolate_lists = remove_non_query_components(genomeNetwork, rNames, qNames, use_gpu = gpu_graph)
if model.type == 'lineage' and not serial:
save_network(genomeNetwork[min(model.ranks)], prefix = output, suffix = '_graph', use_gpu = gpu_graph)
else:
graph_suffix = file_extension_string + '_graph'
save_network(genomeNetwork, prefix = output, suffix = graph_suffix, use_gpu = gpu_graph)
with open(f"{output}/{os.path.basename(output)}_query.subset",'w') as pruned_isolate_csv:
for isolate in pruned_isolate_lists:
pruned_isolate_csv.write(isolate + '\n')

return(isolateClustering)

Expand Down
6 changes: 4 additions & 2 deletions PopPUNK/lineages.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def query_db(args):
accessory,
args.gpu_dist,
args.gpu_graph,
save_partial_query_graph = False)
save_partial_query_graph = False,
use_full_network = True) # Use full network - does not make sense to use references for lineages

# Process clustering
query_strains = {}
Expand Down Expand Up @@ -439,7 +440,8 @@ def query_db(args):
accessory,
args.gpu_dist,
args.gpu_graph,
save_partial_query_graph = False)
save_partial_query_graph = False,
use_full_network = True)
overall_lineage[strain] = createOverallLineage(rank_list, lineageClustering)

# Print combined strain and lineage clustering
Expand Down
5 changes: 4 additions & 1 deletion PopPUNK/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,9 @@ def __init__(self, outPrefix, ranks, max_search_depth, reciprocal_only, count_un
ClusterFit.__init__(self, outPrefix)
self.type = 'lineage'
self.preprocess = False
self.max_search_depth = max_search_depth
self.max_search_depth = max_search_depth+5 # Set to highest rank by default in main; need to store additional distances
# when there is redundancy (e.g. reciprocal matching, unique distance counting)
# or other sequences may be pruned out of the database
self.nn_dists = None # stores the unprocessed kNN at the maximum search depth
self.ranks = []
for rank in sorted(ranks):
Expand Down Expand Up @@ -1368,6 +1370,7 @@ def extend(self, qqDists, qrDists):
qrRect,
self.max_search_depth,
self.threads)

# Update NN dist associated with model
self.__save_sparse__(higher_rank[2], higher_rank[0], higher_rank[1],
self.max_search_depth, n_ref + n_query, self.nn_dists.dtype,
Expand Down
109 changes: 88 additions & 21 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,32 +1920,99 @@ def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_g
if os.path.exists(network_fn):
network_found = True
sys.stderr.write("Loading network from " + network_fn + "\n")
samples_to_keep_set = frozenset(samples_to_keep)
G = load_network_file(network_fn, use_gpu = use_gpu)
if use_gpu:
# Identify indices
reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set]
# Generate data frame
G_df = G.view_edge_list()
if 'src' in G_df.columns:
G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True)
# Filter data frame
G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)]
# Translate network indices to match name order
G_new = translate_network_indices(G_new_df, reference_indices)
else:
reference_vertex = G.new_vertex_property('bool')
for n, vertex in enumerate(G.vertices()):
if reflist[n] in samples_to_keep_set:
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
G_new = gt.GraphView(G, vfilt = reference_vertex)
G_new = gt.Graph(G_new, prune = True)
G_new = remove_nodes_from_graph(G, reflist, samples_to_keep, use_gpu)
save_network(G_new,
prefix = output_db_name,
suffix = '_graph',
use_graphml = False,
use_gpu = use_gpu)
if not network_found:
sys.stderr.write('No network file found for pruning\n')

def remove_nodes_from_graph(G,reflist, samples_to_keep, use_gpu):
"""Return a modified graph containing only the requested nodes

Args:
reflist (list)
Ordered list of sequences of database
samples_to_keep (list)
The names of samples to be retained in the graph
use_gpu (bool)
Whether graph is a cugraph or not
[default = False]

Returns:
G_new (graph)
Pruned graph
"""
samples_to_keep_set = frozenset(samples_to_keep)
if use_gpu:
# Identify indices
reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set]
# Generate data frame
G_df = G.view_edge_list()
if 'src' in G_df.columns:
G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True)
# Filter data frame
G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)]
# Translate network indices to match name order
G_new = translate_network_indices(G_new_df, reference_indices)
else:
reference_vertex = G.new_vertex_property('bool')
for n, vertex in enumerate(G.vertices()):
if reflist[n] in samples_to_keep_set:
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
G_new = gt.GraphView(G, vfilt = reference_vertex)
G_new = gt.Graph(G_new, prune = True)
return G_new

def remove_non_query_components(G, rlist, qlist, use_gpu = False):
"""
Removes all components that do not contain a query sequence.

Args:
G (graph)
Network of queries linked to reference sequences
rlist (list)
List of reference sequence labels
qlist (list)
List of query sequence labels
use_gpu (bool)
Whether to use GPUs for network construction

Returns:
G (graph)
The resulting network
pruned_names (list)
The labels of the sequences in the pruned network
"""
components_with_query = []
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved
combined_names = rlist + qlist
pruned_names = []
if use_gpu:
sys.stderr.write('Saving partial query graphs is not compatible with GPU networks yet\n')
sys.exit(1)
else:
# Identify network components containing queries
component_dict = gt.label_components(G)[0]
components_with_query = set()
# The number of reference sequences is len(rlist)
# These are the first len(rlist) vertices in the graph
# Queries that have been added have indices >len(rlist)
# Therefore these are the components to retain
for i in range(len(rlist),G.num_vertices()):
v = G.vertex(i) # Access vertex by index
components_with_query.add(component_dict[v])
# Create a boolean filter based on the list of component IDs
query_filter = G.new_vertex_property("bool")
for v in G.vertices():
query_filter[int(v)] = (component_dict[v] in components_with_query)
if query_filter[int(v)]:
pruned_names.append(combined_names[int(v)])
# Create a filtered graph with only the specified components
query_subgraph = gt.GraphView(G, vfilt=query_filter)

return query_subgraph, pruned_names
14 changes: 8 additions & 6 deletions PopPUNK/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def drawMST(mst, outPrefix, isolate_clustering, clustering_name, overwrite):
output=graph2_file_name, output_size=(3000, 3000))

def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv, queryList = None,
suffix = None, writeCsv = True):
suffix = None, writeCsv = True, use_partial_query_graph = None):
"""Write outputs for cytoscape. A graphml of the network, and CSV with metadata

Args:
Expand All @@ -536,6 +536,8 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv,
(default = None)
writeCsv (bool)
Whether to print CSV file to accompany network
use_partial_query_graph (str)
File listing sequences to be included in output graph
"""

# Avoid circular import
Expand All @@ -553,7 +555,8 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv,
suffix = '_cytoscape'
else:
suffix = suffix + '_cytoscape'
save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True)
if use_partial_query_graph is None:
save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True)

# Save each component too (useful for very large graphs)
component_assignments, component_hist = gt.label_components(G)
Expand All @@ -562,10 +565,9 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv,
for vidx, v_component in enumerate(component_assignments.a):
if v_component != component_idx:
remove_list.append(vidx)
G_copy = G.copy()
G_copy.remove_vertex(remove_list)
save_network(G_copy, prefix = outPrefix, suffix = "_component_" + str(component_idx + 1), use_graphml = True)
del G_copy
G.remove_vertex(remove_list)
G.purge_vertices()
save_network(G, prefix = outPrefix, suffix = "_component_" + str(component_idx + 1), use_graphml = True)

if G_mst != None:
isolate_labels = isolateNameToLabel(G_mst.vp.id)
Expand Down
17 changes: 13 additions & 4 deletions PopPUNK/sketchlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def getSeqsInDb(dbname):

return seqs

def joinDBs(db1, db2, output, update_random = None):
def joinDBs(db1, db2, output, update_random = None, full_names = False):
"""Join two sketch databases with the low-level HDF5 copy interface

Args:
Expand All @@ -226,10 +226,19 @@ def joinDBs(db1, db2, output, update_random = None):
update_random (dict)
Whether to re-calculate the random object. May contain
control arguments strand_preserved and threads (see :func:`addRandom`)
full_names (bool)
If True, db_name and out_name are the full paths to h5 files

"""
join_prefix = output + "/" + os.path.basename(output)
db1_name = db1 + "/" + os.path.basename(db1) + ".h5"
db2_name = db2 + "/" + os.path.basename(db2) + ".h5"

if not full_names:
join_prefix = output + "/" + os.path.basename(output)
db1_name = db1 + "/" + os.path.basename(db1) + ".h5"
db2_name = db2 + "/" + os.path.basename(db2) + ".h5"
else:
db1_name = db1
db2_name = db2
join_prefix = output
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved

hdf1 = h5py.File(db1_name, 'r')
hdf2 = h5py.File(db2_name, 'r')
Expand Down
6 changes: 5 additions & 1 deletion PopPUNK/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,14 +593,16 @@ def check_and_set_gpu(use_gpu, gpu_lib, quit_on_fail = False):

return use_gpu

def read_rlist_from_distance_pickle(fn, allow_non_self = True):
def read_rlist_from_distance_pickle(fn, allow_non_self = True, include_queries = False):
"""Return the list of reference sequences from a distance pickle.

Args:
fn (str)
Name of distance pickle
allow_non_self (bool)
Whether non-self distance datasets are permissible
include_queries (bool)
Whether queries should be included in the rlist
Returns:
rlist (list)
List of reference sequence names
Expand All @@ -611,6 +613,8 @@ def read_rlist_from_distance_pickle(fn, allow_non_self = True):
sys.stderr.write("Thi analysis requires an all-v-all"
" distance dataset\n")
sys.exit(1)
if include_queries:
rlist = rlist + qlist
return rlist

def get_match_search_depth(rlist,rank_list):
Expand Down
Loading
Loading