generated from openproblems-bio/task_template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create UCE component files * Add UCE dataset preprocessing * Generate UCE indexes * Evaluate UCE model and output results * Add UCE to benchmark workflow * Fix UCE model path in benchmark workflow * Copy UCE files to working directory for Nextflow * Exclude UCE in local benchmark scripts Requires more memory than allowed by the local labels config * Style UCE script
- Loading branch information
Showing
6 changed files
with
263 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
__merge__: ../../api/base_method.yaml | ||
|
||
name: uce | ||
label: UCE | ||
summary: UCE offers a unified biological latent space that can represent any cell | ||
description: | | ||
Universal Cell Embedding (UCE) is a single-cell foundation model that offers a | ||
unified biological latent space that can represent any cell, regardless of | ||
tissue or species | ||
references: | ||
doi: | ||
- 10.1101/2023.11.28.568918 | ||
links: | ||
documentation: https://github.com/snap-stanford/UCE/blob/main/README.md | ||
repository: https://github.com/snap-stanford/UCE | ||
|
||
info: | ||
method_types: [embedding] | ||
preferred_normalization: counts | ||
|
||
arguments: | ||
- name: --model | ||
type: file | ||
description: Path to the directory containing UCE model files or a .zip/.tar.gz archive | ||
required: true | ||
|
||
resources: | ||
- type: python_script | ||
path: script.py | ||
- path: /src/utils/read_anndata_partial.py | ||
|
||
engines: | ||
- type: docker | ||
image: openproblems/base_pytorch_nvidia:1.0.0 | ||
setup: | ||
- type: python | ||
pypi: | ||
- accelerate==0.24.0 | ||
- type: docker | ||
run: "git clone https://github.com/snap-stanford/UCE.git" | ||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [midtime, midmem, midcpu, gpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
import os | ||
import pickle | ||
import sys | ||
import tarfile | ||
import tempfile | ||
import zipfile | ||
from argparse import Namespace | ||
|
||
import anndata as ad | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from accelerate import Accelerator | ||
|
||
# Code has hardcoded paths that only work correctly inside the UCE directory | ||
if os.path.isdir("UCE"): | ||
# For executable we can work inside the UCE directory | ||
os.chdir("UCE") | ||
else: | ||
# For Nextflow we need to copy files to the Nextflow working directory | ||
print(">>> Copying UCE files to local directory...", flush=True) | ||
import shutil | ||
|
||
shutil.copytree("/workspace/UCE", ".", dirs_exist_ok=True) | ||
|
||
# Append current directory to import UCE functions | ||
sys.path.append(".") | ||
from data_proc.data_utils import ( | ||
adata_path_to_prot_chrom_starts, | ||
get_spec_chrom_csv, | ||
get_species_to_pe, | ||
process_raw_anndata, | ||
) | ||
from evaluate import run_eval | ||
|
||
## VIASH START | ||
# Note: this section is auto-generated by viash at runtime. To edit it, make changes | ||
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. | ||
par = { | ||
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", | ||
"output": "output.h5ad", | ||
} | ||
meta = {"name": "uce"} | ||
## VIASH END | ||
|
||
print(">>> Reading input...", flush=True) | ||
sys.path.append(meta["resources_dir"]) | ||
from read_anndata_partial import read_anndata | ||
|
||
adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") | ||
|
||
if adata.uns["dataset_organism"] == "homo_sapiens": | ||
species = "human" | ||
elif adata.uns["dataset_organism"] == "mus_musculus": | ||
species = "mouse" | ||
else: | ||
raise ValueError(f"Species '{adata.uns['dataset_organism']}' not yet implemented") | ||
|
||
print("\n>>> Creating working directory...", flush=True) | ||
work_dir = tempfile.TemporaryDirectory() | ||
print(f"Working directory: '{work_dir.name}'", flush=True) | ||
|
||
print("\n>>> Getting model files...", flush=True) | ||
if os.path.isdir(par["model"]): | ||
model_temp = None | ||
model_dir = par["model"] | ||
else: | ||
model_temp = tempfile.TemporaryDirectory() | ||
model_dir = model_temp.name | ||
|
||
if zipfile.is_zipfile(par["model"]): | ||
print("Extracting UCE model from .zip...", flush=True) | ||
with zipfile.ZipFile(par["model"], "r") as zip_file: | ||
zip_file.extractall(model_dir) | ||
elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(".tar.gz"): | ||
print("Extracting model from .tar.gz...", flush=True) | ||
with tarfile.open(par["model"], "r:gz") as tar_file: | ||
tar_file.extractall(model_dir) | ||
model_dir = os.path.join(model_dir, os.listdir(model_dir)[0]) | ||
else: | ||
raise ValueError( | ||
f"The 'model' argument should be a directory a .zip file or a .tar.gz file" | ||
) | ||
|
||
print(f"Model directory: '{model_dir}'", flush=True) | ||
|
||
print("Extracting protein embeddings...", flush=True) | ||
with tarfile.open( | ||
os.path.join(model_dir, "protein_embeddings.tar.gz"), "r:gz" | ||
) as tar_file: | ||
tar_file.extractall("./model_files") | ||
protein_embeddings_dir = os.path.join("./model_files", "protein_embeddings") | ||
print(f"Protein embeddings directory: '{protein_embeddings_dir}'", flush=True) | ||
|
||
# The following sections implement methods in the UCE.evaluate.AnndataProcessor | ||
# class due to the object not being compatible with the Open Problems setup | ||
model_args = { | ||
"dir": work_dir.name + "/", | ||
"skip": True, | ||
"filter": False, # Turn this off to get embedding for all cells | ||
"name": "input", | ||
"offset_pkl_path": os.path.join(model_dir, "species_offsets.pkl"), | ||
"spec_chrom_csv_path": os.path.join(model_dir, "species_chrom.csv"), | ||
"pe_idx_path": os.path.join(work_dir.name, "input_pe_row_idxs.pt"), | ||
"chroms_path": os.path.join(work_dir.name, "input_chroms.pkl"), | ||
"starts_path": os.path.join(work_dir.name, "input_starts.pkl"), | ||
} | ||
|
||
# AnndataProcessor.preprocess_anndata() | ||
print("\n>>> Preprocessing data...", flush=True) | ||
# Set var names to gene symbols | ||
adata.var_names = adata.var["feature_name"] | ||
adata.write_h5ad(os.path.join(model_args["dir"], "input.h5ad")) | ||
|
||
row = pd.Series() | ||
row.path = "input.h5ad" | ||
row.covar_col = np.nan | ||
row.species = species | ||
|
||
processed_adata, num_cells, num_genes = process_raw_anndata( | ||
row=row, | ||
h5_folder_path=model_args["dir"], | ||
npz_folder_path=model_args["dir"], | ||
scp="", | ||
skip=model_args["skip"], | ||
additional_filter=model_args["filter"], | ||
root=model_args["dir"], | ||
) | ||
|
||
# AnndataProcessor.generate_idxs() | ||
print("\n>>> Generating indexes...", flush=True) | ||
species_to_pe = get_species_to_pe(protein_embeddings_dir) | ||
with open(model_args["offset_pkl_path"], "rb") as f: | ||
species_to_offsets = pickle.load(f) | ||
gene_to_chrom_pos = get_spec_chrom_csv(model_args["spec_chrom_csv_path"]) | ||
spec_pe_genes = list(species_to_pe[species].keys()) | ||
offset = species_to_offsets[species] | ||
pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts( | ||
processed_adata, species, spec_pe_genes, gene_to_chrom_pos, offset | ||
) | ||
torch.save({model_args["name"]: pe_row_idxs}, model_args["pe_idx_path"]) | ||
with open(model_args["chroms_path"], "wb+") as f: | ||
pickle.dump({model_args["name"]: dataset_chroms}, f) | ||
with open(model_args["starts_path"], "wb+") as f: | ||
pickle.dump({model_args["name"]: dataset_pos}, f) | ||
|
||
# AnndataProcessor.run_evaluation() | ||
print("\n>>> Evaluating model...", flush=True) | ||
model_parameters = Namespace( | ||
token_dim=5120, | ||
d_hid=5120, | ||
nlayers=33, # Small model = 4, full model = 33 | ||
output_dim=1280, | ||
multi_gpu=False, | ||
token_file=os.path.join(model_dir, "all_tokens.torch"), | ||
dir=model_args["dir"], | ||
pad_length=1536, | ||
sample_size=1024, | ||
cls_token_idx=3, | ||
CHROM_TOKEN_OFFSET=143574, | ||
chrom_token_right_idx=2, | ||
chrom_token_left_idx=1, | ||
pad_token_idx=0, | ||
) | ||
|
||
if model_parameters.nlayers == 4: | ||
model_parameters.model_loc = os.path.join(model_dir, "4layer_model.torch") | ||
model_parameters.batch_size = 100 | ||
else: | ||
model_parameters.model_loc = os.path.join(model_dir, "33l_8ep_1024t_1280.torch") | ||
model_parameters.batch_size = 25 | ||
|
||
accelerator = Accelerator(project_dir=model_args["dir"]) | ||
accelerator.wait_for_everyone() | ||
shapes_dict = {model_args["name"]: (num_cells, num_genes)} | ||
run_eval( | ||
adata=processed_adata, | ||
name=model_args["name"], | ||
pe_idx_path=model_args["pe_idx_path"], | ||
chroms_path=model_args["chroms_path"], | ||
starts_path=model_args["starts_path"], | ||
shapes_dict=shapes_dict, | ||
accelerator=accelerator, | ||
args=model_parameters, | ||
) | ||
|
||
print("\n>>> Storing output...", flush=True) | ||
embedded_adata = ad.read_h5ad(os.path.join(model_args["dir"], "input_uce_adata.h5ad")) | ||
output = ad.AnnData( | ||
obs=adata.obs[[]], | ||
var=adata.var[[]], | ||
obsm={ | ||
"X_emb": embedded_adata.obsm["X_uce"], | ||
}, | ||
uns={ | ||
"dataset_id": adata.uns["dataset_id"], | ||
"normalization_id": adata.uns["normalization_id"], | ||
"method_id": meta["name"], | ||
}, | ||
) | ||
print(output) | ||
|
||
print("\n>>> Writing output AnnData to file...", flush=True) | ||
output.write_h5ad(par["output"], compression="gzip") | ||
|
||
print("\n>>> Cleaning up temporary directories...", flush=True) | ||
work_dir.cleanup() | ||
if model_temp is not None: | ||
model_temp.cleanup() | ||
|
||
print("\n>>> Done!", flush=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters