Skip to content

Commit

Permalink
Resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjcroucher committed Feb 29, 2024
2 parents 2920b00 + 5652532 commit 82649c5
Show file tree
Hide file tree
Showing 14 changed files with 393 additions and 171 deletions.
24 changes: 19 additions & 5 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_options():
type=float)

# model refinement
refinementGroup = parser.add_argument_group('Refine model options')
refinementGroup = parser.add_argument_group('Network analysis and model refinement options')
refinementGroup.add_argument('--pos-shift', help='Maximum amount to move the boundary right past between-strain mean',
type=float, default = 0)
refinementGroup.add_argument('--neg-shift', help='Maximum amount to move the boundary left past within-strain mean]',
Expand All @@ -156,6 +156,9 @@ def get_options():
refinementGroup.add_argument('--score-idx',
help='Index of score to use [default = 0]',
type=int, default = 0, choices=[0, 1, 2])
refinementGroup.add_argument('--summary-sample',
help='Number of sequences used to estimate graph properties [default = all]',
type=int, default = None)
refinementGroup.add_argument('--betweenness-sample',
help='Number of sequences used to estimate betweeness with a GPU [default = 100]',
type = int, default = betweenness_sample_default)
Expand Down Expand Up @@ -264,6 +267,7 @@ def main():

from .plot import writeClusterCsv
from .plot import plot_scatter
from .plot import plot_database_evaluations

from .qc import prune_distance_matrix, qcDistMat, sketchlibAssemblyQC, remove_qc_fail

Expand Down Expand Up @@ -387,8 +391,9 @@ def main():
# Plot results
if not args.no_plot:
plot_scatter(distMat,
f"{args.output}/{os.path.basename(args.output)}_distanceDistribution",
args.output,
args.output + " distances")
plot_database_evaluations(args.output)

#******************************#
#* *#
Expand Down Expand Up @@ -427,15 +432,13 @@ def main():
fail_unconditionally[line.rstrip] = ["removed"]

# assembly qc
sys.stderr.write("Running sequence QC\n")
pass_assembly_qc, fail_assembly_qc = \
sketchlibAssemblyQC(args.ref_db,
refList,
qc_dict)
sys.stderr.write(f"{len(fail_assembly_qc)} samples failed\n")

# QC pairwise distances to identify long distances indicative of anomalous sequences in the collection
sys.stderr.write("Running distance QC\n")
pass_dist_qc, fail_dist_qc = \
qcDistMat(distMat,
refList,
Expand All @@ -452,13 +455,20 @@ def main():
raise RuntimeError('Type isolate ' + qc_dict['type_isolate'] + \
' not found in isolates after QC; check '
'name of type isolate and QC options\n')


sys.stderr.write(f"{len(passed)} samples passed QC\n")
if len(passed) < len(refList):
remove_qc_fail(qc_dict, refList, passed,
[fail_unconditionally, fail_assembly_qc, fail_dist_qc],
args.ref_db, distMat, output,
args.strand_preserved, args.threads)

# Plot results
if not args.no_plot:
plot_scatter(distMat,
output,
output + " distances")
plot_database_evaluations(output)

#******************************#
#* *#
Expand Down Expand Up @@ -545,6 +555,7 @@ def main():
args.score_idx,
args.no_local,
args.betweenness_sample,
args.summary_sample,
args.gpu_graph)
model = new_model
elif args.fit_model == "threshold":
Expand Down Expand Up @@ -613,6 +624,7 @@ def main():
model.within_label,
distMat = distMat,
weights_type = weights_type,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph)
else:
Expand All @@ -628,6 +640,7 @@ def main():
refList,
assignments[rank],
weights = weights,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph,
summarise = False
Expand Down Expand Up @@ -685,6 +698,7 @@ def main():
queryList,
indivAssignments,
model.within_label,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph)
isolateClustering[dist_type] = \
Expand Down
29 changes: 27 additions & 2 deletions PopPUNK/bgmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,36 @@ def fit2dMultiGaussian(X, dpgmm_max_K = 2):
return dpgmm


def findBetweenLabel_bgmm(means, assignments, rank = 0):
"""Identify between-strain links
Finds the component with the largest number of points
assigned to it
Args:
means (numpy.array)
K x 2 array of mixture component means
assignments (numpy.array)
Sample cluster assignments
rank (int)
Which label to find, ordered by distance from origin. 0-indexed.
(default = 0)
Returns:
between_label (int)
The cluster label with the most points assigned to it
"""
most_dists = {}
for mixture_component, distance in enumerate(np.apply_along_axis(np.linalg.norm, 1, means)):
most_dists[mixture_component] = np.count_nonzero(assignments == mixture_component)

sorted_dists = sorted(most_dists.items(), key=operator.itemgetter(1), reverse=True)
return(sorted_dists[rank][0])

def findWithinLabel(means, assignments, rank = 0):
"""Identify within-strain links
Finds the component with mean closest to the origin and also akes sure
Finds the component with mean closest to the origin and also makes sure
some samples are assigned to it (in the case of small weighted
components with a Dirichlet prior some components are unused)
Expand All @@ -59,7 +85,6 @@ def findWithinLabel(means, assignments, rank = 0):
Sample cluster assignments
rank (int)
Which label to find, ordered by distance from origin. 0-indexed.
(default = 0)
Returns:
Expand Down
55 changes: 30 additions & 25 deletions PopPUNK/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ def get_options():
# main code
def main():

# Import value
from .__main__ import betweenness_sample_default

# Import functions
from .network import load_network_file
from .network import sparse_mat_to_network
from .network import print_network_summary
from .utils import check_and_set_gpu
from .utils import setGtThreads

Expand Down Expand Up @@ -103,6 +107,32 @@ def main():
use_rc = ref_db['sketches'].attrs['use_rc'] == 1
print("Uses canonical k-mers:\t" + str(use_rc))

# Select network file name
if args.network_file is None:
if use_gpu:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz')
else:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt')
else:
network_file = args.network_file

# Open network file
if network_file.endswith('.gt'):
G = load_network_file(network_file, use_gpu = False)
elif network_file.endswith('.csv.gz'):
if use_gpu:
G = load_network_file(network_file, use_gpu = True)
else:
sys.stderr.write('Unable to load necessary GPU libraries\n')
sys.exit(1)
elif network_file.endswith('.npz'):
sparse_mat = sparse.load_npz(network_file)
G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu)
else:
sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n')
sys.exit(1)
print_network_summary(G, betweenness_sample = betweenness_sample_default, use_gpu = args.use_gpu)

# Print sample information
if not args.simple:
sample_names = list(ref_db['sketches'].keys())
Expand All @@ -115,31 +145,6 @@ def main():
sample_sequence_length[sample_name] = ref_db['sketches/' + sample_name].attrs['length']
sample_missing_bases[sample_name] = ref_db['sketches/' + sample_name].attrs['missing_bases']

# Select network file name
if args.network_file is None:
if use_gpu:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz')
else:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt')
else:
network_file = args.network_file

# Open network file
if network_file.endswith('.gt'):
G = load_network_file(network_file, use_gpu = False)
elif network_file.endswith('.csv.gz'):
if use_gpu:
G = load_network_file(network_file, use_gpu = True)
else:
sys.stderr.write('Unable to load necessary GPU libraries\n')
sys.exit(1)
elif network_file.endswith('.npz'):
sparse_mat = sparse.load_npz(network_file)
G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu)
else:
sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n')
sys.exit(1)

# Analyse network
if use_gpu:
component_assignments_df = cugraph.components.connectivity.connected_components(G)
Expand Down
Loading

0 comments on commit 82649c5

Please sign in to comment.