Skip to content

Commit

Permalink
Enable use of partial graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjcroucher committed Oct 10, 2024
1 parent 1079fea commit decf58c
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 38 deletions.
26 changes: 19 additions & 7 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def get_options():
queryingGroup.add_argument('--accessory', help='(with a \'refine\' or \'lineage\' model) '
'Use an accessory-distance only model for assigning queries '
'[default = False]', default=False, action='store_true')
queryingGroup.add_argument('--use-full-network', help='Use full network rather than reference network for querying [default = False]',
default = False,
action = 'store_true')

# processing
other = parser.add_argument_group('Other options')
Expand Down Expand Up @@ -235,7 +238,8 @@ def main():
args.gpu_dist,
args.gpu_graph,
args.deviceid,
args.save_partial_query_graph)
args.save_partial_query_graph,
args.use_full_network)

sys.stderr.write("\nDone\n")

Expand Down Expand Up @@ -268,7 +272,8 @@ def assign_query(dbFuncs,
gpu_dist,
gpu_graph,
deviceid,
save_partial_query_graph):
save_partial_query_graph,
use_full_network):
"""Code for assign query mode for CLI"""
createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
Expand Down Expand Up @@ -317,7 +322,8 @@ def assign_query(dbFuncs,
accessory,
gpu_dist,
gpu_graph,
save_partial_query_graph)
save_partial_query_graph,
use_full_network)
return(isolateClustering)

def assign_query_hdf5(dbFuncs,
Expand All @@ -342,7 +348,8 @@ def assign_query_hdf5(dbFuncs,
accessory,
gpu_dist,
gpu_graph,
save_partial_query_graph):
save_partial_query_graph,
use_full_network):
"""Code for assign query mode taking hdf5 as input. Written as a separate function so it can be called
by web APIs"""
# Modules imported here as graph tool is very slow to load (it pulls in all of GTK?)
Expand All @@ -360,6 +367,7 @@ def assign_query_hdf5(dbFuncs,
from .network import get_vertex_list
from .network import printExternalClusters
from .network import vertex_betweenness
from .network import retain_only_query_clusters
from .qc import sketchlibAssemblyQC

from .plot import writeClusterCsv
Expand Down Expand Up @@ -454,7 +462,7 @@ def assign_query_hdf5(dbFuncs,
ref_file_name = os.path.join(model_prefix,
os.path.basename(model_prefix) + file_extension_string + ".refs")
use_ref_graph = \
os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage'
os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage' and not use_full_network
if use_ref_graph:
with open(ref_file_name) as refFile:
for reference in refFile:
Expand Down Expand Up @@ -792,12 +800,16 @@ def assign_query_hdf5(dbFuncs,
output + "/" + os.path.basename(output) + db_suffix)
else:
storePickle(rNames, qNames, False, qrDistMat, dists_out)
if save_partial_query_graph and not serial:
if model.type == 'lineage':
if save_partial_query_graph:
genomeNetwork, pruned_isolate_lists = retain_only_query_clusters(genomeNetwork, rNames, qNames, use_gpu = gpu_graph)
if model.type == 'lineage' and not serial:
save_network(genomeNetwork[min(model.ranks)], prefix = output, suffix = '_graph', use_gpu = gpu_graph)
else:
graph_suffix = file_extension_string + '_graph'
save_network(genomeNetwork, prefix = output, suffix = graph_suffix, use_gpu = gpu_graph)
with open(f"{output}/{os.path.basename(output)}_query.subset",'w') as pruned_isolate_csv:
for isolate in pruned_isolate_lists:
pruned_isolate_csv.write(isolate + '\n')

return(isolateClustering)

Expand Down
110 changes: 89 additions & 21 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,32 +1920,100 @@ def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_g
if os.path.exists(network_fn):
network_found = True
sys.stderr.write("Loading network from " + network_fn + "\n")
samples_to_keep_set = frozenset(samples_to_keep)
G = load_network_file(network_fn, use_gpu = use_gpu)
if use_gpu:
# Identify indices
reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set]
# Generate data frame
G_df = G.view_edge_list()
if 'src' in G_df.columns:
G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True)
# Filter data frame
G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)]
# Translate network indices to match name order
G_new = translate_network_indices(G_new_df, reference_indices)
else:
reference_vertex = G.new_vertex_property('bool')
for n, vertex in enumerate(G.vertices()):
if reflist[n] in samples_to_keep_set:
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
G_new = gt.GraphView(G, vfilt = reference_vertex)
G_new = gt.Graph(G_new, prune = True)
G_new = remove_nodes_from_graph(G, reflist, samples_to_keep, use_gpu)
save_network(G_new,
prefix = output_db_name,
suffix = '_graph',
use_graphml = False,
use_gpu = use_gpu)
if not network_found:
sys.stderr.write('No network file found for pruning\n')

def remove_nodes_from_graph(G,reflist, samples_to_keep, use_gpu):
"""Return a modified graph containing only the requested nodes
Args:
reflist (list)
Ordered list of sequences of database
samples_to_keep (list)
The names of samples to be retained in the graph
use_gpu (bool)
Whether graph is a cugraph or not
[default = False]
Returns:
G_new (graph)
Pruned graph
"""
samples_to_keep_set = frozenset(samples_to_keep)
if use_gpu:
# Identify indices
reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set]
# Generate data frame
G_df = G.view_edge_list()
if 'src' in G_df.columns:
G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True)
# Filter data frame
G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)]
# Translate network indices to match name order
G_new = translate_network_indices(G_new_df, reference_indices)
else:
reference_vertex = G.new_vertex_property('bool')
for n, vertex in enumerate(G.vertices()):
if reflist[n] in samples_to_keep_set:
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
G_new = gt.GraphView(G, vfilt = reference_vertex)
G_new = gt.Graph(G_new, prune = True)
return G_new

