Skip to content

Commit

Permalink
Restructure to enable selection by cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjcroucher committed Oct 14, 2024
1 parent 33c2500 commit a1d088d
Showing 1 changed file with 89 additions and 81 deletions.
170 changes: 89 additions & 81 deletions PopPUNK/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,17 @@ def generate_visualisations(query_db,

#******************************#
#* *#
#* Process dense or sparse *#
#* distances *#
#* Determine type of distance *#
#* to use *#
#* *#
#******************************#

# Determine whether to use sparse distances
use_sparse = False
use_dense = False

if (tree == "mst" or tree == "both") and rank_fit is not None:
if (tree == "nj" or tree == "both") or rank_fit == None:
use_dense = True
elif (tree == "mst" or tree == "both") and rank_fit is not None:
# Set flag
use_sparse = True
# Read list of sequence names and sparse distance matrix
Expand All @@ -336,9 +337,91 @@ def generate_visualisations(query_db,
elif previous_mst is not None:
sys.stderr.write('The prefix of the distance files used to create the previous MST'
' is needed to use the network')


#**********************************#
#* *#
#* Process clustering information *#
#* *#
#**********************************#

# identify existing model and cluster files
if model_dir is not None:
model_prefix = model_dir
else:
model_prefix = ref_db
try:
model_file = os.path.join(model_prefix, os.path.basename(model_prefix))
model = loadClusterFit(model_file + '_fit.pkl',
model_file + '_fit.npz')
model.set_threads(threads)
except FileNotFoundError:
sys.stderr.write('Unable to locate previous model fit in ' + model_prefix + '\n')
sys.exit(1)

# Either use strain definitions, lineage assignments or external clustering
isolateClustering = {}
# Use external clustering if specified
if external_clustering:
mode = 'external'
cluster_file = external_clustering
if cluster_file.endswith('_lineages.csv'):
suffix = "_lineages.csv"
else:
suffix = "_clusters.csv"
else:
# Load previous clusters
if previous_clustering is not None:
cluster_file = previous_clustering
mode = "clusters"
suffix = "_clusters.csv"
if cluster_file.endswith('_lineages.csv'):
mode = "lineages"
suffix = "_lineages.csv"
else:
# Identify type of clustering based on model
mode = "clusters"
suffix = "_clusters.csv"
if model.type == "lineage":
mode = "lineages"
suffix = "_lineages.csv"
cluster_file = os.path.join(model_prefix, os.path.basename(model_prefix) + suffix)

isolateClustering = readIsolateTypeFromCsv(cluster_file,
mode = mode,
return_dict = True)

# Add individual refinement clusters if they exist
if model.indiv_fitted:
for type, indiv_suffix in zip(['Core','Accessory'],['_core_clusters.csv','_accessory_clusters.csv']):
indiv_clustering = os.path.join(model_prefix, os.path.basename(model_prefix) + indiv_suffix)
if os.path.isfile(indiv_clustering):
indiv_isolateClustering = readIsolateTypeFromCsv(indiv_clustering,
mode = mode,
return_dict = True)
isolateClustering[type] = indiv_isolateClustering['Cluster']

# Join clusters with query clusters if required
if use_dense:
if query_db is not None:
if previous_query_clustering is not None:
prev_query_clustering = previous_query_clustering
else:
prev_query_clustering = os.path.join(query_db, os.path.basename(query_db) + suffix)

queryIsolateClustering = readIsolateTypeFromCsv(
prev_query_clustering,
mode = mode,
return_dict = True)
isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering)

#******************************#
#* *#
#* Process dense or sparse *#
#* distances *#
#* *#
#******************************#

if (tree == "nj" or tree == "both") or rank_fit == None:
use_dense = True

# Either calculate or read distances
if recalculate_distances:
Expand Down Expand Up @@ -456,81 +539,6 @@ def generate_visualisations(query_db,
core_distMat = core_distMat[np.ix_(row_slice, row_slice)]
acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)]

#**********************************#
#* *#
#* Process clustering information *#
#* *#
#**********************************#

# identify existing model and cluster files
if model_dir is not None:
model_prefix = model_dir
else:
model_prefix = ref_db
try:
model_file = os.path.join(model_prefix, os.path.basename(model_prefix))
model = loadClusterFit(model_file + '_fit.pkl',
model_file + '_fit.npz')
model.set_threads(threads)
except FileNotFoundError:
sys.stderr.write('Unable to locate previous model fit in ' + model_prefix + '\n')
sys.exit(1)

# Either use strain definitions, lineage assignments or external clustering
isolateClustering = {}
# Use external clustering if specified
if external_clustering:
mode = 'external'
cluster_file = external_clustering
if cluster_file.endswith('_lineages.csv'):
suffix = "_lineages.csv"
else:
suffix = "_clusters.csv"
else:
# Load previous clusters
if previous_clustering is not None:
cluster_file = previous_clustering
mode = "clusters"
suffix = "_clusters.csv"
if cluster_file.endswith('_lineages.csv'):
mode = "lineages"
suffix = "_lineages.csv"
else:
# Identify type of clustering based on model
mode = "clusters"
suffix = "_clusters.csv"
if model.type == "lineage":
mode = "lineages"
suffix = "_lineages.csv"
cluster_file = os.path.join(model_prefix, os.path.basename(model_prefix) + suffix)

isolateClustering = readIsolateTypeFromCsv(cluster_file,
mode = mode,
return_dict = True)

# Add individual refinement clusters if they exist
if model.indiv_fitted:
for type, indiv_suffix in zip(['Core','Accessory'],['_core_clusters.csv','_accessory_clusters.csv']):
indiv_clustering = os.path.join(model_prefix, os.path.basename(model_prefix) + indiv_suffix)
if os.path.isfile(indiv_clustering):
indiv_isolateClustering = readIsolateTypeFromCsv(indiv_clustering,
mode = mode,
return_dict = True)
isolateClustering[type] = indiv_isolateClustering['Cluster']

# Join clusters with query clusters if required
if use_dense:
if query_db is not None:
if previous_query_clustering is not None:
prev_query_clustering = previous_query_clustering
else:
prev_query_clustering = os.path.join(query_db, os.path.basename(query_db) + suffix)

queryIsolateClustering = readIsolateTypeFromCsv(
prev_query_clustering,
mode = mode,
return_dict = True)
isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering)

#*******************#
#* *#
Expand Down

0 comments on commit a1d088d

Please sign in to comment.