From 52dde860afd694c9a9f474fdd4519eef1ceee700 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 11 Oct 2024 09:35:12 +0100 Subject: [PATCH] Enable recalculation of distances for visualisation --- PopPUNK/visualise.py | 214 ++++++++++++++++++++++++++----------------- 1 file changed, 128 insertions(+), 86 deletions(-) diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index 1427cba4..8df9fca6 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -84,6 +84,10 @@ def get_options(): 'minimum spanning tree', default=None, type = str) + iGroup.add_argument('--recalculate-distances', + help='Recalculate pairwise distances rather than read them from a file', + default=False, + action = 'store_true') iGroup.add_argument('--network-file', help='Specify a file to use for any graph visualisations', type = str) @@ -194,6 +198,7 @@ def generate_visualisations(query_db, overwrite, display_cluster, use_partial_query_graph, + recalculate_distances, tmp): from .models import loadClusterFit @@ -256,13 +261,13 @@ def generate_visualisations(query_db, sys.stderr.write("Cannot create output directory\n") sys.exit(1) - #******************************# - #* *# - #* Process dense or sparse *# - #* distances *# - #* *# - #******************************# + #*******************************# + #* *# + #* Extract subset of sequences *# + #* *# + #*******************************# + # Identify distance matrix for ordered names if distances is None: if query_db is None: distances = ref_db + "/" + os.path.basename(ref_db) + ".dists" @@ -271,17 +276,42 @@ def generate_visualisations(query_db, else: distances = distances + # Location and properties of reference database + ref_db_loc = ref_db + "/" + os.path.basename(ref_db) + kmers, sketch_sizes, codon_phased = readDBParams(ref_db) + + # extract subset of distances if requested + combined_seq = read_rlist_from_distance_pickle(distances + '.pkl') + all_seq = combined_seq # all_seq is an immutable record use for network parsing + if include_files is not None or use_partial_query_graph is not None: + viz_subset = set() + subset_file = include_files if include_files is not None else use_partial_query_graph + with open(subset_file, 'r') as assemblyFiles: + for assembly in assemblyFiles: + viz_subset.add(assembly.rstrip()) + if len(viz_subset.difference(combined_seq)) > 0: + sys.stderr.write("--include-files contains names not in --distances\n") + sys.stderr.write("Please assign distances before subsetting the database\n") + else: + viz_subset = None + + #******************************# + #* *# + #* Process dense or sparse *# + #* distances *# + #* *# + #******************************# + # Determine whether to use sparse distances - combined_seq = None use_sparse = False use_dense = False + if (tree == "mst" or tree == "both") and rank_fit is not None: # Set flag use_sparse = True # Read list of sequence names and sparse distance matrix - rlist = read_rlist_from_distance_pickle(distances + '.pkl') + rlist = combined_seq sparse_mat = sparse.load_npz(rank_fit) - combined_seq = rlist # Check previous distances have been supplied if building on a previous MST old_rlist = None if previous_distances is not None: @@ -289,48 +319,63 @@ def generate_visualisations(query_db, elif previous_mst is not None: sys.stderr.write('The prefix of the distance files used to create the previous MST' ' is needed to use the network') + if (tree == "nj" or tree == "both") or rank_fit == None: use_dense = True - # Process dense distance matrix - rlist, qlist, self, complete_distMat = readPickle(distances) - if not self: - qr_distMat = complete_distMat - combined_seq = rlist + qlist + + # Either calculate or read distances + if recalculate_distances: + sys.stderr.write("Recalculating pairwise distances for tree construction\n") + + # Generate distances + self = True + sequences_to_analyse = viz_subset if viz_subset is not None else combined_seq + qlist = sequences_to_analyse + subset_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc, + query_db_name=ref_db_loc, + rList=sequences_to_analyse, + qList=sequences_to_analyse, + klist=kmers.tolist(), + random_correct=True, + jaccard=False, + num_threads=threads, + use_gpu = gpu_dist, + device_id = deviceid) + + # Convert distance matrix format + combined_seq, core_distMat, acc_distMat = \ + update_distance_matrices(sequences_to_analyse, + subset_distMat, + threads = threads) + else: - rr_distMat = complete_distMat - combined_seq = rlist - - # Fill in qq-distances if required - if self == False: - sys.stderr.write("Note: Distances in " + distances + " are from assign mode\n" - "Note: Distance will be extended to full all-vs-all distances\n" - "Note: Re-run poppunk_assign with --update-db to avoid this\n") - ref_db_loc = ref_db + "/" + os.path.basename(ref_db) - rlist_original, qlist_original, self_ref, rr_distMat = readPickle(ref_db_loc + ".dists") - if not self_ref: - sys.stderr.write("Distances in " + ref_db + " not self all-vs-all either\n") - sys.exit(1) - kmers, sketch_sizes, codon_phased = readDBParams(query_db) - addRandom(query_db, qlist, kmers, - strand_preserved = strand_preserved, threads = threads) - query_db_loc = query_db + "/" + os.path.basename(query_db) - qq_distMat = pp_sketchlib.queryDatabase(ref_db_name=query_db_loc, - query_db_name=query_db_loc, - rList=qlist, - qList=qlist, - klist=kmers, - random_correct=True, - jaccard=False, - num_threads=threads, - use_gpu=gpu_dist, - device_id=deviceid) - - # If the assignment was run with references, qrDistMat will be incomplete - if rlist != rlist_original: - rlist = rlist_original - qr_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc, + sys.stderr.write("Reading pairwise distances for tree construction\n") + + # Process dense distance matrix + rlist, qlist, self, complete_distMat = readPickle(distances) + if not self: + qr_distMat = complete_distMat + combined_seq = rlist + qlist + else: + rr_distMat = complete_distMat + combined_seq = rlist + + # Fill in qq-distances if required + if self == False: + sys.stderr.write("Note: Distances in " + distances + " are from assign mode\n" + "Note: Distance will be extended to full all-vs-all distances\n" + "Note: Re-run poppunk_assign with --update-db to avoid this\n") + rlist_original, qlist_original, self_ref, rr_distMat = readPickle(ref_db_loc + ".dists") + if not self_ref: + sys.stderr.write("Distances in " + ref_db + " not self all-vs-all either\n") + sys.exit(1) + kmers, sketch_sizes, codon_phased = readDBParams(query_db) + addRandom(query_db, qlist, kmers, + strand_preserved = strand_preserved, threads = threads) + query_db_loc = query_db + "/" + os.path.basename(query_db) + qq_distMat = pp_sketchlib.queryDatabase(ref_db_name=query_db_loc, query_db_name=query_db_loc, - rList=rlist, + rList=qlist, qList=qlist, klist=kmers, random_correct=True, @@ -339,46 +384,42 @@ def generate_visualisations(query_db, use_gpu=gpu_dist, device_id=deviceid) - else: - qlist = None - qr_distMat = None - qq_distMat = None - - # Turn long form matrices into square form - combined_seq, core_distMat, acc_distMat = \ - update_distance_matrices(rlist, rr_distMat, - qlist, qr_distMat, qq_distMat, - threads = threads) - - #*******************************# - #* *# - #* Extract subset of sequences *# - #* *# - #*******************************# - - # extract subset of distances if requested - all_seq = combined_seq - if include_files is not None or use_partial_query_graph is not None: - viz_subset = set() - subset_file = include_files if include_files is not None else use_partial_query_graph - with open(subset_file, 'r') as assemblyFiles: - for assembly in assemblyFiles: - viz_subset.add(assembly.rstrip()) - if len(viz_subset.difference(combined_seq)) > 0: - sys.stderr.write("--include-files contains names not in --distances\n") + # If the assignment was run with references, qrDistMat will be incomplete + if rlist != rlist_original: + rlist = rlist_original + qr_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc, + query_db_name=query_db_loc, + rList=rlist, + qList=qlist, + klist=kmers, + random_correct=True, + jaccard=False, + num_threads=threads, + use_gpu=gpu_dist, + device_id=deviceid) - # Only keep found rows - row_slice = [True if name in viz_subset else False for name in combined_seq] - combined_seq = [name for name in combined_seq if name in viz_subset] - if use_sparse: - sparse_mat = sparse_mat[np.ix_(row_slice, row_slice)] - if use_dense: - if qlist != None: - qlist = list(viz_subset.intersection(qlist)) - core_distMat = core_distMat[np.ix_(row_slice, row_slice)] - acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)] - else: - viz_subset = None + else: + qlist = None + qr_distMat = None + qq_distMat = None + + # Turn long form matrices into square form + combined_seq, core_distMat, acc_distMat = \ + update_distance_matrices(rlist, rr_distMat, + qlist, qr_distMat, qq_distMat, + threads = threads) + + # Prune distance matrix if subsetting data + if viz_subset is not None: + row_slice = [True if name in viz_subset else False for name in combined_seq] + combined_seq = [name for name in combined_seq if name in viz_subset] + if use_sparse: + sparse_mat = sparse_mat[np.ix_(row_slice, row_slice)] + if use_dense: + if qlist != None: + qlist = list(viz_subset.intersection(qlist)) + core_distMat = core_distMat[np.ix_(row_slice, row_slice)] + acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)] #**********************************# #* *# @@ -670,6 +711,7 @@ def main(): args.overwrite, args.display_cluster, args.use_partial_query_graph, + args.recalculate_distances, args.tmp) if __name__ == '__main__':