From 7e0b381cc1db15e8fd2d45523e84f58b91078a5b Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Fri, 8 Nov 2024 17:35:08 +0100 Subject: [PATCH] Add UCE method (#7) * Create UCE component files * Add UCE dataset preprocessing * Generate UCE indexes * Evaluate UCE model and output results * Add UCE to benchmark workflow * Fix UCE model path in benchmark workflow * Copy UCE files to working directory for Nextflow * Exclude UCE in local benchmark scripts Requires more memory than allowed by the local labels config * Style UCE script --- scripts/run_benchmark/run_full_local.sh | 1 + scripts/run_benchmark/run_test_local.sh | 1 + src/methods/uce/config.vsh.yaml | 45 +++++ src/methods/uce/script.py | 211 ++++++++++++++++++++ src/workflows/run_benchmark/config.vsh.yaml | 1 + src/workflows/run_benchmark/main.nf | 5 +- 6 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 src/methods/uce/config.vsh.yaml create mode 100644 src/methods/uce/script.py diff --git a/scripts/run_benchmark/run_full_local.sh b/scripts/run_benchmark/run_full_local.sh index 5c83ddb..d823d79 100755 --- a/scripts/run_benchmark/run_full_local.sh +++ b/scripts/run_benchmark/run_full_local.sh @@ -26,6 +26,7 @@ input_states: resources/datasets/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" +settings: '{"methods_exclude": ["uce"]}' HERE # run the benchmark diff --git a/scripts/run_benchmark/run_test_local.sh b/scripts/run_benchmark/run_test_local.sh index d0fba74..2b72eee 100755 --- a/scripts/run_benchmark/run_test_local.sh +++ b/scripts/run_benchmark/run_test_local.sh @@ -21,6 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" +settings: '{"methods_exclude": ["uce"]}' HERE nextflow run . \ diff --git a/src/methods/uce/config.vsh.yaml b/src/methods/uce/config.vsh.yaml new file mode 100644 index 0000000..2301ef5 --- /dev/null +++ b/src/methods/uce/config.vsh.yaml @@ -0,0 +1,45 @@ +__merge__: ../../api/base_method.yaml + +name: uce +label: UCE +summary: UCE offers a unified biological latent space that can represent any cell +description: | + Universal Cell Embedding (UCE) is a single-cell foundation model that offers a + unified biological latent space that can represent any cell, regardless of + tissue or species +references: + doi: + - 10.1101/2023.11.28.568918 +links: + documentation: https://github.com/snap-stanford/UCE/blob/main/README.md + repository: https://github.com/snap-stanford/UCE + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: Path to the directory containing UCE model files or a .zip/.tar.gz archive + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + pypi: + - accelerate==0.24.0 + - type: docker + run: "git clone https://github.com/snap-stanford/UCE.git" +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/uce/script.py b/src/methods/uce/script.py new file mode 100644 index 0000000..24108a9 --- /dev/null +++ b/src/methods/uce/script.py @@ -0,0 +1,211 @@ +import os +import pickle +import sys +import tarfile +import tempfile +import zipfile +from argparse import Namespace + +import anndata as ad +import numpy as np +import pandas as pd +import torch +from accelerate import Accelerator + +# Code has hardcoded paths that only work correctly inside the UCE directory +if os.path.isdir("UCE"): + # For executable we can work inside the UCE directory + os.chdir("UCE") +else: + # For Nextflow we need to copy files to the Nextflow working directory + print(">>> Copying UCE files to local directory...", flush=True) + import shutil + + shutil.copytree("/workspace/UCE", ".", dirs_exist_ok=True) + +# Append current directory to import UCE functions +sys.path.append(".") +from data_proc.data_utils import ( + adata_path_to_prot_chrom_starts, + get_spec_chrom_csv, + get_species_to_pe, + process_raw_anndata, +) +from evaluate import run_eval + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", + "output": "output.h5ad", +} +meta = {"name": "uce"} +## VIASH END + +print(">>> Reading input...", flush=True) +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata + +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] == "homo_sapiens": + species = "human" +elif adata.uns["dataset_organism"] == "mus_musculus": + species = "mouse" +else: + raise ValueError(f"Species '{adata.uns['dataset_organism']}' not yet implemented") + +print("\n>>> Creating working directory...", flush=True) +work_dir = tempfile.TemporaryDirectory() +print(f"Working directory: '{work_dir.name}'", flush=True) + +print("\n>>> Getting model files...", flush=True) +if os.path.isdir(par["model"]): + model_temp = None + model_dir = par["model"] +else: + model_temp = tempfile.TemporaryDirectory() + model_dir = model_temp.name + + if zipfile.is_zipfile(par["model"]): + print("Extracting UCE model from .zip...", flush=True) + with zipfile.ZipFile(par["model"], "r") as zip_file: + zip_file.extractall(model_dir) + elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(".tar.gz"): + print("Extracting model from .tar.gz...", flush=True) + with tarfile.open(par["model"], "r:gz") as tar_file: + tar_file.extractall(model_dir) + model_dir = os.path.join(model_dir, os.listdir(model_dir)[0]) + else: + raise ValueError( + f"The 'model' argument should be a directory a .zip file or a .tar.gz file" + ) + +print(f"Model directory: '{model_dir}'", flush=True) + +print("Extracting protein embeddings...", flush=True) +with tarfile.open( + os.path.join(model_dir, "protein_embeddings.tar.gz"), "r:gz" +) as tar_file: + tar_file.extractall("./model_files") +protein_embeddings_dir = os.path.join("./model_files", "protein_embeddings") +print(f"Protein embeddings directory: '{protein_embeddings_dir}'", flush=True) + +# The following sections implement methods in the UCE.evaluate.AnndataProcessor +# class due to the object not being compatible with the Open Problems setup +model_args = { + "dir": work_dir.name + "/", + "skip": True, + "filter": False, # Turn this off to get embedding for all cells + "name": "input", + "offset_pkl_path": os.path.join(model_dir, "species_offsets.pkl"), + "spec_chrom_csv_path": os.path.join(model_dir, "species_chrom.csv"), + "pe_idx_path": os.path.join(work_dir.name, "input_pe_row_idxs.pt"), + "chroms_path": os.path.join(work_dir.name, "input_chroms.pkl"), + "starts_path": os.path.join(work_dir.name, "input_starts.pkl"), +} + +# AnndataProcessor.preprocess_anndata() +print("\n>>> Preprocessing data...", flush=True) +# Set var names to gene symbols +adata.var_names = adata.var["feature_name"] +adata.write_h5ad(os.path.join(model_args["dir"], "input.h5ad")) + +row = pd.Series() +row.path = "input.h5ad" +row.covar_col = np.nan +row.species = species + +processed_adata, num_cells, num_genes = process_raw_anndata( + row=row, + h5_folder_path=model_args["dir"], + npz_folder_path=model_args["dir"], + scp="", + skip=model_args["skip"], + additional_filter=model_args["filter"], + root=model_args["dir"], +) + +# AnndataProcessor.generate_idxs() +print("\n>>> Generating indexes...", flush=True) +species_to_pe = get_species_to_pe(protein_embeddings_dir) +with open(model_args["offset_pkl_path"], "rb") as f: + species_to_offsets = pickle.load(f) +gene_to_chrom_pos = get_spec_chrom_csv(model_args["spec_chrom_csv_path"]) +spec_pe_genes = list(species_to_pe[species].keys()) +offset = species_to_offsets[species] +pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts( + processed_adata, species, spec_pe_genes, gene_to_chrom_pos, offset +) +torch.save({model_args["name"]: pe_row_idxs}, model_args["pe_idx_path"]) +with open(model_args["chroms_path"], "wb+") as f: + pickle.dump({model_args["name"]: dataset_chroms}, f) +with open(model_args["starts_path"], "wb+") as f: + pickle.dump({model_args["name"]: dataset_pos}, f) + +# AnndataProcessor.run_evaluation() +print("\n>>> Evaluating model...", flush=True) +model_parameters = Namespace( + token_dim=5120, + d_hid=5120, + nlayers=33, # Small model = 4, full model = 33 + output_dim=1280, + multi_gpu=False, + token_file=os.path.join(model_dir, "all_tokens.torch"), + dir=model_args["dir"], + pad_length=1536, + sample_size=1024, + cls_token_idx=3, + CHROM_TOKEN_OFFSET=143574, + chrom_token_right_idx=2, + chrom_token_left_idx=1, + pad_token_idx=0, +) + +if model_parameters.nlayers == 4: + model_parameters.model_loc = os.path.join(model_dir, "4layer_model.torch") + model_parameters.batch_size = 100 +else: + model_parameters.model_loc = os.path.join(model_dir, "33l_8ep_1024t_1280.torch") + model_parameters.batch_size = 25 + +accelerator = Accelerator(project_dir=model_args["dir"]) +accelerator.wait_for_everyone() +shapes_dict = {model_args["name"]: (num_cells, num_genes)} +run_eval( + adata=processed_adata, + name=model_args["name"], + pe_idx_path=model_args["pe_idx_path"], + chroms_path=model_args["chroms_path"], + starts_path=model_args["starts_path"], + shapes_dict=shapes_dict, + accelerator=accelerator, + args=model_parameters, +) + +print("\n>>> Storing output...", flush=True) +embedded_adata = ad.read_h5ad(os.path.join(model_args["dir"], "input_uce_adata.h5ad")) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedded_adata.obsm["X_uce"], + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output AnnData to file...", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary directories...", flush=True) +work_dir.cleanup() +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index d3cc2b5..a4df670 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -96,6 +96,7 @@ dependencies: - name: methods/scanvi - name: methods/scimilarity - name: methods/scvi + - name: methods/uce # metrics - name: metrics/asw_batch - name: metrics/asw_label diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 89564bd..b208df9 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -32,7 +32,10 @@ methods = [ scimilarity.run( args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")] ), - scvi + scvi, + uce.run( + args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")] + ), ] // construct list of metrics