Skip to content

Commit

Permalink
Add utility options for memory reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
prihoda committed May 8, 2020
1 parent 6bd53ec commit 7e62864
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 18 deletions.
8 changes: 6 additions & 2 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def test_usum_umap(tmpdir, mocker):
os.path.join(tmpdir, 'input.fa'),
os.path.join(tmpdir, 'distance.txt'),
maxdist=0.3,
termdist=0.5
termdist=0.5,
threads=None,
dbstep=None
)

assert os.path.exists(tmpdir)
Expand All @@ -72,7 +74,9 @@ def test_usum_tsne(tmpdir, mocker):
os.path.join(tmpdir, 'input.fa'),
os.path.join(tmpdir, 'distance.txt'),
maxdist=0.3,
termdist=0.5
termdist=0.5,
threads=None,
dbstep=None
)

assert os.path.exists(tmpdir)
Expand Down
90 changes: 74 additions & 16 deletions usum/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from umap.plot import _matplotlib_points, _themes, _select_font_color, _datashade_points
from subprocess import CalledProcessError

def main(argv=None):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand All @@ -30,9 +31,12 @@ def main(argv=None):

parser.add_argument("--limit", type=int, help="Use random number of records from each input file.")
parser.add_argument("--seed", default=1, type=int, help="Random seed for input subsampling and UMAP.")

parser.add_argument("--trim-length", type=int, help="Trim sequences longer than given length.")

parser.add_argument("--maxdist", type=float, help="USEARCH: Maximum distance which should be written (required if not using --resume).")
parser.add_argument("--termdist", type=float, default=1.0, help="USEARCH: Identity threshold for terminating the calculation. This should be set higher than maxdist.")
parser.add_argument("--threads", type=int, default=None, help="USEARCH: Number of threads to use. Use --threads 1 to reduce memory use.")
parser.add_argument("--dbstep", type=int, default=None, help="USEARCH: Index only every Nth word to reduce indexing memory.")

parser.add_argument("--umap-min-dist", type=float, default=0.1, help="UMAP: Effective minimum distance between embedded points, relative to spread.")
parser.add_argument("--umap-spread", type=float, default=1.0, help="UMAP: The effective scale of embedded points.")
Expand All @@ -43,6 +47,7 @@ def main(argv=None):
parser.add_argument("--height", type=int, default=800, help="Plot height in pixels.")

parser.add_argument("--tsne", action="store_true", default=False, help="Run t-SNE instead of UMAP.")
parser.add_argument("--stats", action="store_true", default=False, help="Create sequence stats plots.")

options = parser.parse_args()

Expand All @@ -57,29 +62,34 @@ def main(argv=None):
output=options.output,
maxdist=options.maxdist,
termdist=options.termdist,
threads=options.threads,
dbstep=options.dbstep,
labels=options.labels,
force=options.force,
resume=options.resume,
limit=options.limit,
random_state=options.seed,
trim_length=options.trim_length,
umap_min_dist=options.umap_min_dist,
umap_spread=options.umap_spread,
neighbors=options.neighbors,
method='tsne' if options.tsne else 'umap',
stats=options.stats,
theme=options.theme,
width=options.width,
height=options.height
)
except UsumError as e:
print('===')
print(str(e), file=sys.stderr)
sys.exit(2)

class UsumError(Exception):
pass

def usum(
inputs, output, maxdist=None, termdist=1.0,
labels=None, force=False, resume=False, limit=None, random_state=1,
inputs, output, maxdist=None, termdist=1.0, threads=None, dbstep=None,
labels=None, force=False, resume=False, limit=None, random_state=1, trim_length=None, stats=False,
umap_min_dist=0.1, umap_spread=1.0, method='umap', neighbors=15, theme='fire', width=800, height=800
):
"""
Expand All @@ -88,21 +98,25 @@ def usum(
:param output: output directory path.
:param maxdist: USEARCH: Maximum distance which should be written.
:param termdist: USEARCH: Identity threshold for terminating the calculation. This should be set higher than maxdist.
:param threads: USEARCH: Number of threads to use. Use threads=1 to reduce memory use.
:param dbstep: USEARCH: Index only every Nth word to reduce indexing memory.
:param labels: Input file labels. If not provided, file names without extension will be used.
:param force: Force overwrite output.
:param resume: Resume using existing results from output directory.
:param limit: Use random number of records from each input file.
:param random_state: Random seed for input subsampling and UMAP.
:param trim_length: Trim sequences longer than given length
:param stats: Create sequence stats plots
:param umap_min_dist: UMAP: Effective minimum distance between embedded points
:param umap_spread: UMAP: The effective scale of embedded points
:param neighbors: UMAP: The size of local neighborhood.
:param neighbors: UMAP, t-SNE: The size of local neighborhood (perplexity).
:param method: Embedding method (umap, tsne).
:param theme: Plot color theme.
:param width: Plot width in pixels.
:param height: Plot height in pixels.
:return: tuple with UMAP reducer and sequence DataFrame
"""
if not force and not resume and (os.path.exists(output) and (os.listdir(output) or not os.path.isdir(output))):
if not force and not resume and (os.path.exists(output) and os.listdir(output)):
raise UsumError(f'Output path exists and is not empty: {output}. \nRun with -f --force to overwrite or --resume to resume')

