diff --git a/neusomatic/python/call.py b/neusomatic/python/call.py index 2d9a3ea..e5ef8f5 100755 --- a/neusomatic/python/call.py +++ b/neusomatic/python/call.py @@ -11,6 +11,7 @@ import shutil import pickle import multiprocessing +import copy import pysam import numpy as np @@ -69,6 +70,11 @@ def call_variants(net, call_loader, out_dir, model_tag, use_cuda): for data in loader_: (matrices, labels, var_pos_s, var_len_s, non_transformed_matrices), (paths) = data + + paths_ = copy.deepcopy(paths) + del paths + paths = paths_ + matrices = Variable(matrices) iii += 1 j += len(paths[0]) diff --git a/neusomatic/python/filter_candidates.py b/neusomatic/python/filter_candidates.py index 0354b31..39a1fdf 100755 --- a/neusomatic/python/filter_candidates.py +++ b/neusomatic/python/filter_candidates.py @@ -3,6 +3,7 @@ # filter_candidates.py # filter raw candidates extracted by 'scan_alignments.py' using min_af and other cut-offs #------------------------------------------------------------------------- +import os import argparse import traceback import logging @@ -25,6 +26,16 @@ def filter_candidates(candidate_record): thread_logger.info( "---------------------Filter Candidates---------------------") + if dbsnp: + if not dbsnp.endswith("vcf.gz"): + thread_logger.error("Aborting!") + raise Exception( + "The dbSNP file should be a tabix indexed file with .vcf.gz format") + if not os.path.exists(dbsnp + ".tbi"): + thread_logger.error("Aborting!") + raise Exception( + "The dbSNP file should be a tabix indexed file with .vcf.gz format. No {}.tbi file exists.".format(dbsnp)) + records = {} with open(candidates_vcf) as v_f: for line in skip_empty(v_f): @@ -257,47 +268,25 @@ def filter_candidates(candidate_record): final_records.append([chrom, pos - 1, ref, alt, line]) final_records = sorted(final_records, key=lambda x: x[0:2]) if dbsnp: - filtered_bed = get_tmp_file() - intervals = [] - for x in enumerate(final_records): - intervals.append([x[1][0], int(x[1][1]), int( - x[1][1]) + 1, x[1][2], x[1][3], str(x[0])]) - write_tsv_file(filtered_bed, intervals) - filtered_bed = bedtools_sort( - filtered_bed, run_logger=thread_logger) - - dbsnp_tmp = get_tmp_file() - vcf_2_bed(dbsnp, dbsnp_tmp) - bedtools_sort(dbsnp_tmp, output_fn=dbsnp, run_logger=thread_logger) - non_in_dbsnp_1 = bedtools_window( - filtered_bed, dbsnp, args=" -w 0 -v", run_logger=thread_logger) - non_in_dbsnp_2 = bedtools_window( - filtered_bed, dbsnp, args=" -w 0", run_logger=thread_logger) - - tmp_ = get_tmp_file() - with open(non_in_dbsnp_2) as i_f, open(tmp_, "w") as o_f: - for line in skip_empty(i_f): - x = line.strip().split() - if x[1]!=x[7] or x[3]!=x[9] or x[4]!=x[10]: - o_f.write(line) - non_in_dbsnp_2 = tmp_ - - non_in_dbsnp_ids = [] - with open(non_in_dbsnp_1) as i_f: - for line in skip_empty(i_f): - x = line.strip().split("\t") - non_in_dbsnp_ids.append(int(x[5])) - with open(non_in_dbsnp_2) as i_f: - for line in skip_empty(i_f): - x = line.strip().split("\t") - non_in_dbsnp_ids.append(int(x[5])) - final_records = list(map(lambda x: x[1], filter( - lambda x: x[0] in non_in_dbsnp_ids, enumerate(final_records)))) + dbsnp_tb = pysam.TabixFile(dbsnp) with open(filtered_candidates_vcf, "w") as o_f: o_f.write("{}\n".format(VCF_HEADER)) o_f.write( "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") for record in final_records: + if dbsnp: + chrom, pos, ref, alt = record[0:4] + var_id = "-".join(map(str,[chrom, pos, ref, alt])) + region = "{}:{}-{}".format(chrom, pos, pos + 1) + dbsnp_vars = [] + for x in dbsnp_tb.fetch(region=region): + chrom_, pos_, _, ref_, alts_ = x.strip().split("\t")[ + 0:5] + for alt_ in alts_.split(","): + dbsnp_var_id = "-".join(map(str,[chrom_, pos_, ref_, alt_])) + dbsnp_vars.append(dbsnp_var_id) + if var_id in dbsnp_vars: + continue o_f.write(record[-1] + "\n") return filtered_candidates_vcf @@ -320,7 +309,7 @@ def filter_candidates(candidate_record): parser.add_argument('--reference', type=str, help='reference fasta filename', required=True) parser.add_argument('--dbsnp_to_filter', type=str, - help='dbsnp vcf (will be used to filter candidate variants)', default=None) + help='dbsnp vcf.gz (will be used to filter candidate variants)', default=None) parser.add_argument('--good_ao', type=float, help='good alternate count (ignores maf)', default=10) parser.add_argument('--min_ao', type=float, diff --git a/neusomatic/python/generate_dataset.py b/neusomatic/python/generate_dataset.py index 713e864..5a90d81 100755 --- a/neusomatic/python/generate_dataset.py +++ b/neusomatic/python/generate_dataset.py @@ -17,7 +17,7 @@ import numpy as np import pysam -from scipy.misc import imresize +from PIL import Image from split_bed import split_region from utils import concatenate_vcfs, get_chromosomes_order, run_bedtools_cmd, vcf_2_bed, bedtools_sort, bedtools_window, bedtools_intersect, bedtools_slop, get_tmp_file, skip_empty @@ -513,39 +513,43 @@ def prepare_info_matrices_tabix(ref_file, tumor_count_bed, normal_count_bed, rec else: col_pos_map = {i: int(round(v / float(ncols) * matrix_width)) for i, v in col_pos_map.items()} - tumor_count_matrix = imresize( - tumor_matrix, (5, matrix_width)).astype(int) - bq_tumor_count_matrix = imresize( - bq_tumor_matrix, (5, matrix_width)).astype(int) - mq_tumor_count_matrix = imresize( - mq_tumor_matrix, (5, matrix_width)).astype(int) - st_tumor_count_matrix = imresize( - st_tumor_matrix, (5, matrix_width)).astype(int) - lsc_tumor_count_matrix = imresize( - lsc_tumor_matrix, (5, matrix_width)).astype(int) - rsc_tumor_count_matrix = imresize( - rsc_tumor_matrix, (5, matrix_width)).astype(int) + tumor_count_matrix = np.array(Image.fromarray( + tumor_matrix).resize((matrix_width, 5), 2)).astype(int) + bq_tumor_count_matrix = np.array(Image.fromarray( + bq_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) + mq_tumor_count_matrix = np.array(Image.fromarray( + mq_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) + st_tumor_count_matrix = np.array(Image.fromarray( + st_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) + lsc_tumor_count_matrix = np.array(Image.fromarray( + lsc_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) + rsc_tumor_count_matrix = np.array(Image.fromarray( + rsc_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) + tag_tumor_count_matrices = [] for iii in range(len(tag_tumor_matrices)): tag_tumor_count_matrices.append( - imresize(tag_tumor_matrices[iii], (5, matrix_width)).astype(int)) - normal_count_matrix = imresize( - normal_matrix, (5, matrix_width)).astype(int) - bq_normal_count_matrix = imresize( - bq_normal_matrix, (5, matrix_width)).astype(int) - mq_normal_count_matrix = imresize( - mq_normal_matrix, (5, matrix_width)).astype(int) - st_normal_count_matrix = imresize( - st_normal_matrix, (5, matrix_width)).astype(int) - lsc_normal_count_matrix = imresize( - lsc_normal_matrix, (5, matrix_width)).astype(int) - rsc_normal_count_matrix = imresize( - rsc_normal_matrix, (5, matrix_width)).astype(int) + np.array(Image.fromarray(tag_tumor_matrices[iii]).resize((matrix_width, 5), 2)).astype(int)) + + normal_count_matrix = np.array(Image.fromarray( + normal_matrix).resize((matrix_width, 5), 2)).astype(int) + bq_normal_count_matrix = np.array(Image.fromarray( + bq_normal_matrix).resize((matrix_width, 5), 2)).astype(int) + mq_normal_count_matrix = np.array(Image.fromarray( + mq_normal_matrix).resize((matrix_width, 5), 2)).astype(int) + st_normal_count_matrix = np.array(Image.fromarray( + st_normal_matrix).resize((matrix_width, 5), 2)).astype(int) + lsc_normal_count_matrix = np.array(Image.fromarray( + lsc_normal_matrix).resize((matrix_width, 5), 2)).astype(int) + rsc_normal_count_matrix = np.array(Image.fromarray( + rsc_normal_matrix).resize((matrix_width, 5), 2)).astype(int) + tag_normal_count_matrices = [] for iii in range(len(tag_normal_matrices)): tag_normal_count_matrices.append( - imresize(tag_normal_matrices[iii], (5, matrix_width)).astype(int)) - ref_count_matrix = imresize(ref_matrix, (5, matrix_width)).astype(int) + np.array(Image.fromarray(tag_normal_matrices[iii]).resize((matrix_width, 5), 2)).astype(int)) + ref_count_matrix = np.array(Image.fromarray( + ref_matrix).resize((matrix_width, 5), 2)).astype(int) if int(pos) + rcenter[0] not in col_pos_map: center = min(col_pos_map.values()) + rcenter[0] - 1 + rcenter[1] @@ -873,6 +877,7 @@ def find_records(input_record): records = [] i = 0 anns = {} + fasta_file = pysam.Fastafile(ref_file) if ensemble_bed: with open(not_in_ensemble_bed) as ni_f: for line in skip_empty(ni_f): @@ -1033,8 +1038,16 @@ def find_records(input_record): with open(split_truth_vcf_file, 'r') as vcf_reader: for line in skip_empty(vcf_reader): record = line.strip().split() + pos = int(record[1]) + if len(record[3]) != len(record[4]) and min(len(record[3]), len(record[4])) > 0 and record[3][0] != record[4][0]: + if pos > 1: + l_base = fasta_file.fetch( + record[0], pos - 2, pos - 1).upper() + record[3] = l_base + record[3] + record[4] = l_base + record[4] + pos -= 1 truth_records.append( - [record[0], int(record[1]), record[3], record[4], str(i)]) + [record[0], pos, record[3], record[4], str(i)]) i += 1 truth_bed = get_tmp_file() @@ -1071,7 +1084,6 @@ def find_records(input_record): record_center = {} chroms_order = get_chromosomes_order(reference=ref_file) - fasta_file = pysam.Fastafile(ref_file) good_records = {"INS": [], "DEL": [], "SNP": []} vtype = {} @@ -1299,7 +1311,6 @@ def find_records(input_record): records_r = [records[x] for k, w in good_records.items() for x in w] N_none = len(none_records_ids) - thread_logger.info("N_none: {} ".format(N_none)) none_records = list(map(lambda x: records[x], none_records_ids)) none_records = sorted(none_records, key=lambda x: [x[0], int(x[1])]) diff --git a/neusomatic/python/long_read_indelrealign.py b/neusomatic/python/long_read_indelrealign.py index d56f220..4a74eda 100755 --- a/neusomatic/python/long_read_indelrealign.py +++ b/neusomatic/python/long_read_indelrealign.py @@ -279,7 +279,7 @@ def cigartuple_to_string(cigartuples): return "".join(map(lambda x: "%d%s" % (x[1], _CIGAR_OPS[x[0]]), cigartuples)) -def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i): +def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i, ds, filter_duplicate): logger = logging.getLogger(prepare_fasta.__name__) in_fasta_file = os.path.join( work, region.__str__() + "_split_{}".format(split_i) + "_0.fasta") @@ -290,7 +290,7 @@ def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i) with open(info_file, "w") as info_txt: if include_ref: ref_seq = ref_fasta.fetch( - region.chrom, region.start, region.end + 1) + region.chrom, region.start, region.end + 1).upper() in_fasta.write(">0\n") in_fasta.write("%s\n" % ref_seq.upper()) cnt = 1 @@ -298,6 +298,8 @@ def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i) for record in samfile.fetch(region.chrom, region.start, region.end + 1): if record.is_unmapped: continue + if filter_duplicate and record.is_duplicate: + continue if record.is_supplementary and "SA" in dict(record.tags): sas = dict(record.tags)["SA"].split(";") sas = list(filter(None, sas)) @@ -310,6 +312,8 @@ def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i) full_length=True))))) if not record.cigartuples: continue + if np.random.rand() > ds: + continue sc_start = (record.cigartuples[0][0] == CIGAR_SOFTCLIP) * record.cigartuples[0][1] sc_end = (record.cigartuples[-1][0] == @@ -343,7 +347,7 @@ def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i) (start_idx - sc_start):(end_idx - sc_start)] non_ins = np.nonzero(positions_ >= 0) refseq = ref_fasta.fetch(region.chrom, positions_[non_ins][0], - positions_[non_ins][-1] + 1) + positions_[non_ins][-1] + 1).upper() q_seq = record.seq[start_idx:end_idx + 1] non_ins_positions = positions_[non_ins] mn, mx = min(non_ins_positions), max( @@ -373,13 +377,15 @@ def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i) def split_bam_to_chunks(work, region, input_bam, chunk_size=200, - chunk_scale=1.5): + chunk_scale=1.5, do_split=False, filter_duplicate=False): logger = logging.getLogger(split_bam_to_chunks.__name__) records = [] with pysam.Samfile(input_bam, "rb") as samfile: for record in samfile.fetch(region.chrom, region.start, region.end + 1): if record.is_unmapped: continue + if filter_duplicate and record.is_duplicate: + continue if record.is_supplementary and "SA" in dict(record.tags): sas = dict(record.tags)["SA"].split(";") sas = list(filter(None, sas)) @@ -413,11 +419,13 @@ def split_bam_to_chunks(work, region, input_bam, chunk_size=200, if len(records) < chunk_size * chunk_scale: bams = [input_bam] lens = [len(records)] - else: + ds = [1] + elif do_split: n_splits = int(max(6, len(records) // chunk_size)) new_chunk_size = len(records) // n_splits bams = [] lens = [] + ds = [] n_split = (len(records) // new_chunk_size) + 1 if 0 < (len(records) - ((n_split - 1) * new_chunk_size) + new_chunk_size) \ < new_chunk_size * chunk_scale: @@ -441,7 +449,12 @@ def split_bam_to_chunks(work, region, input_bam, chunk_size=200, bams.append(split_input_bam) lens.append(i_end - i_start + 1) - return bams, lens + ds.append(1) + else: + bams = [input_bam] + lens = [chunk_size * chunk_scale] + ds = [chunk_size * chunk_scale / float(len(records))] + return bams, lens, ds def read_info(info_file): @@ -633,8 +646,8 @@ def find_realign_dict(realign_bed_file, chrom): realign_bed = get_tmp_file() with open(realign_bed_file) as i_f, open(realign_bed, "w") as o_f: for line in skip_empty(i_f): - x=line.strip().split() - if x[0]==chrom: + x = line.strip().split() + if x[0] == chrom: o_f.write(line) realign_dict = {} @@ -825,7 +838,7 @@ def run_msa(in_fasta_file, match_score, mismatch_penalty, gap_open_penalty, gap_ return out_fasta_file -def do_realign(region, info_file, thr_realign=0.0135, max_N=1000): +def do_realign(region, info_file, max_realign_dp, thr_realign=0.0135): logger = logging.getLogger(do_realign.__name__) sum_nm_snp = 0 sum_nm_indel = 0 @@ -837,7 +850,7 @@ def do_realign(region, info_file, thr_realign=0.0135, max_N=1000): sum_nm_indel += int(x[-1]) c += 1 eps = 0.0001 - if (c < max_N) and ( + if (c < max_realign_dp) and ( (sum_nm_snp + sum_nm_indel ) / float(c + eps) / float(region.span() + eps) > thr_realign): @@ -845,7 +858,10 @@ def do_realign(region, info_file, thr_realign=0.0135, max_N=1000): return False -def find_var(out_fasta_file, snp_min_af, del_min_af, ins_min_af, scale_maf): +def find_var(out_fasta_file, snp_min_af, del_min_af, ins_min_af, scale_maf, simplify): + # Find variants from MSA: + # In each column the AF is calculated + # The low AF vars in each column are discarded and the variant is extracted logger = logging.getLogger(find_var.__name__) records = SeqIO.to_dict(SeqIO.parse(out_fasta_file, "fasta")) if set(map(int, records.keys())) ^ set(range(len(records))): @@ -895,7 +911,60 @@ def find_var(out_fasta_file, snp_min_af, del_min_af, ins_min_af, scale_maf): map(lambda x: NUM_to_NUC[x], filter(lambda x: x > 0, ref_seq))) alt_seq_ = "".join( map(lambda x: NUM_to_NUC[x], filter(lambda x: x > 0, alt_seq))) - return ref_seq_, alt_seq_, afs + if not simplify: + variants = [[0, ref_seq_, alt_seq_, afs]] + else: + variants = [] + bias = 0 + current_ref = [] + current_alt = [] + current_af = [] + current_bias = 0 + is_ins = False + is_del = False + done = False + for i, (r, a) in enumerate(zip(list(ref_seq) + [0], list(alt_seq) + [0])): + if i in i_afs: + af = afs[i_afs.index(i)] + else: + af = 0 + if r == a: + done = True + else: + if r == 0 and a != 0: + if not is_ins: + done = True + elif r != 0 and a == 0: + if not is_del: + done = True + else: + done = True + if done: + if current_alt: + rr = "".join(map(lambda x: NUM_to_NUC[ + x], filter(lambda x: x > 0, current_ref))) + aa = "".join(map(lambda x: NUM_to_NUC[ + x], filter(lambda x: x > 0, current_alt))) + variants.append( + [current_bias, rr, aa, np.array(current_af)]) + done = False + current_ref = [] + current_alt = [] + current_af = [] + current_bias = bias + is_ins = False + is_del = False + done = False + if not done: + current_ref.append(r) + current_alt.append(a) + current_af.append(af) + is_ins = r == 0 and a != 0 + is_del = r != 0 and a == 0 + if r != 0: + bias += 1 + + return variants def TrimREFALT(ref, alt, pos): @@ -922,24 +991,27 @@ def run_realignment(input_record): work, ref_fasta_file, target_region, pad, chunk_size, chunk_scale, \ snp_min_af, del_min_af, ins_min_af, len_chr, input_bam, \ match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, \ - msa_binary, get_var = input_record + max_realign_dp, \ + filter_duplicate, \ + msa_binary, get_var, do_split = input_record + ref_fasta = pysam.Fastafile(ref_fasta_file) thread_logger = logging.getLogger( "{} ({})".format(run_realignment.__name__, multiprocessing.current_process().name)) try: region = Region(target_region, pad, len_chr) - + not_realigned_region = None original_tempdir = tempfile.tempdir bed_tempdir = os.path.join( work, "bed_tmpdir_{}".format(region.__str__())) if not os.path.exists(bed_tempdir): os.mkdir(bed_tempdir) tempfile.tempdir = bed_tempdir - variant = [] + variants = [] all_entries = [] - input_bam_splits, lens_splits = split_bam_to_chunks( - work, region, input_bam, chunk_size, chunk_scale) + input_bam_splits, lens_splits, ds_splits = split_bam_to_chunks( + work, region, input_bam, chunk_size, chunk_scale, do_split or not get_var, filter_duplicate) new_seqs = [] new_ref_seq = "" skipped = 0 @@ -950,14 +1022,16 @@ def run_realignment(input_record): afss = [] for i, i_bam in enumerate(input_bam_splits): in_fasta_file, info_file = prepare_fasta( - work, region, i_bam, ref_fasta_file, True, i) - if do_realign(region, info_file): + work, region, i_bam, ref_fasta_file, True, i, ds_splits[i], filter_duplicate) + if do_realign(region, info_file, max_realign_dp): out_fasta_file_0 = run_msa( in_fasta_file, match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, msa_binary) if get_var: - ref_seq_, alt_seq_, afs = find_var( - out_fasta_file_0, snp_min_af, del_min_af, ins_min_af, scale_maf) + var = find_var( + out_fasta_file_0, snp_min_af, del_min_af, ins_min_af, scale_maf, False) + assert(len(var) == 1) + _, ref_seq_, alt_seq_, afs = var[0] afss.append(afs) new_ref_seq = ref_seq_ new_seqs.append(alt_seq_) @@ -969,37 +1043,66 @@ def run_realignment(input_record): all_entries.extend(entries) else: skipped += 1 - if get_var and new_seqs: - for i in range(skipped): + if get_var: + if new_seqs: + for i in range(skipped): + new_seqs = [new_ref_seq] + new_seqs new_seqs = [new_ref_seq] + new_seqs - new_seqs = [new_ref_seq] + new_seqs - consensus_fasta = os.path.join( - work, region.__str__() + "_consensus.fasta") - with open(consensus_fasta, "w") as output_handle: - for i, seq in enumerate(new_seqs): - record = SeqRecord( - Seq(seq, DNAAlphabet.letters), id=str(i), description="") - SeqIO.write(record, output_handle, "fasta") - consensus_fasta_aligned = run_msa( - consensus_fasta, match_score, mismatch_penalty, gap_open_penalty, - gap_ext_penalty, msa_binary) - ref_seq, alt_seq, afs = find_var( - consensus_fasta_aligned, snp_min_af, del_min_af, ins_min_af, 1) - if ref_seq != alt_seq: - ref, alt, pos = TrimREFALT( - ref_seq, alt_seq, int(region.start) + 1) - a = int(np.ceil(np.max(afs) * len(afss))) - af = sum(sorted(map(lambda x: - np.max(x) if x.shape[0] > 0 else 0, - afss))[-a:]) / float(len(afss)) - dp = sum(lens_splits) - ao = int(af * dp) - ro = dp - ao - variant = [region.chrom, pos, ref, alt, dp, ro, ao] + consensus_fasta = os.path.join( + work, region.__str__() + "_consensus.fasta") + with open(consensus_fasta, "w") as output_handle: + for i, seq in enumerate(new_seqs): + record = SeqRecord( + Seq(seq, DNAAlphabet.letters), id=str(i), description="") + SeqIO.write(record, output_handle, "fasta") + consensus_fasta_aligned = run_msa( + consensus_fasta, match_score, mismatch_penalty, gap_open_penalty, + gap_ext_penalty, msa_binary) + vars_ = find_var( + consensus_fasta_aligned, snp_min_af, del_min_af, ins_min_af, 1, True) + for var in vars_: + pos_, ref_seq, alt_seq, afs = var + if ref_seq != alt_seq: + ref, alt, pos = ref_seq, alt_seq, int( + region.start) + 1 + pos_ + if pos > 1: + num_add_before = min(40, pos - 1) + before = ref_fasta.fetch( + region.chrom, pos - num_add_before, pos - 1).upper() + print(before) + pos -= num_add_before - 1 + ref = before + ref + alt = before + alt + ref, alt, pos = TrimREFALT( + ref, alt, pos) + a = int(np.ceil(np.max(afs) * len(afss))) + af = sum(sorted(map(lambda x: + np.max(x) if x.shape[0] > 0 else 0, + afss))[-a:]) / float(len(afss)) + dp = int(sum(lens_splits)) + ao = int(af * dp) + ro = dp - ao + if ref == "" and pos > 1: + pos -= 1 + r_ = ref_fasta.fetch( + region.chrom, pos - 1, pos).upper() + ref = r_ + ref + alt = r_ + alt + if alt == "" and pos > 1: + pos -= 1 + r_ = ref_fasta.fetch( + region.chrom, pos - 1, pos).upper() + ref = r_ + ref + alt = r_ + alt + variants.append( + [region.chrom, pos, ref, alt, dp, ro, ao]) + else: + if skipped > 0: + not_realigned_region = target_region shutil.rmtree(bed_tempdir) tempfile.tempdir = original_tempdir - return all_entries, variant + return all_entries, variants, not_realigned_region except Exception as ex: thread_logger.error(traceback.format_exc()) thread_logger.error(ex) @@ -1018,7 +1121,7 @@ def set_chrom(self, chrom): def get_seq(self, start, end=[]): if not end: end = start + 1 - return self.fasta_pysam.fetch(self.chrom, start, end) + return self.fasta_pysam.fetch(self.chrom, start, end).upper() def extend_regions_hp(region_bed_file, extended_region_bed_file, ref_fasta_file, @@ -1128,17 +1231,17 @@ def extend_regions_repeat(region_bed_file, extended_region_bed_file, ref_fasta_f ref_seq = ref_fasta.fetch( chrom, new_start, new_end + 1).upper() if cnt_s == 0: - while check_rep(ref_seq, "left", 3): - new_start -= 3 + while check_rep(ref_seq, "left", 2): + new_start -= 2 ref_seq = ref_fasta.fetch( chrom, new_start, new_end + 1).upper() - cnt_s += 3 + cnt_s += 2 if cnt_s == 0: - while check_rep(ref_seq, "left", 4): - new_start -= 4 + while check_rep(ref_seq, "left", 3): + new_start -= 3 ref_seq = ref_fasta.fetch( chrom, new_start, new_end + 1).upper() - cnt_s += 4 + cnt_s += 3 if cnt_s == 0: while check_rep(ref_seq, "left", 4): new_start -= 4 @@ -1194,10 +1297,14 @@ def extend_regions_repeat(region_bed_file, extended_region_bed_file, ref_fasta_f run_logger=logger) -def long_read_indelrealign(work, input_bam, output_bam, output_vcf, region_bed_file, +def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_realigned_bed, + region_bed_file, ref_fasta_file, num_threads, pad, chunk_size, chunk_scale, snp_min_af, del_min_af, ins_min_af, match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, + max_realign_dp, + do_split, + filter_duplicate, msa_binary): logger = logging.getLogger(long_read_indelrealign.__name__) @@ -1257,7 +1364,8 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, region_bed_f chunk_scale, snp_min_af, del_min_af, ins_min_af, chrom_lengths[target_region[0]], input_bam, match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, - msa_binary, get_var)) + max_realign_dp, filter_duplicate, + msa_binary, get_var, do_split)) shuffle(map_args) try: @@ -1276,7 +1384,10 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, region_bed_f realign_entries = list(map(lambda x: x[0], realign_output)) realign_variants = list(map(lambda x: x[1], realign_output)) + realign_variants = [v for var in realign_variants for v in var] realign_variants = list(filter(None, realign_variants)) + not_realigned_regions = list(map(lambda x: x[2], realign_output)) + not_realigned_regions = list(filter(None, not_realigned_regions)) if get_var: with open(output_vcf, "w") as o_f: @@ -1292,6 +1403,9 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, region_bed_f "GT:DP:RO:AO", "0/1:{}:{}:{}".format( dp, ro, ao), ]) o_f.write(line + "\n") + with open(output_not_realigned_bed, "w") as o_f: + for x in not_realigned_regions: + o_f.write("\t".join(map(str, x)) + "\n") original_tempdir = tempfile.tempdir bed_tempdir = os.path.join(work, "bed_tmpdir") @@ -1327,6 +1441,8 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, region_bed_f parser.add_argument('--input_bam', type=str, help='input bam') parser.add_argument('--output_vcf', type=str, help='output_vcf (needed for variant prediction)', default=None) + parser.add_argument('--output_not_realigned_bed', type=str, + help='output_not_realigned_bed', required=True) parser.add_argument('--output_bam', type=str, help='output_bam (needed for getting the realigned bam)', default=None) parser.add_argument('--region_bed', type=str, @@ -1357,6 +1473,14 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, region_bed_f help='penalty for opening a gap', default=8) parser.add_argument('--gap_ext_penalty', type=int, help='penalty for extending a gap', default=6) + parser.add_argument('--max_realign_dp', type=int, + help='max coverage for realign region', default=1000) + parser.add_argument('--do_split', + help='Split bam for high coverage regions (in variant-calling mode).', + action="store_true") + parser.add_argument('--filter_duplicate', + help='filter duplicate reads in analysis', + action="store_true") parser.add_argument('--msa_binary', type=str, help='MSA binary', default="../bin/msa") args = parser.parse_args() @@ -1364,12 +1488,18 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, region_bed_f try: processor = long_read_indelrealign(args.work, args.input_bam, args.output_bam, - args.output_vcf, args.region_bed, args.reference, + args.output_vcf, args.output_not_realigned_bed, + args.region_bed, args.reference, args.num_threads, args.pad, args.chunk_size, args.chunk_scale, args.snp_min_af, args.del_min_af, args.ins_min_af, args.match_score, args.mismatch_penalty, args.gap_open_penalty, - args.gap_ext_penalty, args.msa_binary) + args.gap_ext_penalty, + args.gap_ext_penalty, + args.max_realign_dp, + args.do_split, + args.filter_duplicate, + args.msa_binary) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") diff --git a/neusomatic/python/merge_tsvs.py b/neusomatic/python/merge_tsvs.py index d91c501..96ca72c 100755 --- a/neusomatic/python/merge_tsvs.py +++ b/neusomatic/python/merge_tsvs.py @@ -47,7 +47,12 @@ def merge_tsvs(input_tsvs, out, totla_L = 0 for tsv in input_tsvs: - totla_L += len(pickle.load(open(tsv + ".idx", "rb"))) - 1 + if os.path.exists(tsv + "idx"): + totla_L += len(pickle.load(open(tsv + ".idx", "rb"))) - 1 + else: + with open(tsv, "r") as i_f: + for line in i_f: + totla_L += 1 totla_L = max(0, totla_L) candidates_per_tsv = max(candidates_per_tsv, np.ceil( totla_L / float(max_num_tsvs)) + 1) diff --git a/neusomatic/python/postprocess.py b/neusomatic/python/postprocess.py index 8e4d528..34ad405 100755 --- a/neusomatic/python/postprocess.py +++ b/neusomatic/python/postprocess.py @@ -21,7 +21,7 @@ from extract_postprocess_targets import extract_postprocess_targets from merge_post_vcfs import merge_post_vcfs from resolve_variants import resolve_variants -from utils import concatenate_files, get_chromosomes_order, bedtools_window, skip_empty +from utils import concatenate_files, get_chromosomes_order, bedtools_window, skip_empty, run_bedtools_cmd from long_read_indelrealign import long_read_indelrealign from resolve_scores import resolve_scores from _version import __version__ @@ -167,7 +167,8 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense postprocess_max_dist, long_read, lr_pad, lr_chunk_size, lr_chunk_scale, lr_snp_min_af, lr_ins_min_af, lr_del_min_af, lr_match_score, lr_mismatch_penalty, - lr_gap_open_penalty, lr_gap_ext_penalty, + lr_gap_open_penalty, lr_gap_ext_penalty, lr_max_realign_dp, lr_do_split, + filter_duplicate, pass_threshold, lowqual_threshold, msa_binary, num_threads): logger = logging.getLogger(postprocess.__name__) @@ -203,7 +204,10 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense logger.info("Resolve targets") if not long_read: resolve_variants(tumor_bam, resolved_vcf, - reference, target_vcf, target_bed, num_threads) + reference, target_vcf, target_bed, filter_duplicate, + num_threads) + all_no_resolve = concatenate_files( + [no_resolve, ensembled_preds], os.path.join(work, "no_resolve.vcf")) else: work_lr_indel_realign = os.path.join(work, "work_lr_indel_realign") if os.path.exists(work_lr_indel_realign): @@ -211,16 +215,26 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense os.mkdir(work_lr_indel_realign) ra_resolved_vcf = os.path.join( work, "candidates_preds.ra_resolved.vcf") - long_read_indelrealign(work_lr_indel_realign, tumor_bam, None, ra_resolved_vcf, target_bed, + not_resolved_bed = os.path.join(work, "candidates_preds.not_ra_resolved.bed") + long_read_indelrealign(work_lr_indel_realign, tumor_bam, None, ra_resolved_vcf, + not_resolved_bed, target_bed, reference, num_threads, lr_pad, lr_chunk_size, lr_chunk_scale, lr_snp_min_af, lr_del_min_af, lr_ins_min_af, lr_match_score, lr_mismatch_penalty, lr_gap_open_penalty, - lr_gap_ext_penalty, msa_binary) + lr_gap_ext_penalty, lr_max_realign_dp, lr_do_split, + filter_duplicate, + msa_binary) resolve_scores(tumor_bam, ra_resolved_vcf, target_vcf, resolved_vcf) - all_no_resolve = concatenate_files( - [no_resolve, ensembled_preds], os.path.join(work, "no_resolve.vcf")) + not_resolved_vcf = os.path.join(work, "candidates_preds.not_ra_resolved.vcf") + cmd = "bedtools intersect -a {} -b {} -u".format( + target_vcf, not_resolved_bed) + run_bedtools_cmd(cmd, output_fn=not_resolved_vcf, run_logger=logger) + + + all_no_resolve = concatenate_files( + [no_resolve, ensembled_preds, not_resolved_vcf], os.path.join(work, "no_resolve.vcf")) logger.info("Merge vcfs") merged_vcf = os.path.join(work, "merged_preds.vcf") @@ -285,11 +299,19 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense help='long_read indel realign: penalty for opening a gap', default=8) parser.add_argument('--lr_gap_ext_penalty', type=int, help='long_read indel realign: penalty for extending a gap', default=6) + parser.add_argument('--lr_max_realign_dp', type=int, + help='long read max coverage for realign region', default=1000) + parser.add_argument('--lr_do_split', + help='long read split bam for high coverage regions (in variant-calling mode).', + action="store_true") parser.add_argument('--pass_threshold', type=float, help='SCORE for PASS (PASS for score => pass_threshold)', default=0.7) parser.add_argument('--lowqual_threshold', type=float, help='SCORE for LowQual (PASS for lowqual_threshold <= score < pass_threshold)', default=0.4) + parser.add_argument('--filter_duplicate', + help='filter duplicate reads in analysis', + action="store_true") parser.add_argument('--msa_binary', type=str, help='MSA binary', default="../bin/msa") parser.add_argument('--num_threads', type=int, @@ -308,7 +330,10 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense args.lr_snp_min_af, args.lr_ins_min_af, args.lr_del_min_af, args.lr_match_score, args.lr_mismatch_penalty, args.lr_gap_open_penalty, - args.lr_gap_ext_penalty, args.pass_threshold, args.lowqual_threshold, + args.lr_gap_ext_penalty, args.lr_max_realign_dp, + args.lr_do_split, + args.filter_duplicate, + args.pass_threshold, args.lowqual_threshold, args.msa_binary, args.num_threads) except Exception as e: diff --git a/neusomatic/python/preprocess.py b/neusomatic/python/preprocess.py index 133ae62..6901870 100755 --- a/neusomatic/python/preprocess.py +++ b/neusomatic/python/preprocess.py @@ -22,21 +22,6 @@ from utils import concatenate_vcfs, run_bedtools_cmd, bedtools_sort, bedtools_merge, bedtools_intersect, bedtools_slop, get_tmp_file, skip_empty, vcf_2_bed -def split_dbsnp(record): - restart, dbsnp, region_bed, dbsnp_region_vcf = record - thread_logger = logging.getLogger( - "{} ({})".format(split_dbsnp.__name__, multiprocessing.current_process().name)) - try: - if restart or not os.path.exists(dbsnp_region_vcf): - bedtools_intersect( - dbsnp, region_bed, args=" -u", output_fn=dbsnp_region_vcf, run_logger=thread_logger) - return dbsnp_region_vcf - except Exception as ex: - thread_logger.error(traceback.format_exc()) - thread_logger.error(ex) - return None - - def process_split_region(tn, work, region, reference, mode, alignment_bam, dbsnp, scan_window_size, scan_maf, min_mapq, filtered_candidates_vcf, min_dp, max_dp, @@ -44,7 +29,7 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam, dbsnp good_ao, min_ao, snp_min_af, snp_min_bq, snp_min_ao, ins_min_af, del_min_af, del_merge_min_af, ins_merge_min_af, merge_r, - scan_alignments_binary, restart, num_threads, calc_qual, regions=[], dbsnp_regions=[]): + scan_alignments_binary, restart, num_threads, calc_qual, regions=[]): logger = logging.getLogger(process_split_region.__name__) logger.info("Scan bam.") @@ -55,36 +40,12 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam, dbsnp if filtered_candidates_vcf: logger.info("Filter candidates.") if restart or not os.path.exists(filtered_candidates_vcf): - if dbsnp and not dbsnp_regions: - map_args = [] - for raw_vcf, count_bed, split_region_bed in scan_outputs: - dbsnp_region_vcf = os.path.join(os.path.dirname( - os.path.realpath(raw_vcf)), "dbsnp_region.vcf") - map_args.append( - (restart, dbsnp, split_region_bed, dbsnp_region_vcf)) - pool = multiprocessing.Pool(num_threads) - try: - dbsnp_regions = pool.map_async(split_dbsnp, map_args).get() - pool.close() - except Exception as inst: - logger.error(inst) - pool.close() - traceback.print_exc() - raise Exception - - for o in dbsnp_regions: - if o is None: - raise Exception("split_dbsnp failed!") - pool = multiprocessing.Pool(num_threads) map_args = [] for i, (raw_vcf, count_bed, split_region_bed) in enumerate(scan_outputs): filtered_vcf = os.path.join(os.path.dirname( os.path.realpath(raw_vcf)), "filtered_candidates.vcf") - dbsnp_region_vcf = None - if dbsnp: - dbsnp_region_vcf = dbsnp_regions[i] - map_args.append((raw_vcf, filtered_vcf, reference, dbsnp_region_vcf, min_dp, max_dp, good_ao, + map_args.append((raw_vcf, filtered_vcf, reference, dbsnp, min_dp, max_dp, good_ao, min_ao, snp_min_af, snp_min_bq, snp_min_ao, ins_min_af, del_min_af, del_merge_min_af, ins_merge_min_af, merge_r)) try: @@ -110,15 +71,9 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam, dbsnp filtered_vcf = os.path.join(os.path.dirname( os.path.realpath(raw_vcf)), "filtered_candidates.vcf") filtered_candidates_vcfs.append(filtered_vcf) - if dbsnp and not dbsnp_regions: - dbsnp_regions = [] - for raw_vcf, _, _ in scan_outputs: - dbsnp_region_vcf = os.path.join(os.path.dirname( - os.path.realpath(raw_vcf)), "dbsnp_region.vcf") - dbsnp_regions.append(dbsnp_region_vcf) else: filtered_candidates_vcfs = None - return list(map(lambda x: x[1], scan_outputs)), list(map(lambda x: x[2], scan_outputs)), filtered_candidates_vcfs, dbsnp_regions + return list(map(lambda x: x[1], scan_outputs)), list(map(lambda x: x[2], scan_outputs)), filtered_candidates_vcfs def generate_dataset_region(work, truth_vcf, mode, filtered_candidates_vcf, region, tumor_count_bed, normal_count_bed, reference, @@ -266,6 +221,16 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, raise Exception( "No normal .bai index file {}".format(normal_bam + ".bai")) + if dbsnp: + if dbsnp[-6:] != "vcf.gz": + logger.error("Aborting!") + raise Exception( + "The dbSNP file should be a tabix indexed file with .vcf.gz format") + if not os.path.exists(dbsnp + ".tbi"): + logger.error("Aborting!") + raise Exception( + "The dbSNP file should be a tabix indexed file with .vcf.gz format. No {}.tbi file exists.".format(dbsnp)) + ensemble_bed = None if ensemble_tsv: ensemble_bed = os.path.join(work, "ensemble.bed") @@ -275,7 +240,6 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, merge_d_for_short_read = 100 candidates_split_regions = [] - dbsnp_regions_q = [] ensemble_beds = [] if not long_read and first_do_without_qual: logger.info("Scan tumor bam (first without quality scores).") @@ -294,8 +258,8 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, ins_min_af, del_min_af, del_merge_min_af, ins_merge_min_af, merge_r, scan_alignments_binary, restart, num_threads, - calc_qual=False, dbsnp_regions=[]) - tumor_counts_without_q, split_regions, filtered_candidates_vcfs_without_q, dbsnp_regions_q = tumor_outputs_without_q + calc_qual=False) + tumor_counts_without_q, split_regions, filtered_candidates_vcfs_without_q = tumor_outputs_without_q if ensemble_tsv: ensemble_beds = get_ensemble_beds( @@ -320,9 +284,8 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, ins_merge_min_af, merge_r, scan_alignments_binary, restart, num_threads, calc_qual=True, - regions=candidates_split_regions, - dbsnp_regions=dbsnp_regions_q) - tumor_counts, split_regions, filtered_candidates_vcfs, _ = tumor_outputs + regions=candidates_split_regions) + tumor_counts, split_regions, filtered_candidates_vcfs = tumor_outputs if ensemble_tsv and not ensemble_beds: ensemble_beds = get_ensemble_beds( @@ -339,17 +302,16 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, if restart or not os.path.exists(work_normal): os.mkdir(work_normal) logger.info("Scan normal bam (and extracting quality scores).") - normal_counts, _, _, _ = process_split_region("normal", work_normal, region_bed, reference, mode, normal_bam, - None, scan_window_size, 0.2, min_mapq, - None, min_dp, max_dp, - filter_duplicate, - good_ao, min_ao, snp_min_af, snp_min_bq, snp_min_ao, - ins_min_af, del_min_af, del_merge_min_af, - ins_merge_min_af, merge_r, - scan_alignments_binary, restart, num_threads, - calc_qual=True, - regions=candidates_split_regions, - dbsnp_regions=[]) + normal_counts, _, _ = process_split_region("normal", work_normal, region_bed, reference, mode, normal_bam, + None, scan_window_size, 0.2, min_mapq, + None, min_dp, max_dp, + filter_duplicate, + good_ao, min_ao, snp_min_af, snp_min_bq, snp_min_ao, + ins_min_af, del_min_af, del_merge_min_af, + ins_merge_min_af, merge_r, + scan_alignments_binary, restart, num_threads, + calc_qual=True, + regions=candidates_split_regions) work_dataset = os.path.join(work, "dataset") if restart or not os.path.exists(work_dataset): @@ -393,7 +355,7 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, parser.add_argument('--work', type=str, help='work directory', required=True) parser.add_argument('--dbsnp_to_filter', type=str, - help='dbsnp vcf (will be used to filter candidate variants)', default=None) + help='dbsnp vcf.gz (will be used to filter candidate variants)', default=None) parser.add_argument('--scan_window_size', type=int, help='window size to scan the variants', default=2000) parser.add_argument('--scan_maf', type=float, diff --git a/neusomatic/python/resolve_scores.py b/neusomatic/python/resolve_scores.py index b07853f..65ff505 100755 --- a/neusomatic/python/resolve_scores.py +++ b/neusomatic/python/resolve_scores.py @@ -24,7 +24,7 @@ def resolve_scores(input_bam, ra_vcf, target_vcf, output_vcf): final_intervals = read_tsv_file(tmp_) for x in final_intervals: - x[5] = "0.5" + x[5] = str(np.round(-10*np.log10(0.25),4)) tmp_ = bedtools_window( ra_vcf, target_vcf, args=" -w 5", run_logger=logger) diff --git a/neusomatic/python/resolve_variants.py b/neusomatic/python/resolve_variants.py index 5c4d666..0a672c5 100755 --- a/neusomatic/python/resolve_variants.py +++ b/neusomatic/python/resolve_variants.py @@ -71,7 +71,7 @@ def extract_ins(record): def find_resolved_variants(input_record): - chrom, start, end, variants, input_bam, reference = input_record + chrom, start, end, variants, input_bam, filter_duplicate, reference = input_record thread_logger = logging.getLogger( "{} ({})".format(find_resolved_variants.__name__, multiprocessing.current_process().name)) try: @@ -83,7 +83,7 @@ def find_resolved_variants(input_record): scores = list(map(lambda x: x[5], variants)) if len(set(vartypes)) > 1: out_variants.extend( - list(map(lambda x: [x[0], int(x[1]), x[3], x[4], x[10], x[5]], variants))) + list(map(lambda x: [x[0], int(x[1]), x[3], x[4], x[9].split(":")[0], x[5]], variants))) else: vartype = vartypes[0] score = max(scores) @@ -91,8 +91,9 @@ def find_resolved_variants(input_record): dels = [] with pysam.AlignmentFile(input_bam) as samfile: for record in samfile.fetch(chrom, start, end): - if record.cigarstring and "D" in record.cigarstring: - dels.extend(extract_del(record)) + if not record.is_duplicate or not filter_duplicate: + if record.cigarstring and "D" in record.cigarstring: + dels.extend(extract_del(record)) dels = list(filter(lambda x: ( start <= x[1] <= end) or start <= x[2] <= end, dels)) if dels: @@ -113,20 +114,20 @@ def find_resolved_variants(input_record): f_o.write( "\t".join(map(str, x + [".", "."])) + "\n") new_bed = bedtools_sort(new_bed, run_logger=thread_logger) - new_bed = bedtools_merge( new_bed, args=" -c 1 -o count", run_logger=thread_logger) vs = read_tsv_file(new_bed, fields=range(4)) vs = list(map(lambda x: [x[0], int(x[1]), ref.fetch(x[0], int( - x[1]) - 1, int(x[2])), ref.fetch(x[0], int(x[1]) - 1, int(x[1])), "0/1", score], vs)) + x[1]) - 1, int(x[2])).upper(), ref.fetch(x[0], int(x[1]) - 1, int(x[1])).upper(), "0/1", score], vs)) out_variants.extend(vs) elif vartype == "INS": intervals = [] inss = [] with pysam.AlignmentFile(input_bam) as samfile: for record in samfile.fetch(chrom, start, end): - if record.cigarstring and "I" in record.cigarstring: - inss.extend(extract_ins(record)) + if not record.is_duplicate or not filter_duplicate: + if record.cigarstring and "I" in record.cigarstring: + inss.extend(extract_ins(record)) inss = list(filter(lambda x: ( start <= x[1] <= end) or start <= x[2] <= end, inss)) if inss: @@ -152,7 +153,7 @@ def find_resolved_variants(input_record): new_bed = bedtools_sort(new_bed, run_logger=thread_logger) vs = read_tsv_file(new_bed, fields=range(4)) vs = list(map(lambda x: [x[0], int(x[1]), ref.fetch(x[0], int( - x[1]) - 1, int(x[1])), ref.fetch(x[0], int(x[1]) - 1, int(x[1])) + x[3], "0/1", score], vs)) + x[1]) - 1, int(x[1])).upper(), ref.fetch(x[0], int(x[1]) - 1, int(x[1])).upper() + x[3], "0/1", score], vs)) out_variants.extend(vs) return out_variants except Exception as ex: @@ -162,7 +163,7 @@ def find_resolved_variants(input_record): def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, - target_bed_file, num_threads): + target_bed_file, filter_duplicate, num_threads): logger = logging.getLogger(resolve_variants.__name__) logger.info("-------Resolve variants (e.g. exact INDEL sequences)-------") @@ -188,7 +189,7 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, chrom, start, end, id_ = tb[0:4] id_ = int(id_) map_args.append([chrom, start, end, variants[id_], - input_bam, reference]) + input_bam, filter_duplicate, reference]) pool = multiprocessing.Pool(num_threads) try: @@ -241,6 +242,9 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, help='resolve target bed', required=True) parser.add_argument('--reference', type=str, help='reference fasta filename', required=True) + parser.add_argument('--filter_duplicate', + help='filter duplicate reads in analysis', + action="store_true") parser.add_argument('--num_threads', type=int, help='number of threads', default=1) args = parser.parse_args() @@ -248,7 +252,8 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, try: resolve_variants(args.input_bam, args.resolved_vcf, args.reference, args.target_vcf, - args.target_bed, args.num_threads) + args.target_bed, args.filter_duplicate, + args.num_threads) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") diff --git a/neusomatic/python/train.py b/neusomatic/python/train.py index 60c7510..0aa72e3 100755 --- a/neusomatic/python/train.py +++ b/neusomatic/python/train.py @@ -8,6 +8,7 @@ import argparse import datetime import logging +import copy import numpy as np import torch @@ -87,6 +88,10 @@ def test(net, epoch, validation_loader, use_cuda): for data in validation_loader: (matrices, labels, _, var_len_s, _), (paths) = data + paths_ = copy.deepcopy(paths) + del paths + paths = paths_ + matrices = Variable(matrices) if use_cuda: matrices = matrices.cuda() @@ -382,7 +387,7 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo net.train() len_train_set = sum(none_counts) + sum(var_counts) - logger.info("Number of candidater per epoch: {}".format(len_train_set)) + logger.info("Number of candidates per epoch: {}".format(len_train_set)) print_freq = max(1, int(len_train_set / float(batch_size) / 4.0)) curr_epoch = prev_epochs torch.save({"state_dict": net.state_dict(),