Skip to content

Commit

Permalink
Enable recalculation of distances for visualisation
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjcroucher committed Oct 11, 2024
1 parent df4dfec commit 52dde86
Showing 1 changed file with 128 additions and 86 deletions.
214 changes: 128 additions & 86 deletions PopPUNK/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def get_options():
'minimum spanning tree',
default=None,
type = str)
iGroup.add_argument('--recalculate-distances',
help='Recalculate pairwise distances rather than read them from a file',
default=False,
action = 'store_true')
iGroup.add_argument('--network-file',
help='Specify a file to use for any graph visualisations',
type = str)
Expand Down Expand Up @@ -194,6 +198,7 @@ def generate_visualisations(query_db,
overwrite,
display_cluster,
use_partial_query_graph,
recalculate_distances,
tmp):

from .models import loadClusterFit
Expand Down Expand Up @@ -256,13 +261,13 @@ def generate_visualisations(query_db,
sys.stderr.write("Cannot create output directory\n")
sys.exit(1)

#******************************#
#* *#
#* Process dense or sparse *#
#* distances *#
#* *#
#******************************#
#*******************************#
#* *#
#* Extract subset of sequences *#
#* *#
#*******************************#

# Identify distance matrix for ordered names
if distances is None:
if query_db is None:
distances = ref_db + "/" + os.path.basename(ref_db) + ".dists"
Expand All @@ -271,66 +276,106 @@ def generate_visualisations(query_db,
else:
distances = distances

# Location and properties of reference database
ref_db_loc = ref_db + "/" + os.path.basename(ref_db)
kmers, sketch_sizes, codon_phased = readDBParams(ref_db)

# extract subset of distances if requested
combined_seq = read_rlist_from_distance_pickle(distances + '.pkl')
all_seq = combined_seq # all_seq is an immutable record use for network parsing
if include_files is not None or use_partial_query_graph is not None:
viz_subset = set()
subset_file = include_files if include_files is not None else use_partial_query_graph
with open(subset_file, 'r') as assemblyFiles:
for assembly in assemblyFiles:
viz_subset.add(assembly.rstrip())
if len(viz_subset.difference(combined_seq)) > 0:
sys.stderr.write("--include-files contains names not in --distances\n")
sys.stderr.write("Please assign distances before subsetting the database\n")
else:
viz_subset = None

#******************************#
#* *#
#* Process dense or sparse *#
#* distances *#
#* *#
#******************************#

# Determine whether to use sparse distances
combined_seq = None
use_sparse = False
use_dense = False

if (tree == "mst" or tree == "both") and rank_fit is not None:
# Set flag
use_sparse = True
# Read list of sequence names and sparse distance matrix
rlist = read_rlist_from_distance_pickle(distances + '.pkl')
rlist = combined_seq
sparse_mat = sparse.load_npz(rank_fit)
combined_seq = rlist
# Check previous distances have been supplied if building on a previous MST
old_rlist = None
if previous_distances is not None:
old_rlist = read_rlist_from_distance_pickle(previous_distances + '.pkl')
elif previous_mst is not None:
sys.stderr.write('The prefix of the distance files used to create the previous MST'
' is needed to use the network')

if (tree == "nj" or tree == "both") or rank_fit == None:
use_dense = True
# Process dense distance matrix
rlist, qlist, self, complete_distMat = readPickle(distances)
if not self:
qr_distMat = complete_distMat
combined_seq = rlist + qlist

# Either calculate or read distances
if recalculate_distances:
sys.stderr.write("Recalculating pairwise distances for tree construction\n")

# Generate distances
self = True
sequences_to_analyse = viz_subset if viz_subset is not None else combined_seq
qlist = sequences_to_analyse
subset_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc,
query_db_name=ref_db_loc,
rList=sequences_to_analyse,
qList=sequences_to_analyse,
klist=kmers.tolist(),
random_correct=True,
jaccard=False,
num_threads=threads,
use_gpu = gpu_dist,
device_id = deviceid)

# Convert distance matrix format
combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(sequences_to_analyse,
subset_distMat,
threads = threads)

else:
rr_distMat = complete_distMat
combined_seq = rlist

# Fill in qq-distances if required
if self == False:
sys.stderr.write("Note: Distances in " + distances + " are from assign mode\n"
"Note: Distance will be extended to full all-vs-all distances\n"
"Note: Re-run poppunk_assign with --update-db to avoid this\n")
ref_db_loc = ref_db + "/" + os.path.basename(ref_db)
rlist_original, qlist_original, self_ref, rr_distMat = readPickle(ref_db_loc + ".dists")
if not self_ref:
sys.stderr.write("Distances in " + ref_db + " not self all-vs-all either\n")
sys.exit(1)
kmers, sketch_sizes, codon_phased = readDBParams(query_db)
addRandom(query_db, qlist, kmers,
strand_preserved = strand_preserved, threads = threads)
query_db_loc = query_db + "/" + os.path.basename(query_db)
qq_distMat = pp_sketchlib.queryDatabase(ref_db_name=query_db_loc,
query_db_name=query_db_loc,
rList=qlist,
qList=qlist,
klist=kmers,
random_correct=True,
jaccard=False,
num_threads=threads,
use_gpu=gpu_dist,
device_id=deviceid)

# If the assignment was run with references, qrDistMat will be incomplete
if rlist != rlist_original:
rlist = rlist_original
qr_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc,
sys.stderr.write("Reading pairwise distances for tree construction\n")

# Process dense distance matrix
rlist, qlist, self, complete_distMat = readPickle(distances)
if not self:
qr_distMat = complete_distMat
combined_seq = rlist + qlist
else:
rr_distMat = complete_distMat
combined_seq = rlist

# Fill in qq-distances if required
if self == False:
sys.stderr.write("Note: Distances in " + distances + " are from assign mode\n"
"Note: Distance will be extended to full all-vs-all distances\n"
"Note: Re-run poppunk_assign with --update-db to avoid this\n")
rlist_original, qlist_original, self_ref, rr_distMat = readPickle(ref_db_loc + ".dists")
if not self_ref:
sys.stderr.write("Distances in " + ref_db + " not self all-vs-all either\n")
sys.exit(1)
kmers, sketch_sizes, codon_phased = readDBParams(query_db)
addRandom(query_db, qlist, kmers,
strand_preserved = strand_preserved, threads = threads)
query_db_loc = query_db + "/" + os.path.basename(query_db)
qq_distMat = pp_sketchlib.queryDatabase(ref_db_name=query_db_loc,
query_db_name=query_db_loc,
rList=rlist,
rList=qlist,
qList=qlist,
klist=kmers,
random_correct=True,
Expand All @@ -339,46 +384,42 @@ def generate_visualisations(query_db,
use_gpu=gpu_dist,
device_id=deviceid)

else:
qlist = None
qr_distMat = None
qq_distMat = None

# Turn long form matrices into square form
combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(rlist, rr_distMat,
qlist, qr_distMat, qq_distMat,
threads = threads)

#*******************************#
#* *#
#* Extract subset of sequences *#
#* *#
#*******************************#

# extract subset of distances if requested
all_seq = combined_seq
if include_files is not None or use_partial_query_graph is not None:
viz_subset = set()
subset_file = include_files if include_files is not None else use_partial_query_graph
with open(subset_file, 'r') as assemblyFiles:
for assembly in assemblyFiles:
viz_subset.add(assembly.rstrip())
if len(viz_subset.difference(combined_seq)) > 0:
sys.stderr.write("--include-files contains names not in --distances\n")
# If the assignment was run with references, qrDistMat will be incomplete
if rlist != rlist_original:
rlist = rlist_original
qr_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc,
query_db_name=query_db_loc,
rList=rlist,
qList=qlist,
klist=kmers,
random_correct=True,
jaccard=False,
num_threads=threads,
use_gpu=gpu_dist,
device_id=deviceid)

# Only keep found rows
row_slice = [True if name in viz_subset else False for name in combined_seq]
combined_seq = [name for name in combined_seq if name in viz_subset]
if use_sparse:
sparse_mat = sparse_mat[np.ix_(row_slice, row_slice)]
if use_dense:
if qlist != None:
qlist = list(viz_subset.intersection(qlist))
core_distMat = core_distMat[np.ix_(row_slice, row_slice)]
acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)]
else:
viz_subset = None
else:
qlist = None
qr_distMat = None
qq_distMat = None

# Turn long form matrices into square form
combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(rlist, rr_distMat,
qlist, qr_distMat, qq_distMat,
threads = threads)

# Prune distance matrix if subsetting data
if viz_subset is not None:
row_slice = [True if name in viz_subset else False for name in combined_seq]
combined_seq = [name for name in combined_seq if name in viz_subset]
if use_sparse:
sparse_mat = sparse_mat[np.ix_(row_slice, row_slice)]
if use_dense:
if qlist != None:
qlist = list(viz_subset.intersection(qlist))
core_distMat = core_distMat[np.ix_(row_slice, row_slice)]
acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)]

#**********************************#
#* *#
Expand Down Expand Up @@ -670,6 +711,7 @@ def main():
args.overwrite,
args.display_cluster,
args.use_partial_query_graph,
args.recalculate_distances,
args.tmp)

if __name__ == '__main__':
Expand Down

0 comments on commit 52dde86

Please sign in to comment.