if labels:
Expand Down Expand Up @@ -142,21 +156,40 @@ def usum(
# Remove sequence TSV file to avoid reusing incomplete result
os.remove(index_path)

index = save_input_fasta(inputs, labels, fasta_path, limit=limit, random=(limit is not None), random_state=random_state)
index = save_input_fasta(inputs, labels, fasta_path, limit=limit, random=(limit is not None), random_state=random_state, trim_length=trim_length)
print(f'Saved {len(index):,} records to: {fasta_path}')

if stats:
print(f'\n> Plotting sequence stats...')
seq_length_path = os.path.join(output, 'seq_length.png')
seq_length = pd.Series(len(r.seq) for r in SeqIO.parse(fasta_path, 'fasta'))
ax = seq_length.plot.hist(bins=20, figsize=(10, 4))
ax.set_xlabel('Sequence length')
ax.axvline(seq_length.median(), ls=':', color='grey')
ax.figure.savefig(seq_length_path, bbox_inches='tight')
print(f'Saved sequence length histogram to: {seq_length_path}')

print(f'\n> Creating sparse {len(index):,} x {len(index):,} distance matrix with {maxdist} max distance...')
if len(index) > 10000 and not limit:
print('NOTE: This might take some time. Consider using --limit to compare just a random subset.')
run_usearch(fasta_path, distance_path, maxdist=maxdist, termdist=termdist)

run_usearch(fasta_path, distance_path, maxdist=maxdist, termdist=termdist, threads=threads, dbstep=dbstep)


dist_matrix = load_sparse_dist_matrix(distance_path)
# with {dist_matrix.getnnx():,} non-null values
print(f'Loaded {len(index):,} x {len(index):,} distance matrix ({int(dist_matrix.nbytes / 1024 / 1024)} MB)')

#if stats:
#seq_dist_path = os.path.join(output, 'seq_dist.png')
#seq_dist = pd.Series(dist_matrix.flatten())
#ax = seq_dist.plot.hist(bins=20, figsize=(10, 4))
#ax.set_xlabel('Pairwise distance (without self-distances)')
#ax.figure.savefig(seq_dist_path, bbox_inches='tight')
#print(f'Saved sequence distance histogram to: {seq_dist_path}')

if method == 'tsne':
print(f'\n> Creating t-SNE embedding...')
reducer, embedding = fit_tsne(dist_matrix, random_state=random_state)
reducer, embedding = fit_tsne(dist_matrix, random_state=random_state, perplexity=neighbors)

index['tsne1'] = embedding[:,0]
index['tsne2'] = embedding[:,1]
Expand Down Expand Up @@ -211,28 +244,43 @@ def iterate_fasta(path, limit=None, random=False, random_state=None):

return records

def save_input_fasta(inputs, labels, fasta_path, random=True, limit=None, random_state=None):
def save_input_fasta(inputs, labels, fasta_path, random=True, limit=None, random_state=None, trim_length=None):
i = 0
index = []
trimmed = []
num_trimmed = 0
with open(fasta_path, 'w') as f:
for path, label in zip(inputs, labels):
for record in iterate_fasta(path, limit=limit, random=random, random_state=random_state):
record_id = record.id
index.append(OrderedDict(
index=i,
label=label,
seq_id=record.id
seq_id=record_id,
description=record.description
))
record.id = str(i)
record.description = ''
record.name = record.id
if trim_length and len(record.seq) > trim_length:
if len(trimmed) < 10:
trimmed.append(f'"{record_id}" ({len(record.seq)} chars)')
record.seq = record.seq[:trim_length]
num_trimmed += 1
SeqIO.write(record, f, 'fasta')
i += 1

if num_trimmed:
trimmed_label = ", ".join(trimmed)
if num_trimmed > len(trimmed):
trimmed_label += '...'
print(f'Trimed {num_trimmed} sequences to {trim_length} chars: {trimmed_label}')

index = pd.DataFrame(index)
return index


def run_usearch(fasta_path, distance_path, maxdist, termdist=1.0):
def run_usearch(fasta_path, distance_path, maxdist, termdist=1.0, threads=None, dbstep=None):
if not shutil.which('usearch'):
raise UsumError('Missing usearch dependency on PATH. \nInstall it from: https://drive5.com/usearch/download.html')

Expand All @@ -243,9 +291,19 @@ def run_usearch(fasta_path, distance_path, maxdist, termdist=1.0):
'-maxdist', str(maxdist),
'-termdist', str(termdist)
]
if threads:
cmd += ['-threads', str(threads)]
if dbstep:
cmd += ['-dbstep', str(dbstep)]
print(f'Running USEARCH command: {" ".join(cmd)}')
subprocess.check_output(cmd)

try:
subprocess.check_output(cmd)
except CalledProcessError:
raise UsumError('Error calling USEARCH. '
'\nTips to reduce memory use:'
'\n - Use --trim-length to trim long sequences'
'\n - Use --threads 1 to avoid memory duplication between threads (slow)'
'\n - If indexing fails, use --dbstep 2 or higher to index only each Nth word')

def load_sparse_dist_matrix(distance_path):
dist_matrix = pd.read_csv(distance_path, header=None, sep='\t')
Expand All @@ -260,7 +318,6 @@ def load_sparse_dist_matrix(distance_path):


def fit_umap(dist_matrix, random_state=None, neighbors=15, min_dist=0.1, spread=1.0):
print(dist_matrix)
reducer = umap.UMAP(
n_neighbors=neighbors,
random_state=random_state,
Expand All @@ -272,8 +329,9 @@ def fit_umap(dist_matrix, random_state=None, neighbors=15, min_dist=0.1, spread=
return reducer, embedding


def fit_tsne(dist_matrix, random_state=None):
def fit_tsne(dist_matrix, random_state=None, perplexity=30):
reducer = TSNE(
perplexity=perplexity,
random_state=random_state,
metric='precomputed'
)
Expand Down

0 comments on commit 7e62864

Please sign in to comment.