Skip to content

Commit

Permalink
Add UCE method (#7)
Browse files Browse the repository at this point in the history
* Create UCE component files

* Add UCE dataset preprocessing

* Generate UCE indexes

* Evaluate UCE model and output results

* Add UCE to benchmark workflow

* Fix UCE model path in benchmark workflow

* Copy UCE files to working directory for Nextflow

* Exclude UCE in local benchmark scripts

Requires more memory than allowed by the local labels config

* Style UCE script
  • Loading branch information
lazappi authored Nov 8, 2024
1 parent 2158882 commit 7e0b381
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 1 deletion.
1 change: 1 addition & 0 deletions scripts/run_benchmark/run_full_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ input_states: resources/datasets/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
HERE

# run the benchmark
Expand Down
1 change: 1 addition & 0 deletions scripts/run_benchmark/run_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
HERE

nextflow run . \
Expand Down
45 changes: 45 additions & 0 deletions src/methods/uce/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
__merge__: ../../api/base_method.yaml

name: uce
label: UCE
summary: UCE offers a unified biological latent space that can represent any cell
description: |
Universal Cell Embedding (UCE) is a single-cell foundation model that offers a
unified biological latent space that can represent any cell, regardless of
tissue or species
references:
doi:
- 10.1101/2023.11.28.568918
links:
documentation: https://github.com/snap-stanford/UCE/blob/main/README.md
repository: https://github.com/snap-stanford/UCE

info:
method_types: [embedding]
preferred_normalization: counts

arguments:
- name: --model
type: file
description: Path to the directory containing UCE model files or a .zip/.tar.gz archive
required: true

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pypi:
- accelerate==0.24.0
- type: docker
run: "git clone https://github.com/snap-stanford/UCE.git"
runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
211 changes: 211 additions & 0 deletions src/methods/uce/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import os
import pickle
import sys
import tarfile
import tempfile
import zipfile
from argparse import Namespace

import anndata as ad
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator

# Code has hardcoded paths that only work correctly inside the UCE directory
if os.path.isdir("UCE"):
# For executable we can work inside the UCE directory
os.chdir("UCE")
else:
# For Nextflow we need to copy files to the Nextflow working directory
print(">>> Copying UCE files to local directory...", flush=True)
import shutil

shutil.copytree("/workspace/UCE", ".", dirs_exist_ok=True)

# Append current directory to import UCE functions
sys.path.append(".")
from data_proc.data_utils import (
adata_path_to_prot_chrom_starts,
get_spec_chrom_csv,
get_species_to_pe,
process_raw_anndata,
)
from evaluate import run_eval

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad",
"output": "output.h5ad",
}
meta = {"name": "uce"}
## VIASH END

print(">>> Reading input...", flush=True)
sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")

if adata.uns["dataset_organism"] == "homo_sapiens":
species = "human"
elif adata.uns["dataset_organism"] == "mus_musculus":
species = "mouse"
else:
raise ValueError(f"Species '{adata.uns['dataset_organism']}' not yet implemented")

print("\n>>> Creating working directory...", flush=True)
work_dir = tempfile.TemporaryDirectory()
print(f"Working directory: '{work_dir.name}'", flush=True)

print("\n>>> Getting model files...", flush=True)
if os.path.isdir(par["model"]):
model_temp = None
model_dir = par["model"]
else:
model_temp = tempfile.TemporaryDirectory()
model_dir = model_temp.name

if zipfile.is_zipfile(par["model"]):
print("Extracting UCE model from .zip...", flush=True)
with zipfile.ZipFile(par["model"], "r") as zip_file:
zip_file.extractall(model_dir)
elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(".tar.gz"):
print("Extracting model from .tar.gz...", flush=True)
with tarfile.open(par["model"], "r:gz") as tar_file:
tar_file.extractall(model_dir)
model_dir = os.path.join(model_dir, os.listdir(model_dir)[0])
else:
raise ValueError(
f"The 'model' argument should be a directory a .zip file or a .tar.gz file"
)

print(f"Model directory: '{model_dir}'", flush=True)

print("Extracting protein embeddings...", flush=True)
with tarfile.open(
os.path.join(model_dir, "protein_embeddings.tar.gz"), "r:gz"
) as tar_file:
tar_file.extractall("./model_files")
protein_embeddings_dir = os.path.join("./model_files", "protein_embeddings")
print(f"Protein embeddings directory: '{protein_embeddings_dir}'", flush=True)

