From 7e628640618ba6f8f8a73edfbec0fbc2cff40a63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20P=C5=99=C3=ADhoda?= Date: Fri, 8 May 2020 15:29:59 +0200 Subject: [PATCH] Add utility options for memory reduction --- test/test_main.py | 8 +++-- usum/main.py | 90 ++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 80 insertions(+), 18 deletions(-) diff --git a/test/test_main.py b/test/test_main.py index 5d63fd2..9cab9a9 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -47,7 +47,9 @@ def test_usum_umap(tmpdir, mocker): os.path.join(tmpdir, 'input.fa'), os.path.join(tmpdir, 'distance.txt'), maxdist=0.3, - termdist=0.5 + termdist=0.5, + threads=None, + dbstep=None ) assert os.path.exists(tmpdir) @@ -72,7 +74,9 @@ def test_usum_tsne(tmpdir, mocker): os.path.join(tmpdir, 'input.fa'), os.path.join(tmpdir, 'distance.txt'), maxdist=0.3, - termdist=0.5 + termdist=0.5, + threads=None, + dbstep=None ) assert os.path.exists(tmpdir) diff --git a/usum/main.py b/usum/main.py index 2507797..6cc6a3f 100644 --- a/usum/main.py +++ b/usum/main.py @@ -19,6 +19,7 @@ from sklearn.manifold import TSNE from matplotlib import pyplot as plt from umap.plot import _matplotlib_points, _themes, _select_font_color, _datashade_points +from subprocess import CalledProcessError def main(argv=None): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -30,9 +31,12 @@ def main(argv=None): parser.add_argument("--limit", type=int, help="Use random number of records from each input file.") parser.add_argument("--seed", default=1, type=int, help="Random seed for input subsampling and UMAP.") - + parser.add_argument("--trim-length", type=int, help="Trim sequences longer than given length.") + parser.add_argument("--maxdist", type=float, help="USEARCH: Maximum distance which should be written (required if not using --resume).") parser.add_argument("--termdist", type=float, default=1.0, help="USEARCH: Identity threshold for terminating the calculation. This should be set higher than maxdist.") + parser.add_argument("--threads", type=int, default=None, help="USEARCH: Number of threads to use. Use --threads 1 to reduce memory use.") + parser.add_argument("--dbstep", type=int, default=None, help="USEARCH: Index only every Nth word to reduce indexing memory.") parser.add_argument("--umap-min-dist", type=float, default=0.1, help="UMAP: Effective minimum distance between embedded points, relative to spread.") parser.add_argument("--umap-spread", type=float, default=1.0, help="UMAP: The effective scale of embedded points.") @@ -43,6 +47,7 @@ def main(argv=None): parser.add_argument("--height", type=int, default=800, help="Plot height in pixels.") parser.add_argument("--tsne", action="store_true", default=False, help="Run t-SNE instead of UMAP.") + parser.add_argument("--stats", action="store_true", default=False, help="Create sequence stats plots.") options = parser.parse_args() @@ -57,20 +62,25 @@ def main(argv=None): output=options.output, maxdist=options.maxdist, termdist=options.termdist, + threads=options.threads, + dbstep=options.dbstep, labels=options.labels, force=options.force, resume=options.resume, limit=options.limit, random_state=options.seed, + trim_length=options.trim_length, umap_min_dist=options.umap_min_dist, umap_spread=options.umap_spread, neighbors=options.neighbors, method='tsne' if options.tsne else 'umap', + stats=options.stats, theme=options.theme, width=options.width, height=options.height ) except UsumError as e: + print('===') print(str(e), file=sys.stderr) sys.exit(2) @@ -78,8 +88,8 @@ class UsumError(Exception): pass def usum( - inputs, output, maxdist=None, termdist=1.0, - labels=None, force=False, resume=False, limit=None, random_state=1, + inputs, output, maxdist=None, termdist=1.0, threads=None, dbstep=None, + labels=None, force=False, resume=False, limit=None, random_state=1, trim_length=None, stats=False, umap_min_dist=0.1, umap_spread=1.0, method='umap', neighbors=15, theme='fire', width=800, height=800 ): """ @@ -88,21 +98,25 @@ def usum( :param output: output directory path. :param maxdist: USEARCH: Maximum distance which should be written. :param termdist: USEARCH: Identity threshold for terminating the calculation. This should be set higher than maxdist. + :param threads: USEARCH: Number of threads to use. Use threads=1 to reduce memory use. + :param dbstep: USEARCH: Index only every Nth word to reduce indexing memory. :param labels: Input file labels. If not provided, file names without extension will be used. :param force: Force overwrite output. :param resume: Resume using existing results from output directory. :param limit: Use random number of records from each input file. :param random_state: Random seed for input subsampling and UMAP. + :param trim_length: Trim sequences longer than given length + :param stats: Create sequence stats plots :param umap_min_dist: UMAP: Effective minimum distance between embedded points :param umap_spread: UMAP: The effective scale of embedded points - :param neighbors: UMAP: The size of local neighborhood. + :param neighbors: UMAP, t-SNE: The size of local neighborhood (perplexity). :param method: Embedding method (umap, tsne). :param theme: Plot color theme. :param width: Plot width in pixels. :param height: Plot height in pixels. :return: tuple with UMAP reducer and sequence DataFrame """ - if not force and not resume and (os.path.exists(output) and (os.listdir(output) or not os.path.isdir(output))): + if not force and not resume and (os.path.exists(output) and os.listdir(output)): raise UsumError(f'Output path exists and is not empty: {output}. \nRun with -f --force to overwrite or --resume to resume') if labels: @@ -142,21 +156,40 @@ def usum( # Remove sequence TSV file to avoid reusing incomplete result os.remove(index_path) - index = save_input_fasta(inputs, labels, fasta_path, limit=limit, random=(limit is not None), random_state=random_state) + index = save_input_fasta(inputs, labels, fasta_path, limit=limit, random=(limit is not None), random_state=random_state, trim_length=trim_length) print(f'Saved {len(index):,} records to: {fasta_path}') + if stats: + print(f'\n> Plotting sequence stats...') + seq_length_path = os.path.join(output, 'seq_length.png') + seq_length = pd.Series(len(r.seq) for r in SeqIO.parse(fasta_path, 'fasta')) + ax = seq_length.plot.hist(bins=20, figsize=(10, 4)) + ax.set_xlabel('Sequence length') + ax.axvline(seq_length.median(), ls=':', color='grey') + ax.figure.savefig(seq_length_path, bbox_inches='tight') + print(f'Saved sequence length histogram to: {seq_length_path}') + print(f'\n> Creating sparse {len(index):,} x {len(index):,} distance matrix with {maxdist} max distance...') if len(index) > 10000 and not limit: print('NOTE: This might take some time. Consider using --limit to compare just a random subset.') - run_usearch(fasta_path, distance_path, maxdist=maxdist, termdist=termdist) - + run_usearch(fasta_path, distance_path, maxdist=maxdist, termdist=termdist, threads=threads, dbstep=dbstep) + + dist_matrix = load_sparse_dist_matrix(distance_path) # with {dist_matrix.getnnx():,} non-null values print(f'Loaded {len(index):,} x {len(index):,} distance matrix ({int(dist_matrix.nbytes / 1024 / 1024)} MB)') + #if stats: + #seq_dist_path = os.path.join(output, 'seq_dist.png') + #seq_dist = pd.Series(dist_matrix.flatten()) + #ax = seq_dist.plot.hist(bins=20, figsize=(10, 4)) + #ax.set_xlabel('Pairwise distance (without self-distances)') + #ax.figure.savefig(seq_dist_path, bbox_inches='tight') + #print(f'Saved sequence distance histogram to: {seq_dist_path}') + if method == 'tsne': print(f'\n> Creating t-SNE embedding...') - reducer, embedding = fit_tsne(dist_matrix, random_state=random_state) + reducer, embedding = fit_tsne(dist_matrix, random_state=random_state, perplexity=neighbors) index['tsne1'] = embedding[:,0] index['tsne2'] = embedding[:,1] @@ -211,28 +244,43 @@ def iterate_fasta(path, limit=None, random=False, random_state=None): return records -def save_input_fasta(inputs, labels, fasta_path, random=True, limit=None, random_state=None): +def save_input_fasta(inputs, labels, fasta_path, random=True, limit=None, random_state=None, trim_length=None): i = 0 index = [] + trimmed = [] + num_trimmed = 0 with open(fasta_path, 'w') as f: for path, label in zip(inputs, labels): for record in iterate_fasta(path, limit=limit, random=random, random_state=random_state): + record_id = record.id index.append(OrderedDict( index=i, label=label, - seq_id=record.id + seq_id=record_id, + description=record.description )) record.id = str(i) record.description = '' record.name = record.id + if trim_length and len(record.seq) > trim_length: + if len(trimmed) < 10: + trimmed.append(f'"{record_id}" ({len(record.seq)} chars)') + record.seq = record.seq[:trim_length] + num_trimmed += 1 SeqIO.write(record, f, 'fasta') i += 1 + if num_trimmed: + trimmed_label = ", ".join(trimmed) + if num_trimmed > len(trimmed): + trimmed_label += '...' + print(f'Trimed {num_trimmed} sequences to {trim_length} chars: {trimmed_label}') + index = pd.DataFrame(index) return index -def run_usearch(fasta_path, distance_path, maxdist, termdist=1.0): +def run_usearch(fasta_path, distance_path, maxdist, termdist=1.0, threads=None, dbstep=None): if not shutil.which('usearch'): raise UsumError('Missing usearch dependency on PATH. \nInstall it from: https://drive5.com/usearch/download.html') @@ -243,9 +291,19 @@ def run_usearch(fasta_path, distance_path, maxdist, termdist=1.0): '-maxdist', str(maxdist), '-termdist', str(termdist) ] + if threads: + cmd += ['-threads', str(threads)] + if dbstep: + cmd += ['-dbstep', str(dbstep)] print(f'Running USEARCH command: {" ".join(cmd)}') - subprocess.check_output(cmd) - + try: + subprocess.check_output(cmd) + except CalledProcessError: + raise UsumError('Error calling USEARCH. ' + '\nTips to reduce memory use:' + '\n - Use --trim-length to trim long sequences' + '\n - Use --threads 1 to avoid memory duplication between threads (slow)' + '\n - If indexing fails, use --dbstep 2 or higher to index only each Nth word') def load_sparse_dist_matrix(distance_path): dist_matrix = pd.read_csv(distance_path, header=None, sep='\t') @@ -260,7 +318,6 @@ def load_sparse_dist_matrix(distance_path): def fit_umap(dist_matrix, random_state=None, neighbors=15, min_dist=0.1, spread=1.0): - print(dist_matrix) reducer = umap.UMAP( n_neighbors=neighbors, random_state=random_state, @@ -272,8 +329,9 @@ def fit_umap(dist_matrix, random_state=None, neighbors=15, min_dist=0.1, spread= return reducer, embedding -def fit_tsne(dist_matrix, random_state=None): +def fit_tsne(dist_matrix, random_state=None, perplexity=30): reducer = TSNE( + perplexity=perplexity, random_state=random_state, metric='precomputed' )