diff --git a/strainy.py b/strainy.py index 8e140c5..34ee921 100755 --- a/strainy.py +++ b/strainy.py @@ -6,10 +6,9 @@ import re import subprocess import argparse -import gfapy import logging import shutil - +import cProfile import gfapy @@ -52,23 +51,21 @@ def get_processor_name(): def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - requiredNamed = parser.add_argument_group('Required named arguments') requiredNamed.add_argument("-o", "--output", help="output directory",required=True) requiredNamed.add_argument("-g", "--gfa", help="input gfa to uncollapse",required=True) - requiredNamed.add_argument("-m", "--mode", help="type of reads", choices=["hifi", "nano"], required=True) - requiredNamed.add_argument("-q", "--fastq", - help="fastq file with reads to phase / assemble", + requiredNamed.add_argument("-m", "--mode", help="type of reads", choices=["hifi", "nano"], + required=True) + requiredNamed.add_argument("-q", "--fastq",help="fastq file with reads to phase / assemble", required=True) - parser.add_argument("-s", "--stage", help="stage to run: either phase, transform or e2e (phase + transform)", choices=["phase", "transform", "e2e"], default="e2e") parser.add_argument("--snp", help="path to vcf file with SNP calls to use", default=None) parser.add_argument("-t", "--threads", help="number of threads to use", type=int, default=4) parser.add_argument("-f", "--fasta", required=False, help=argparse.SUPPRESS) parser.add_argument("-b", "--bam", help="path to indexed alignment in bam format",required=False) - parser.add_argument("--link-simplify", required=False, action="store_true", default=False, dest="link_simplify", - help="Enable agressive graph simplification") + parser.add_argument("--link-simplify", required=False, action="store_true", default=False, + dest="link_simplify",help="Enable agressive graph simplification") parser.add_argument("--debug", required=False, action="store_true", default=False, help="Generate extra output for debugging") parser.add_argument("--unitig-split-length", @@ -76,9 +73,11 @@ def main(): required=False, type=int, default=50) - parser.add_argument("--only-split",help="Do not run stRainy, only split long gfa unitigs", default='False', required=False) + parser.add_argument("--only-split",help="Do not run stRainy, only split long gfa unitigs", default='False', + required=False) parser.add_argument("-d","--cluster-divergence",help="cluster divergence", type=float, default=0, required=False) - parser.add_argument("-a","--allele-frequency",help="Set allele frequency for internal caller only (pileup)", type=float, default=0.2, required=False) + parser.add_argument("-a","--allele-frequency",help="Set allele frequency for internal caller only (pileup)", + type=float, default=0.2, required=False) parser.add_argument("--min-unitig-length", help="The length (in kb) which the unitigs that are shorter will not be phased", required=False, @@ -129,7 +128,6 @@ def main(): elif args.stage == "transform": sys.exit(transform_main(args)) elif args.stage == "e2e": - import cProfile pr_phase = cProfile.Profile() pr_phase.enable() phase_main(args) diff --git a/strainy/color_bam.py b/strainy/color_bam.py index 26b93a6..b7ce779 100644 --- a/strainy/color_bam.py +++ b/strainy/color_bam.py @@ -12,14 +12,21 @@ from strainy.params import * -def write_bam(edge, I, AF): +def write_bam(edge, I, AF, cl_file=None,file=None): infile = pysam.AlignmentFile(StRainyArgs().bam, "rb") - outfile = pysam.AlignmentFile("%s/bam/coloredBAM_unitig_%s.bam" % (StRainyArgs().output_intermediate, edge), "wb", template=infile) - cl = pd.read_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF),keep_default_na=False) + if file==None: + outfile = pysam.AlignmentFile("%s/bam/coloredBAM_unitig_%s.bam" % (StRainyArgs().output_intermediate, edge), "wb", template=infile) + else: + outfile = pysam.AlignmentFile(file,"wb", template=infile) + if cl_file==None: + cl = pd.read_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF),keep_default_na=False) + else: + cl = pd.read_csv(cl_file,keep_default_na=False) iter = infile.fetch(edge,until_eof=True) cmap = plt.get_cmap("viridis") cl.loc[cl["Cluster"] == "NA", "Cluster"] = 0 - clusters=sorted(set(cl["Cluster"].astype(int))) + #clusters=sorted(set(cl["Cluster"].astype(int))) + clusters = set(cl["Cluster"]) cmap = cmap(np.linspace(0, 1, len(clusters))) colors={} i=0 @@ -35,7 +42,8 @@ def write_bam(edge, I, AF): for read in iter: try: - clN = int(cl_dict[str(read).split()[0]]) + #clN = int(cl_dict[str(read).split()[0]]) + clN = cl_dict[str(read).split()[0]] tag = colors[clN] read.set_tag("YC", tag, replace=False) outfile.write(read) @@ -44,9 +52,9 @@ def write_bam(edge, I, AF): outfile.close() -def color(edge): +def color(edge,cl_file=None,file=None): try: - write_bam(edge, I, StRainyArgs().AF) + write_bam(edge, I, StRainyArgs().AF,cl_file,file) except (FileNotFoundError): pass diff --git a/strainy/phase.py b/strainy/phase.py index 3048978..f9aef99 100644 --- a/strainy/phase.py +++ b/strainy/phase.py @@ -1,14 +1,13 @@ import multiprocessing -import pysam -import pickle import os +import pickle import sys import subprocess -import multiprocessing import logging import shutil import traceback import time +import pysam from strainy.clustering.cluster import cluster from strainy.color_bam import color @@ -24,13 +23,14 @@ def _thread_fun(i, shared_flye_consensus, args): init_global_args_storage(args) set_thread_logging(StRainyArgs().log_phase, "phase", multiprocessing.current_process().pid) - logger.info("\n\n\t == == Processing unitig " + str(StRainyArgs().edges_to_phase[i]) + " == == ") + logger.info("\n\n\t == == Processing unitig " + + str(StRainyArgs().edges_to_phase[i]) + " == == ") try: cluster(i, shared_flye_consensus) - except Exception as e: - logger.error("Worker thread exception! " + str(e) + "\n" + traceback.format_exc()) - raise e + except Exception as excpt: + logger.error("Worker thread exception! " + str(excpt) + "\n" + traceback.format_exc()) + raise excpt logger.debug("Thread worker function finished!") @@ -40,7 +40,8 @@ def phase(edges, args): empty_consensus_dict = {} default_manager = multiprocessing.Manager() - shared_flye_consensus = FlyeConsensus(StRainyArgs().bam, StRainyArgs().fa, 1, empty_consensus_dict, default_manager) + shared_flye_consensus = FlyeConsensus(StRainyArgs().bam, StRainyArgs().fa, 1, + empty_consensus_dict, default_manager) if StRainyArgs().threads == 1: for i in range(len(edges)): cluster(i, shared_flye_consensus) @@ -62,38 +63,52 @@ def phase(edges, args): return shared_flye_consensus.get_consensus_dict() -def color_bam(edges): +def color_bam(edges, transfrom_stage=False): logger.info("Creating phased bam") - for e in edges: - color(e) + if transfrom_stage == False: + for edge in edges: + color(edge) + out_bam_dir = os.path.join(StRainyArgs().output_intermediate, "bam") + final_aln = os.path.join(StRainyArgs().output, "alignment_phased.bam") + + else: + for edge in edges: + color(edge, cl_file="%s/clusters/clusters_%s_%s_%s_MERGED.csv" % + (StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF), + file="%s/bam/merged/coloredBAM_unitig_%s_merged.bam" % (StRainyArgs().output_intermediate, edge)) + out_bam_dir = os.path.join(StRainyArgs().output_intermediate, "bam/merged") + final_aln = os.path.join(StRainyArgs().output, "alignment_phased_merged.bam") - out_bam_dir = os.path.join(StRainyArgs().output_intermediate, "bam") - final_aln = os.path.join(StRainyArgs().output, "alignment_phased.bam") files_to_be_merged = [] - for fname in subprocess.check_output(f'find {out_bam_dir} -name "*unitig*.bam"', shell = True, universal_newlines = True).split("\n"): + for fname in subprocess.check_output(f'find {out_bam_dir} -name "*unitig*.bam"', + shell = True, universal_newlines = True).split("\n"): if len(fname): files_to_be_merged.append(fname) - # Number of file to be merged could be > 4092, in which case samtools merge throws too many open files error + # Number of file to be merged could be > 4092, + # in which case samtools merge throws too many open files error for i, bam_file in enumerate(files_to_be_merged): # fetch the header and put it at the top of the file, for the first bam_file only if i == 0: - subprocess.check_output(f'samtools view -H {bam_file} > {out_bam_dir}/coloredSAM.sam', shell = True) + subprocess.check_output(f'samtools view -H {bam_file} > ' + f'{out_bam_dir}/coloredSAM.sam',shell = True) # convert bam to sam, append to the file - subprocess.check_output(f'samtools view {bam_file} >> {out_bam_dir}/coloredSAM.sam', shell = True) + subprocess.check_output(f'samtools view {bam_file} >> {out_bam_dir}/coloredSAM.sam', + shell = True) # convert the file to bam and sort - subprocess.check_output(f'samtools view -b {out_bam_dir}/coloredSAM.sam >> {out_bam_dir}/unsortedBAM.bam', shell = True) + subprocess.check_output(f'samtools view -b {out_bam_dir}/coloredSAM.sam >> ' + f'{out_bam_dir}/unsortedBAM.bam',shell = True) pysam.samtools.sort(f'{out_bam_dir}/unsortedBAM.bam', "-o", final_aln) pysam.samtools.index(final_aln) # remove unnecessary files os.remove(f'{out_bam_dir}/unsortedBAM.bam') os.remove(f'{out_bam_dir}/coloredSAM.sam') - for f in files_to_be_merged: - os.remove(f) + for file in files_to_be_merged: + os.remove(file) def phase_main(args): @@ -105,8 +120,7 @@ def phase_main(args): "%s/bam/clusters" % StRainyArgs().output_intermediate, "%s/flye_inputs" % StRainyArgs().output_intermediate, "%s/flye_outputs" % StRainyArgs().output_intermediate - ) - +) debug_dirs = ("%s/graphs/" % StRainyArgs().output_intermediate, "%s/adj_M/" % StRainyArgs().output_intermediate ) diff --git a/strainy/preprocessing.py b/strainy/preprocessing.py index 4f8a164..5b3e709 100644 --- a/strainy/preprocessing.py +++ b/strainy/preprocessing.py @@ -1,8 +1,8 @@ -import subprocess import os import logging import pysam import gfapy +import subprocess from strainy.params import StRainyArgs diff --git a/strainy/transform.py b/strainy/transform.py index fcaadf7..0f0b144 100644 --- a/strainy/transform.py +++ b/strainy/transform.py @@ -12,7 +12,7 @@ import time import traceback import csv - +from strainy.color_bam import color import strainy.clustering.build_adj_matrix as matrix import strainy.clustering.cluster_postprocess as postprocess import strainy.simplification.simplify_links as smpl @@ -24,6 +24,7 @@ from strainy.reports.strainy_stats import strain_stats_report from strainy.reports.call_variants import produce_strainy_vcf from strainy.preprocessing import gfa_to_fasta +from strainy.phase import color_bam logger = logging.getLogger() @@ -1035,8 +1036,12 @@ def transform_main(args): shutil.copyfile(out_clusters, strainy_utgs) phased_graph = gfapy.Gfa.from_file(out_clusters) #parsing again because gfapy can"t copy + + segs_unmerged=phased_graph.segment_names gfapy.GraphOperations.merge_linear_paths(phased_graph) clean_graph(phased_graph) + segs_merged = phased_graph.segment_names + out_merged = os.path.join(StRainyArgs().output_intermediate, "20_extended_haplotypes.gfa") gfapy.Gfa.to_file(phased_graph, out_merged) @@ -1066,5 +1071,26 @@ def transform_main(args): produce_strainy_vcf(StRainyArgs().fa, strain_utgs_fasta, StRainyArgs().threads, strain_utgs_aln, open(vcf_strain_variants, "w")) + logger.info("Update clusters and colored BAM") + merged_clusters={} + AF = StRainyArgs().AF + #I = StRainyArgs().I + for seg in [i for i in segs_unmerged if i not in segs_merged]: + seg_merged = [k for k in segs_merged if re.search(seg, k) != None][0] + merged_clusters[seg] = seg_merged + + for edge in StRainyArgs().edges: + try: + cl = pd.read_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF), + keep_default_na=False) + clusters = sorted(set(cl['Cluster'])) + for cluster in clusters: + seg=str(edge)+"_"+str(cluster) + if seg in merged_clusters.keys(): + cl.loc[cl['Cluster'] == cluster, 'Cluster'] = merged_clusters[seg] + cl.to_csv("%s/clusters/clusters_%s_%s_%s_MERGED.csv" % (StRainyArgs().output_intermediate, edge, I, AF)) + except(FileNotFoundError): pass + os.makedirs("%s/bam/merged/" % StRainyArgs().output_intermediate, exist_ok=True) + color_bam(StRainyArgs().edges, transfrom_stage=True) flye_consensus.print_cache_statistics() logger.info("### Done!")