def retain_only_query_clusters(G, rlist, qlist, use_gpu = False):
"""
Removes all components that do not contain a query sequence.
Args:
G (graph)
Network of queries linked to reference sequences
rlist (list)
List of reference sequence labels
qlist (list)
List of query sequence labels
use_gpu (bool)
Whether to use GPUs for network construction
Returns:
G (graph)
The resulting network
pruned_names (list)
The labels of the sequences in the pruned network
"""
num_refs = len(rlist)
components_with_query = []
combined_names = rlist + qlist
pruned_names = []
if use_gpu:
sys.stderr.write('Not compatible with GPU networks yet\n')
query_subgraph = G
else:
components = gt.label_components(G)[0].a
for component in components:
subgraph = gt.GraphView(G, vfilt=components == component)
max_node = max([int(v) for v in subgraph.vertices()])
if max_node >= num_refs:
components_with_query.append(int(component))
# Create a boolean filter based on the list of component IDs
query_filter = G.new_vertex_property("bool")
for v in G.vertices():
query_filter[int(v)] = (components[int(v)] in components_with_query)
if query_filter[int(v)]:
pruned_names.append(combined_names[int(v)])

# Create a filtered graph with only the specified components
query_subgraph = gt.GraphView(G, vfilt=query_filter)

# Purge the filtered graph to remove the other components permanently
query_subgraph.purge_vertices()

return query_subgraph, pruned_names
25 changes: 16 additions & 9 deletions PopPUNK/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def get_options():
iGroup.add_argument('--display-cluster',
help='Column of clustering CSV to use for plotting',
default=None)
iGroup.add_argument('--use-partial-query-graph',
help='File listing sequences in partial query graph after assignment',
default=None)

# output options
oGroup = parser.add_argument_group('Output options')
Expand Down Expand Up @@ -190,6 +193,7 @@ def generate_visualisations(query_db,
mst_distances,
overwrite,
display_cluster,
use_partial_query_graph,
tmp):

from .models import loadClusterFit
Expand All @@ -200,6 +204,7 @@ def generate_visualisations(query_db,
from .network import cugraph_to_graph_tool
from .network import save_network
from .network import sparse_mat_to_network
from .network import remove_nodes_from_graph

from .plot import drawMST
from .plot import outputsForMicroreact
Expand Down Expand Up @@ -353,9 +358,10 @@ def generate_visualisations(query_db,

# extract subset of distances if requested
all_seq = combined_seq
if include_files is not None:
if include_files is not None or use_partial_query_graph is not None:
viz_subset = set()
with open(include_files, 'r') as assemblyFiles:
subset_file = include_files if include_files is not None else use_partial_query_graph
with open(subset_file, 'r') as assemblyFiles:
for assembly in assemblyFiles:
viz_subset.add(assembly.rstrip())
if len(viz_subset.difference(combined_seq)) > 0:
Expand Down Expand Up @@ -605,20 +611,20 @@ def generate_visualisations(query_db,
if gpu_graph:
genomeNetwork = cugraph_to_graph_tool(genomeNetwork, isolateNameToLabel(all_seq))
# Hard delete from network to remove samples (mask doesn't work neatly)
if viz_subset is not None:
remove_list = []
for keep, idx in enumerate(row_slice):
if not keep:
remove_list.append(idx)
genomeNetwork.remove_vertex(remove_list)
if include_files is not None:
genomeNetwork = remove_nodes_from_graph(genomeNetwork, all_seq, viz_subset, use_gpu = gpu_graph)
elif rank_fit is not None:
genomeNetwork = sparse_mat_to_network(sparse_mat, combined_seq, use_gpu = gpu_graph)
else:
sys.stderr.write('Cytoscape output requires a network file or lineage rank fit to be provided\n')
sys.exit(1)
# If network has been pruned then only use the appropriate subset of names - otherwise use all names
# for full network
node_labels = viz_subset if (use_partial_query_graph is not None or include_files is not None) \
else combined_seq
outputsForCytoscape(genomeNetwork,
mst_graph,
combined_seq,
node_labels,
isolateClustering,
output,
info_csv)
Expand Down Expand Up @@ -663,6 +669,7 @@ def main():
args.mst_distances,
args.overwrite,
args.display_cluster,
args.use_partial_query_graph,
args.tmp)

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model dbscan --ref-db batch12 --output batch12 --overwrite", shell=True, check=True)
subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db batch12 --output batch12 --overwrite", shell=True, check=True)
subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch12 --query rfile3.txt --output batch3 --external-clustering batch12_external_clusters.csv --save-partial-query-graph --overwrite", shell=True, check=True)
subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --network-file ./batch12/batch12_graph.gt --overwrite", shell=True, check=True)
subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --network-file ./batch3/batch3_graph.gt --use-partial-query-graph ./batch3/batch3_query.subset --overwrite", shell=True, check=True)

# citations
sys.stderr.write("Printing citations\n")
Expand Down

0 comments on commit decf58c

Please sign in to comment.