# The following sections implement methods in the UCE.evaluate.AnndataProcessor
# class due to the object not being compatible with the Open Problems setup
model_args = {
"dir": work_dir.name + "/",
"skip": True,
"filter": False, # Turn this off to get embedding for all cells
"name": "input",
"offset_pkl_path": os.path.join(model_dir, "species_offsets.pkl"),
"spec_chrom_csv_path": os.path.join(model_dir, "species_chrom.csv"),
"pe_idx_path": os.path.join(work_dir.name, "input_pe_row_idxs.pt"),
"chroms_path": os.path.join(work_dir.name, "input_chroms.pkl"),
"starts_path": os.path.join(work_dir.name, "input_starts.pkl"),
}

# AnndataProcessor.preprocess_anndata()
print("\n>>> Preprocessing data...", flush=True)
# Set var names to gene symbols
adata.var_names = adata.var["feature_name"]
adata.write_h5ad(os.path.join(model_args["dir"], "input.h5ad"))

row = pd.Series()
row.path = "input.h5ad"
row.covar_col = np.nan
row.species = species

processed_adata, num_cells, num_genes = process_raw_anndata(
row=row,
h5_folder_path=model_args["dir"],
npz_folder_path=model_args["dir"],
scp="",
skip=model_args["skip"],
additional_filter=model_args["filter"],
root=model_args["dir"],
)

# AnndataProcessor.generate_idxs()
print("\n>>> Generating indexes...", flush=True)
species_to_pe = get_species_to_pe(protein_embeddings_dir)
with open(model_args["offset_pkl_path"], "rb") as f:
species_to_offsets = pickle.load(f)
gene_to_chrom_pos = get_spec_chrom_csv(model_args["spec_chrom_csv_path"])
spec_pe_genes = list(species_to_pe[species].keys())
offset = species_to_offsets[species]
pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts(
processed_adata, species, spec_pe_genes, gene_to_chrom_pos, offset
)
torch.save({model_args["name"]: pe_row_idxs}, model_args["pe_idx_path"])
with open(model_args["chroms_path"], "wb+") as f:
pickle.dump({model_args["name"]: dataset_chroms}, f)
with open(model_args["starts_path"], "wb+") as f:
pickle.dump({model_args["name"]: dataset_pos}, f)

# AnndataProcessor.run_evaluation()
print("\n>>> Evaluating model...", flush=True)
model_parameters = Namespace(
token_dim=5120,
d_hid=5120,
nlayers=33, # Small model = 4, full model = 33
output_dim=1280,
multi_gpu=False,
token_file=os.path.join(model_dir, "all_tokens.torch"),
dir=model_args["dir"],
pad_length=1536,
sample_size=1024,
cls_token_idx=3,
CHROM_TOKEN_OFFSET=143574,
chrom_token_right_idx=2,
chrom_token_left_idx=1,
pad_token_idx=0,
)

if model_parameters.nlayers == 4:
model_parameters.model_loc = os.path.join(model_dir, "4layer_model.torch")
model_parameters.batch_size = 100
else:
model_parameters.model_loc = os.path.join(model_dir, "33l_8ep_1024t_1280.torch")
model_parameters.batch_size = 25

accelerator = Accelerator(project_dir=model_args["dir"])
accelerator.wait_for_everyone()
shapes_dict = {model_args["name"]: (num_cells, num_genes)}
run_eval(
adata=processed_adata,
name=model_args["name"],
pe_idx_path=model_args["pe_idx_path"],
chroms_path=model_args["chroms_path"],
starts_path=model_args["starts_path"],
shapes_dict=shapes_dict,
accelerator=accelerator,
args=model_parameters,
)

print("\n>>> Storing output...", flush=True)
embedded_adata = ad.read_h5ad(os.path.join(model_args["dir"], "input_uce_adata.h5ad"))
output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedded_adata.obsm["X_uce"],
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

print("\n>>> Writing output AnnData to file...", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Cleaning up temporary directories...", flush=True)
work_dir.cleanup()
if model_temp is not None:
model_temp.cleanup()

print("\n>>> Done!", flush=True)
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ dependencies:
- name: methods/scanvi
- name: methods/scimilarity
- name: methods/scvi
- name: methods/uce
# metrics
- name: metrics/asw_batch
- name: metrics/asw_label
Expand Down
5 changes: 4 additions & 1 deletion src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ methods = [
scimilarity.run(
args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")]
),
scvi
scvi,
uce.run(
args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")]
),
]

// construct list of metrics
Expand Down

0 comments on commit 7e0b381

Please sign in to comment.