diff --git a/.gitignore b/.gitignore index 72c2207e5..21ea229dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # /archive/ +example_notebooks/* # .DS_Store diff --git a/.vscode/settings.json b/.vscode/settings.json index 7d5509522..b41b8ad69 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -56,6 +56,7 @@ "search.followSymlinks": false, "terminal.integrated.fontSize": 14, "terminal.integrated.scrollback": 100000, + "python.terminal.activateEnvironment": false, "workbench.colorTheme": "Catppuccin Mocha", "workbench.iconTheme": "vscode-icons", // Passing --no-cov to pytestArgs is required to respect breakpoints diff --git a/src/pyrovelocity/models/__init__.py b/src/pyrovelocity/models/__init__.py index 4dff5d0ed..20fdb7542 100644 --- a/src/pyrovelocity/models/__init__.py +++ b/src/pyrovelocity/models/__init__.py @@ -7,13 +7,14 @@ from pyrovelocity.models._deterministic_simulation import ( solve_transcription_splicing_model_analytical, ) -from pyrovelocity.models._transcription_dynamics import mrna_dynamics +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics from pyrovelocity.models._velocity import PyroVelocity __all__ = [ deterministic_transcription_splicing_probabilistic_model, mrna_dynamics, + atac_mrna_dynamics, PyroVelocity, solve_transcription_splicing_model, solve_transcription_splicing_model_analytical, diff --git a/src/pyrovelocity/models/_trainer.py b/src/pyrovelocity/models/_trainer.py index fe3505cf2..5bd377b2c 100644 --- a/src/pyrovelocity/models/_trainer.py +++ b/src/pyrovelocity/models/_trainer.py @@ -271,8 +271,8 @@ def train_faster( if scipy.sparse.issparse(self.adata.layers["raw_spliced"]) else self.adata.layers["raw_spliced"], dtype=torch.float32, - ).to(device) - + ).to(device) + epsilon = 1e-6 log_u_library_size = np.log( @@ -335,60 +335,127 @@ def train_faster( losses = [] patience = patient_init - for step in range(max_epochs): - if cell_state is None: - elbos = ( - svi.step( - u, - s, - u_library.reshape(-1, 1), - s_library.reshape(-1, 1), - u_library_mean.reshape(-1, 1), - s_library_mean.reshape(-1, 1), - u_library_scale.reshape(-1, 1), - s_library_scale.reshape(-1, 1), - None, - None, + + if not self.adata.uns['atac']: + + for step in range(max_epochs): + if cell_state is None: + elbos = ( + svi.step( + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + None, + ) + / normalizer ) - / normalizer - ) - else: - elbos = ( - svi.step( - u, - s, - u_library.reshape(-1, 1), - s_library.reshape(-1, 1), - u_library_mean.reshape(-1, 1), - s_library_mean.reshape(-1, 1), - u_library_scale.reshape(-1, 1), - s_library_scale.reshape(-1, 1), - None, - cell_state.reshape(-1, 1), + else: + elbos = ( + svi.step( + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + cell_state.reshape(-1, 1), + ) + / normalizer + ) + if (step == 0) or ( + ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) + ): + mlflow.log_metric("-ELBO", -elbos, step=step + 1) + logger.info( + f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" + ) + if step > log_every: + if (losses[-1] - elbos) < losses[-1] * patient_improve: + patience -= 1 + else: + patience = patient_init + if patience <= 0: + break + losses.append(elbos) + + else: + + c = torch.tensor( + np.array( + self.adata.layers["atac"].toarray(), dtype="float32" + ) + if scipy.sparse.issparse(self.adata.layers["atac"]) + else self.adata.layers["atac"], + dtype=torch.float32, + ).to(device) + + + for step in range(max_epochs): + if cell_state is None: + elbos = ( + svi.step( + c, + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + None, + ) + / normalizer ) - / normalizer - ) - if (step == 0) or ( - ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) - ): - mlflow.log_metric("-ELBO", -elbos, step=step + 1) - logger.info( - f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" - ) - if step > log_every: - if (losses[-1] - elbos) < losses[-1] * patient_improve: - patience -= 1 else: - patience = patient_init - if patience <= 0: - break - losses.append(elbos) + elbos = ( + svi.step( + c, + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + cell_state.reshape(-1, 1), + ) + / normalizer + ) + if (step == 0) or ( + ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) + ): + mlflow.log_metric("-ELBO", -elbos, step=step + 1) + logger.info( + f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" + ) + if step > log_every: + if (losses[-1] - elbos) < losses[-1] * patient_improve: + patience -= 1 + else: + patience = patient_init + if patience <= 0: + break + losses.append(elbos) + mlflow.log_metric("-ELBO", -elbos, step=step + 1) mlflow.log_metric("real_epochs", step + 1) logger.info( f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" ) - return losses + return losses def train_faster_with_batch( self, diff --git a/src/pyrovelocity/models/_transcription_dynamics.py b/src/pyrovelocity/models/_transcription_dynamics.py index 03c531828..aba2a317f 100644 --- a/src/pyrovelocity/models/_transcription_dynamics.py +++ b/src/pyrovelocity/models/_transcription_dynamics.py @@ -60,6 +60,289 @@ def mrna_dynamics( return ut, st +@beartype +def atac_mrna_dynamics( + tau: Tensor, + c0: Tensor, + u0: Tensor, + s0: Tensor, + k_c: Tensor, + alpha_c: Tensor, + alpha: Tensor, + beta: Tensor, + gamma: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes the ATAC and mRNA dynamics given temporal coordinate, parameter values, and + initial conditions. + + `st_gamma_equals_beta` for the case where the gamma parameter is equal + to the beta parameter is taken from Equation 2.12 of + + Args: + tau (Tensor): Time points starting at last change in RNA transcription rate. + c0 (Tensor): Initial value of c. + u0 (Tensor): Initial value of u. + s0 (Tensor): Initial value of s. + k_c (Tensor): Chromatin state. + alpha_c (Tensor): Rate of chromatin opening/closing. + alpha (Tensor): Alpha parameter. + beta (Tensor): Beta parameter. + gamma (Tensor): Gamma parameter. + + Returns: + Tuple[Tensor, Tensor]: Tuple containing the final values of c, u and s. + + Examples: + >>> import torch + >>> tau = torch.tensor(2.0) + >>> c0 = torch.tensor(1.0) + >>> u0 = torch.tensor(1.0) + >>> s0 = torch.tensor(0.5) + >>> alpha_c = torch.tensor(0.45) + >>> alpha = torch.tensor(0.5) + >>> beta = torch.tensor(0.4) + >>> gamma = torch.tensor(0.3) + >>> k_c = torch.tensor(1.0) + >>> atac_mrna_dynamics(tau_c, tau, c0, u0, s0, k_c, alpha_c, alpha, beta, gamma) + >>> import torch + >>> input = [torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]]), torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([0.1000, 0.2000]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([0.0900, 0.1100]), torch.tensor([0.0500, 0.0600])] + >>> tau_vec = input[0] + >>> c0_vec = input[1] + >>> u0_vec = input[2] + >>> s0_vec = input[3] + >>> k_c_vec = input[4] + >>> alpha_c = input[5] + >>> alpha_vec = input[6] + >>> beta = input[7] + >>> gamma = input[8] + >>> atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + (tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + """ + + A = torch.exp(-alpha_c * tau) + B = torch.exp(-beta * tau) + C = torch.exp(-gamma * tau) + + ct = c0 * A + k_c * (1 - A) + ut = ( + u0 * B + + alpha * k_c / beta * (1 - B) + + (k_c - c0) * alpha / (beta - alpha_c) * (B - A) + ) + st = s0 * C + alpha * k_c / gamma * (1 - C) + +beta / (gamma - beta) * ( + (alpha * k_c) / beta - u0 - (k_c - c0) * alpha / (beta - alpha_c) + ) * (C - B) + +beta / (gamma - alpha_c) * (k_c - c0) * alpha / (beta - alpha_c) * (C - A) + + return ct, ut, st + +@beartype +def get_initial_states( + t0_state: Tensor, + k_c_state: Tensor, + alpha_c: Tensor, + alpha_state: Tensor, + beta: Tensor, + gamma: Tensor, + state: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes initial conditions of chromatin and mRNA in each cell. + + Args: + t0_state (Tensor): The switch times of each gene (1 for each state). + k_c_state (Tensor): The chromatin state in each state. + alpha_c (Tensor): The chromatin opening and closing rate. + alpha_state (Tensor): The transcription rate of each gene in each state. + beta (Tensor): The splicing rate of each gene. + gamma (Tensor): The degradation rate of each gene. + state (Tensor): The state of each cell. + + Returns: + Tuple[Tensor, Tensor, Tensor]: Tuple containing the initial conditions of + c, u and s for each cell. + + Examples: + >>> import torch + >>> alpha_c = torch.tensor((0.1, 0.2)) + >>> beta = torch.tensor((0.09, 0.11)) + >>> gamma = torch.tensor((0.05, 0.06)) + >>> state = torch.tensor([[0, 0],[2, 2],[2, 2],[3, 3]]) + >>> k_c_state = torch.tensor([[0., 1., 1., 0., 0.], [0., 1., 1., 1., 0.]]) + >>> alpha_state = torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000],[0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]) + >>> t0_state = torch.tensor([[ 0., 10., 25., 75., 102.],[ 0., 10., 25., 78., 95.]]) + >>> get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + (torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]])) + """ + + n_genes = t0_state.shape[0] + c0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + u0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + s0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + dt_state = t0_state - torch.stack([torch.zeros((2)), torch.zeros((2)), + t0_state[:,1], t0_state[:,2], t0_state[:,3]], dim = 1) # genes, states + for i in range(1, 4): + c0_i, u0_i, s0_i = atac_mrna_dynamics( + dt_state[:, i+1], c0_state_list[-1], u0_state_list[-1], s0_state_list[-1], k_c_state[:, i], + alpha_c, alpha_state[:, i], beta, gamma + ) + c0_state_list += [c0_i] + u0_state_list += [u0_i] + s0_state_list += [s0_i] + + c0_state = torch.stack(c0_state_list, dim = 1) + u0_state = torch.stack(u0_state_list, dim = 1) + s0_state = torch.stack(s0_state_list, dim = 1) + + c0_vec = c0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + u0_vec = u0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + s0_vec = s0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + + return c0_vec, u0_vec, s0_vec + +@beartype +def get_cell_parameters( + t: Tensor, + t0_1: Tensor, + dt_1: Tensor, + dt_2: Tensor, + dt_3: Tensor, + alpha: Tensor, + alpha_off: Tensor, + k: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Gets the ODE parameters for each cell, by first assign each gene in each cell to a state + based on state switch times of a gene and then computes the transcription rate, chromatin state + and time since last state switch(tau) for each gene in each cell. + + Args: + t (Tensor): The time of each cell. + t0_1 (Tensor): Start time for chromatin opening. + dt_1 (Tensor): Time gap since chromatin opening for transcription start for each gene. + dt_2 (Tensor): Time gap since transcription start for chromatin closing for each gene. + dt_3 (Tensor): Time gap since transcription start for transcription stopping for each gene. + alpha (Tensor): The transcription rate of each gene in the on state. + alpha_off (Tensor): The transcription rate of each gene in the off state. + k (Tensor): The activation state of each gene in each state. + + Returns: + Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: Tuple containing the state of each cell (state), + the switch time of each state (t0_state), the chromatin opening state (k_c_state), the transcription rate in each cell + (alpha_state) and cell-specific parameters for the chromatin state (k_c_vec) transcription rate (alpha_vec) and + time (tau_vec) since last state switch. + + Examples: + >>> import torch + + >>> n_cells = 4 + >>> t = torch.arange(0, 120, 30).reshape(n_cells, 1) + >>> t0_1 = torch.tensor((10.0, 10.0)) + >>> dt_1 = torch.tensor((15.0, 15.0)) + >>> dt_2 = torch.tensor((77.0, 53.0)) + >>> dt_3 = torch.tensor((50.0, 70.0)) + >>> alpha = torch.tensor((0.5, 0.3)) + >>> alpha_off = torch.tensor(0.0) + >>> k = torch.tensor((1.0, 1.0),(1.0,1.0)) + >>> get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k + ) + (tensor([[0, 0], + [2, 2], + [2, 2], + [3, 3]]),tensor([[0., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]), tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]), tensor([[ 0., 10., 25., 75., 102.], + [ 0., 10., 25., 78., 95.]]), tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]])) + """ + + # Assign each gene in each cell to a state: + t0_2 = t0_1 + dt_1 + boolean = dt_2 >= dt_3 # True means chromatin starts closing, before transcription stops. + t0_3 = torch.where(boolean, t0_2 + dt_3, t0_2 + dt_2) + t0_4 = torch.where(~boolean, t0_2 + dt_3, t0_2 + dt_2) + state = ((t0_1 <= t).int() + (t0_2 <= t).int() + (t0_3 <= t).int() + (t0_4 <= t).int()) # cells, genes + n_genes = state.shape[1] + state = state * (1-1*k) + + t0_state = torch.stack([torch.zeros_like(t0_1), t0_1, t0_2, t0_3, t0_4], dim=1) # genes, states + t0_vec = t0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + tau_vec = t - t0_vec # cells, genes + + alpha_state = torch.stack([ + torch.ones_like(t0_1) * alpha_off, + torch.ones_like(t0_1) * alpha_off, + torch.ones_like(t0_1) * alpha, + torch.where(boolean, torch.ones_like(t0_1) * alpha, torch.ones_like(t0_1) * alpha_off), + torch.ones_like(t0_1) * alpha_off + ], dim=1) # genes, states + + k_c_state = torch.stack([ + torch.zeros_like(t0_1), + torch.ones_like(t0_1), + torch.ones_like(t0_1), + torch.where(boolean, torch.zeros_like(t0_1), torch.ones_like(t0_1)), + torch.zeros_like(t0_1) + ], dim=1) # genes, states + + alpha_vec = alpha_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + k_c_vec = k_c_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + + return state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec + @beartype def inv(x: Tensor) -> Tensor: diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index 95e09310d..201aa0a41 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -2,10 +2,7 @@ import pickle import sys from statistics import harmonic_mean -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Union +from typing import Dict, Optional, Sequence, Union import mlflow import numpy as np @@ -15,24 +12,25 @@ from beartype import beartype from numpy import ndarray from scvi.data import AnnDataManager -from scvi.data._constants import _SETUP_ARGS_KEY -from scvi.data._constants import _SETUP_METHOD_NAME -from scvi.data.fields import LayerField -from scvi.data.fields import NumericalObsField +from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME +from scvi.data.fields import LayerField, NumericalObsField from scvi.model._utils import parse_device_args from scvi.model.base import BaseModelClass -from scvi.model.base._utils import _initialize_model -from scvi.model.base._utils import _load_saved_files -from scvi.model.base._utils import _validate_var_names +from scvi.model.base._utils import ( + _initialize_model, + _load_saved_files, + _validate_var_names, +) from scvi.module.base import PyroBaseModuleClass -from pyrovelocity.analysis.analyze import compute_mean_vector_field -from pyrovelocity.analysis.analyze import compute_volcano_data -from pyrovelocity.analysis.analyze import vector_field_uncertainty +from pyrovelocity.analysis.analyze import ( + compute_mean_vector_field, + compute_volcano_data, + vector_field_uncertainty, +) from pyrovelocity.logging import configure_logging from pyrovelocity.models._trainer import VelocityTrainingMixin -from pyrovelocity.models._velocity_module import VelocityModule - +from pyrovelocity.models._velocity_module import VelocityModule, MultiVelocityModule __all__ = ["PyroVelocity"] @@ -101,6 +99,7 @@ class PyroVelocity(VelocityTrainingMixin, BaseModelClass): def __init__( self, adata: AnnData, + adata_atac: Optional[AnnData] = None, input_type: str = "raw", shared_time: bool = True, model_type: str = "auto", @@ -128,6 +127,7 @@ def __init__( Args: adata (AnnData): An AnnData object containing the gene expression data. + adata_atac (Optional[AnnData], optional): An AnnData object containing atac data. input_type (str, optional): Type of input data. Can be "raw", "knn", or "raw_cpm". Defaults to "raw". shared_time (bool, optional): Whether to use shared time. Defaults to True. model_type (str, optional): Type of model to use. Defaults to "auto". @@ -246,30 +246,56 @@ def __init__( # else: initial_values = {} logger.info(self.summary_stats) - self.module = VelocityModule( - self.summary_stats["n_cells"], - self.summary_stats["n_vars"], - model_type=model_type, - guide_type=guide_type, - likelihood=likelihood, - shared_time=shared_time, - t_scale_on=t_scale_on, - plate_size=plate_size, - latent_factor=latent_factor, - latent_factor_operation=latent_factor_operation, - latent_factor_size=latent_factor_size, - inducing_point_size=inducing_point_size, - include_prior=include_prior, - use_gpu=use_gpu, - num_aux_cells=num_aux_cells, - only_cell_times=only_cell_times, - decoder_on=decoder_on, - add_offset=add_offset, - correct_library_size=correct_library_size, - cell_specific_kinetics=cell_specific_kinetics, - kinetics_num=self.k, - **initial_values, - ) + if not adata_atac: + self.module = VelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) + else: + self.module = MultiVelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=False, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) self.num_cells = self.module.num_cells self._model_summary_string = """ RNA velocity Pyro model with parameters: @@ -298,7 +324,7 @@ def enum_parallel_predict(self): return @classmethod - def setup_anndata(cls, adata: AnnData, *args, **kwargs): + def setup_anndata(cls, adata: AnnData, adata_atac = None, *args, **kwargs): """ Set up AnnData object for compatibility with the scvi-tools model training interface. @@ -332,9 +358,18 @@ def setup_anndata(cls, adata: AnnData, *args, **kwargs): NumericalObsField("s_lib_size_scale", "s_lib_size_scale"), NumericalObsField("ind_x", "ind_x"), ] + + if adata_atac: + adata.layers['atac'] = adata_atac.X + anndata_fields += [LayerField('atac', 'atac')] + adata.uns['atac'] = True + else: + adata.uns['atac'] = None + adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) + adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/pyrovelocity/models/_velocity_model.py b/src/pyrovelocity/models/_velocity_model.py index 63ca3fd8a..3aaad3492 100644 --- a/src/pyrovelocity/models/_velocity_model.py +++ b/src/pyrovelocity/models/_velocity_model.py @@ -1,27 +1,19 @@ -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import pyro import torch from beartype import beartype -from jaxtyping import Float -from jaxtyping import jaxtyped +from jaxtyping import Float, jaxtyped from pyro import poutine -from pyro.distributions import Bernoulli -from pyro.distributions import LogNormal -from pyro.distributions import Normal -from pyro.distributions import Poisson -from pyro.nn import PyroModule -from pyro.nn import PyroSample +from pyro.distributions import Bernoulli, LogNormal, Normal, Poisson +from pyro.nn import PyroModule, PyroSample from pyro.primitives import plate from scvi.nn import Decoder -from torch.nn.functional import relu -from torch.nn.functional import softplus +from torch.nn.functional import relu, softplus +from torch import Tensor from pyrovelocity.logging import configure_logging -from pyrovelocity.models._transcription_dynamics import mrna_dynamics - +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics, get_initial_states, get_cell_parameters logger = configure_logging(__name__) @@ -36,6 +28,11 @@ Float[torch.Tensor, "samples num_cells num_genes"], ] +__all__ = [ + "LogNormalModel", + "VelocityModelAuto", + "MultiVelocityModelAuto", +] class LogNormalModel(PyroModule): """ @@ -154,6 +151,10 @@ def create_plates( gene_plate = pyro.plate("genes", self.num_genes, dim=-1) return cell_plate, gene_plate + @PyroSample + def alpha_c(self): + return self._pyrosample_helper(1.0) + @PyroSample def alpha(self): return self._pyrosample_helper(1.0) @@ -182,6 +183,18 @@ def u_inf(self): def s_inf(self): return self._pyrosample_helper(0.1) + @PyroSample + def dt_switching_c(self): + return self._pyrosample_helper(1.0) + + @PyroSample + def delay(self): + return self._pyrosample_helper(1.0) + + @PyroSample + def dt_switching_c(self): + return self._pyrosample_helper(1.0) + @PyroSample def dt_switching(self): return self._pyrosample_helper(1.0) @@ -305,8 +318,99 @@ def get_likelihood( u_dist = Poisson(ut) s_dist = Poisson(st) + return u_dist, s_dist + + @beartype + def get_likelihood_multiome( + self, + ct: torch.Tensor, + ut: torch.Tensor, + st: torch.Tensor, + sigma_c: torch.Tensor, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_scale: Optional[torch.Tensor] = None, + s_scale: Optional[torch.Tensor] = None, + u_read_depth: Optional[torch.Tensor] = None, + s_read_depth: Optional[torch.Tensor] = None, + u_cell_size_coef: None = None, + ut_coef: None = None, + s_cell_size_coef: None = None, + st_coef: None = None, + ) -> Tuple[LogNormal, Poisson, Poisson]: + """ + Compute the likelihood of the given count data. + + Args: + ct (torch.Tensor): Tensor representing chromatin state. + ut (torch.Tensor): Tensor representing unspliced transcripts. + st (torch.Tensor): Tensor representing spliced transcripts. + sigma_c (torch.Tensor): Tensor representing standard deviation of chromatin state. + u_log_library (Optional[torch.Tensor], optional): Log library tensor for unspliced transcripts. Defaults to None. + s_log_library (Optional[torch.Tensor], optional): Log library tensor for spliced transcripts. Defaults to None. + u_scale (Optional[torch.Tensor], optional): Scale tensor for unspliced transcripts. Defaults to None. + s_scale (Optional[torch.Tensor], optional): Scale tensor for spliced transcripts. Defaults to None. + u_read_depth (Optional[torch.Tensor], optional): Read depth tensor for unspliced transcripts. Defaults to None. + s_read_depth (Optional[torch.Tensor], optional): Read depth tensor for spliced transcripts. Defaults to None. + u_cell_size_coef (Optional[Any], optional): Cell size coefficient for unspliced transcripts. Defaults to None. + ut_coef (Optional[Any], optional): Coefficient for unspliced transcripts. Defaults to None. + s_cell_size_coef (Optional[Any], optional): Cell size coefficient for spliced transcripts. Defaults to None. + st_coef (Optional[Any], optional): Coefficient for spliced transcripts. Defaults to None. + Returns: + Tuple[Poisson, Poisson]: A tuple of Poisson distributions for unspliced and spliced transcripts, respectively. + + Example: + >>> import torch + >>> from pyrovelocity.models._velocity_model import LogNormalModel + >>> num_cells = 10 + >>> num_genes = 20 + >>> likelihood = "Poisson" + >>> plate_size = 2 + >>> model = LogNormalModel(num_cells, num_genes, likelihood, plate_size) + >>> logger.info(model) + >>> ut = torch.rand(num_cells, num_genes) + >>> st = torch.rand(num_cells, num_genes) + >>> u_read_depth = torch.rand(num_cells, 1) + >>> s_read_depth = torch.rand(num_cells, 1) + >>> u_dist, s_dist = model.get_likelihood(ut, st, u_read_depth=u_read_depth, s_read_depth=s_read_depth) + >>> logger.info(f"u_dist: {u_dist}") + >>> logger.info(f"s_dist: {s_dist}") + >>> assert isinstance(u_dist, torch.distributions.Poisson) + >>> assert isinstance(s_dist, torch.distributions.Poisson) + """ + if self.likelihood != "Poisson": + likelihood_not_implemented_msg = ( + "In the future, the likelihood will be referred to via a " + "member of a sum type over supported distributions" + ) + raise NotImplementedError(likelihood_not_implemented_msg) + + if self.correct_library_size: + ut = relu(ut) + self.one * 1e-6 + st = relu(st) + self.one * 1e-6 + ut = pyro.deterministic("ut", ut, event_dim=0) + st = pyro.deterministic("st", st, event_dim=0) + ut = ut / torch.sum(ut, dim=-1, keepdim=True) + st = st / torch.sum(st, dim=-1, keepdim=True) + ut = pyro.deterministic("ut_norm", ut, event_dim=0) + st = pyro.deterministic("st_norm", st, event_dim=0) + ut = (ut + self.one * 1e-6) * u_read_depth + st = (st + self.one * 1e-6) * s_read_depth + else: + ut = relu(ut) + st = relu(st) + ut = pyro.deterministic("ut", ut, event_dim=0) + st = pyro.deterministic("st", st, event_dim=0) + ut = ut + self.one * 1e-6 + st = st + self.one * 1e-6 + + c_dist = LogNormal(ct, sigma_c) + u_dist = Poisson(ut) + s_dist = Poisson(st) + + return c_dist, u_dist, s_dist class VelocityModelAuto(LogNormalModel): """Automatically configured velocity model. @@ -698,3 +802,435 @@ def forward( u = pyro.sample("u", u_dist, obs=u_obs) s = pyro.sample("s", s_dist, obs=s_obs) return u, s + +class MultiVelocityModelAuto(LogNormalModel): + """Automatically configured MULTIOME velocity model. + + Args: + num_cells (int): _description_ + num_genes (int): _description_ + likelihood (str, optional): _description_. Defaults to "Poisson". + shared_time (bool, optional): _description_. Defaults to True. + t_scale_on (bool, optional): _description_. Defaults to False. + plate_size (int, optional): _description_. Defaults to 2. + latent_factor (str, optional): _description_. Defaults to "none". + latent_factor_size (int, optional): _description_. Defaults to 30. + latent_factor_operation (str, optional): _description_. Defaults to "selection". + include_prior (bool, optional): _description_. Defaults to False. + num_aux_cells (int, optional): _description_. Defaults to 100. + only_cell_times (bool, optional): _description_. Defaults to False. + decoder_on (bool, optional): _description_. Defaults to False. + add_offset (bool, optional): _description_. Defaults to False. + correct_library_size (Union[bool, str], optional): _description_. Defaults to True. + guide_type (str, optional): _description_. Defaults to "velocity". + cell_specific_kinetics (Optional[star], optional): _description_. Defaults to None. + kinetics_num (Optional[int], optional): _description_. Defaults to None. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> model = VelocityModelAuto( + ... 3, + ... 4, + ... "Poisson", + ... True, + ... False, + ... 2, + ... "none", + ... latent_factor_operation="selection", + ... latent_factor_size=10, + ... include_prior=False, + ... num_aux_cells=0, + ... only_cell_times=True, + ... decoder_on=False, + ... add_offset=False, + ... correct_library_size=True, + ... guide_type="auto_t0_constraint", + ... cell_specific_kinetics=None, + ... **{} + ... ) + >>> logger.info(model) + """ + + @beartype + def __init__( + self, + num_cells: int, + num_genes: int, + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_size: int = 30, + latent_factor_operation: str = "selection", + include_prior: bool = False, + num_aux_cells: int = 100, + only_cell_times: bool = False, + decoder_on: bool = False, + add_offset: bool = False, + correct_library_size: Union[bool, str] = True, + guide_type: str = "velocity", + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + assert num_cells > 0 and num_genes > 0 + super().__init__(num_cells, num_genes, likelihood, plate_size) + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + self.guide_type = guide_type + self.cell_specific_kinetics = cell_specific_kinetics + self.k = kinetics_num + + self.mask = initial_values.get( + "mask", torch.ones(self.num_cells, self.num_genes).bool() + ) + for key in initial_values: + self.register_buffer(f"{key}_init", initial_values[key]) + + self.shared_time = shared_time + self.t_scale_on = t_scale_on + self.add_offset = add_offset + self.plate_size = plate_size + + self.latent_factor = latent_factor + self.latent_factor_size = latent_factor_size + self.latent_factor_operation = latent_factor_operation + self.include_prior = include_prior + self.decoder_on = decoder_on + self.correct_library_size = correct_library_size + if self.decoder_on: + self.decoder = Decoder(1, self.num_genes, n_layers=2) + self.enumeration = "parallel" + + @beartype + def create_plates( + self, + c_obs: Optional[torch.Tensor] = None, + u_obs: Optional[torch.Tensor] = None, + s_obs: Optional[torch.Tensor] = None, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_log_library_loc: Optional[torch.Tensor] = None, + s_log_library_loc: Optional[torch.Tensor] = None, + u_log_library_scale: Optional[torch.Tensor] = None, + s_log_library_scale: Optional[torch.Tensor] = None, + ind_x: Optional[torch.Tensor] = None, + cell_state: Optional[torch.Tensor] = None, + time_info: Optional[torch.Tensor] = None, + ) -> Tuple[plate, plate]: + # Call the parent class method + cell_plate, gene_plate = super().create_plates( + u_obs=u_obs, + s_obs=s_obs, + u_log_library=u_log_library, + s_log_library=s_log_library, + u_log_library_loc=u_log_library_loc, + s_log_library_loc=s_log_library_loc, + u_log_library_scale=u_log_library_scale, + s_log_library_scale=s_log_library_scale, + ind_x=ind_x, + cell_state=cell_state, + time_info=time_info, + ) + # You can add any additional logic here if needed + return cell_plate, gene_plate + + def sample_cell_gene_state(self, t, switching): + return ( + pyro.sample( + "cell_gene_state", + Bernoulli(logits=t - switching), + infer={"enumerate": self.enumeration}, + ) + == self.zero + ) + + @beartype + def __repr__(self) -> str: + return ( + f"\nVelocityModelAuto(\n" + f"\tnum_cells={self.num_cells}, \n" + f"\tnum_genes={self.num_genes}, \n" + f'\tlikelihood="{self.likelihood}", \n' + f"\tshared_time={self.shared_time}, \n" + f"\tt_scale_on={self.t_scale_on}, \n" + f"\tplate_size={self.plate_size}, \n" + f'\tlatent_factor="{self.latent_factor}", \n' + f"\tlatent_factor_size={self.latent_factor_size}, \n" + f'\tlatent_factor_operation="{self.latent_factor_operation}", \n' + f"\tinclude_prior={self.include_prior}, \n" + f"\tnum_aux_cells={self.num_aux_cells}, \n" + f"\tonly_cell_times={self.only_cell_times}, \n" + f"\tdecoder_on={self.decoder_on}, \n" + f"\tadd_offset={self.add_offset}, \n" + f"\tcorrect_library_size={self.correct_library_size}, \n" + f'\tguide_type="{self.guide_type}", \n' + f"\tcell_specific_kinetics={self.cell_specific_kinetics}, \n" + f"\tkinetics_num={self.k}\n" + f")\n" + ) + + @jaxtyped(typechecker=beartype) + def get_atac_rna( + self, + u_scale: RNAInputType, + s_scale: RNAInputType, + t: Tensor, # cells, 1 + t0_1: Tensor, + dt_1: Tensor, + dt_2: Tensor, + dt_3: Tensor, + alpha_c: Tensor, + alpha: Tensor, + alpha_off: Tensor, + beta: Tensor, + gamma: Tensor, + ) -> Tuple[RNAOutputType, RNAOutputType, RNAOutputType]: + """ + Computes the unspliced (u) and spliced (s) RNA expression levels and chromatin opening state (c) given + the model parameters. + + Args: + u_scale (torch.Tensor): Scaling factor for unspliced expression. + s_scale (torch.Tensor): Scaling factor for spliced expression. + t (Tensor): The time of each cell. + t0_1 (Tensor): Start time for chromatin opening. + dt_1 (Tensor): Time gap since chromatin opening for transcription start for each gene. + dt_2 (Tensor): Time gap since transcription start for chromatin closing for each gene. + dt_3 (Tensor): Time gap since transcription start for transcription stopping for each gene. + alpha_c (Tensor): The chromatin opening and closing rate. + alpha (Tensor): The transcription rate of each gene in the on state. + alpha_off (Tensor): The transcription rate of each gene in the off state. + beta (torch.Tensor): Splicing rate. + gamma (torch.Tensor): Degradation rate. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The chromatin state (c), unspliced (u) and + spliced (s) RNA expression levels. + + + Examples: + >>> from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + >>> import torch + >>> n_cells = 4 + >>> u_scale = torch.tensor(1.0) + >>> s_scale = torch.tensor(1.0) + >>> t = torch.arange(0, 120, 30).reshape(n_cells,1) # cells, 1 + >>> t0_1 = torch.tensor((10.0, 10.0)) + >>> dt_1 = torch.tensor((15.0, 15.0)) + >>> dt_2 = torch.tensor((77.0, 53.0)) + >>> dt_3 = torch.tensor((50.0, 70.0)) + >>> alpha_c = torch.tensor((0.1, 0.2)) + >>> alpha = torch.tensor((0.5, 0.3)) + >>> alpha_off = torch.tensor(0.0) + >>> beta = torch.tensor((0.09, 0.11)) + >>> gamma = torch.tensor((0.05, 0.06)) + >>> mod = MultiVelocityModelAuto(num_cells = n_cells, num_genes = 2) + >>> output = MultiVelocityModelAuto.get_atac_rna( + mod, u_scale, s_scale, t, t0_1, dt_1, dt_2, dt_3, alpha_c, alpha, alpha_off, beta, gamma + ) + (tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), + tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), + tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + """ + + k = self.sample_cell_gene_state(t, t0_1) + + state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec = get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k, + ) + + c0_vec, u0_vec, s0_vec = get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + + ct, ut, st = atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + + ut = ut * u_scale / s_scale + return ct, ut, st + + @beartype + def forward( + self, + c_obs: torch.Tensor, + u_obs: torch.Tensor, + s_obs: torch.Tensor, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_log_library_loc: Optional[torch.Tensor] = None, + s_log_library_loc: Optional[torch.Tensor] = None, + u_log_library_scale: Optional[torch.Tensor] = None, + s_log_library_scale: Optional[torch.Tensor] = None, + ind_x: Optional[torch.Tensor] = None, + cell_state: Optional[torch.Tensor] = None, + time_info: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Defines the forward model, which computes the chromatin state (c), unspliced (u) and spliced + (s) RNA expression levels given the observations and model parameters. + + Args: + u_obs (Optional[torch.Tensor], optional): Observed unspliced RNA expression. Default is None. + s_obs (Optional[torch.Tensor], optional): Observed spliced RNA expression. Default is None. + c_obs (Optional[torch.Tensor], optional): Observed chromatin state. Default is None. + u_log_library (Optional[torch.Tensor], optional): Log-transformed library size for unspliced RNA. Default is None. + s_log_library (Optional[torch.Tensor], optional): Log-transformed library size for spliced RNA. Default is None. + u_log_library_loc (Optional[torch.Tensor], optional): Mean of log-transformed library size for unspliced RNA. Default is None. + s_log_library_loc (Optional[torch.Tensor], optional): Mean of log-transformed library size for spliced RNA. Default is None. + u_log_library_scale (Optional[torch.Tensor], optional): Scale of log-transformed library size for unspliced RNA. Default is None. + s_log_library_scale (Optional[torch.Tensor], optional): Scale of log-transformed library size for spliced RNA. Default is None. + ind_x (Optional[torch.Tensor], optional): Indices for the cells. Default is None. + cell_state (Optional[torch.Tensor], optional): Cell state information. Default is None. + time_info (Optional[torch.Tensor], optional): Time information for the cells. Default is None. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The chromatin state (c), unspliced (u) and spliced (s) RNA expression levels. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> u_obs=torch.tensor( + ... [[33., 1., 7., 1.], + ... [12., 30., 11., 3.], + ... [ 1., 1., 8., 5.]], + ... device="cpu", + >>> ) + >>> s_obs=torch.tensor( + ... [[32.0, 0.0, 6.0, 0.0], + ... [11.0, 29.0, 10.0, 2.0], + ... [0.0, 0.0, 7.0, 4.0]], + ... device="cpu", + >>> ) + >>> c_obs=torch.tensor( + ... [[1.0, 0.2, 0.4, 0.0], + ... [0.8, 0.2, 0.5, 0.3], + ... [0.0, 0.0, 0.1, 0.9]], + ... device="cpu", + >>> ) + >>> u_log_library=torch.tensor([[3.7377], [4.0254], [2.7081]], device="cpu") + >>> s_log_library=torch.tensor([[3.6376], [3.9512], [2.3979]], device="cpu") + >>> u_log_library_loc=torch.tensor([[3.4904], [3.4904], [3.4904]], device="cpu") + >>> s_log_library_loc=torch.tensor([[3.3289], [3.3289], [3.3289]], device="cpu") + >>> u_log_library_scale=torch.tensor([[0.6926], [0.6926], [0.6926]], device="cpu") + >>> s_log_library_scale=torch.tensor([[0.8214], [0.8214], [0.8214]], device="cpu") + >>> ind_x=torch.tensor([2, 0, 1], device="cpu") + >>> model = VelocityModelAuto(3,4) + >>> u, s = model.forward( + >>> u_obs, + >>> s_obs, + >>> u_log_library, + >>> s_log_library, + >>> u_log_library_loc, + >>> s_log_library_loc, + >>> u_log_library_scale, + >>> s_log_library_scale, + >>> ind_x, + >>> ) + >>> u, s + (tensor([[33., 1., 7., 1.], + [12., 30., 11., 3.], + [ 1., 1., 8., 5.]]), + tensor([[32., 0., 6., 0.], + [11., 29., 10., 2.], + [ 0., 0., 7., 4.]])) + """ + cell_plate, gene_plate = self.create_plates( + u_obs, + s_obs, + c_obs, + u_log_library, + s_log_library, + u_log_library_loc, + s_log_library_loc, + u_log_library_scale, + s_log_library_scale, + ind_x, + cell_state, + time_info, + ) + + with gene_plate, poutine.mask(mask=self.include_prior): + + alpha_c = pyro.sample("alpha_c", LogNormal(self.one, self.one)) + alpha = pyro.sample("alpha", LogNormal(self.one*20, self.one*10)) + gamma = pyro.sample("gamma", LogNormal(self.one*20, self.one*10)) + beta = pyro.sample("beta", LogNormal(self.one*20, self.one*10)) + alpha_off = self.zero + + t0_1 = pyro.sample("t0_1", Normal(self.zero, self.one*10)) + dt_1 = pyro.sample("dt_1", LogNormal(self.one*20, self.one*10)) + dt_2 = pyro.sample("dt_2", LogNormal(self.one*20, self.one*10)) + dt_3 = pyro.sample("dt_3", LogNormal(self.one*20, self.one*10)) + + u_scale = self.u_scale + s_scale = self.one + + with cell_plate: + t = pyro.sample( + "cell_time", + LogNormal(self.zero, self.one*50).mask(self.include_prior), + ) + + with cell_plate: + + u_cell_size_coef = ut_coef = s_cell_size_coef = st_coef = None + u_read_depth = pyro.sample( + "u_read_depth", LogNormal(u_log_library, u_log_library_scale) + ) + + s_read_depth = pyro.sample( + "s_read_depth", LogNormal(s_log_library, s_log_library_scale) + ) + + sigma_c = pyro.sample( + "sigma_c", LogNormal(0.2,0.2) + ) + + ct, ut, st = self.get_atac_rna( + u_scale, + s_scale, + t, # cells, 1 + t0_1, + dt_1, + dt_2, + dt_3, + alpha_c, + alpha, + alpha_off, + beta, + gamma) + + with gene_plate: + c_dist, u_dist, s_dist = self.get_likelihood_multiome( + ct, + ut, + st, + sigma_c, + u_log_library, + s_log_library, + u_scale, + s_scale, + u_read_depth=u_read_depth, + s_read_depth=s_read_depth, + u_cell_size_coef=u_cell_size_coef, + ut_coef=ut_coef, + s_cell_size_coef=s_cell_size_coef, + st_coef=st_coef, + ) + c = pyro.sample("c", c_dist, obs=c_obs) + u = pyro.sample("u", u_dist, obs=u_obs) + s = pyro.sample("s", s_dist, obs=s_obs) + return c, u, s diff --git a/src/pyrovelocity/models/_velocity_module.py b/src/pyrovelocity/models/_velocity_module.py index 16a75fb99..814fd1a0d 100644 --- a/src/pyrovelocity/models/_velocity_module.py +++ b/src/pyrovelocity/models/_velocity_module.py @@ -12,7 +12,7 @@ from scvi.module.base import PyroBaseModuleClass from pyrovelocity.logging import configure_logging -from pyrovelocity.models._velocity_model import VelocityModelAuto +from pyrovelocity.models._velocity_model import VelocityModelAuto, MultiVelocityModelAuto logger = configure_logging(__name__) @@ -242,3 +242,225 @@ def _get_fn_args_from_batch( cell_state, time_info, ), {} + +class MultiVelocityModule(PyroBaseModuleClass): + """ + VelocityModule is an scvi-tools pyro module that combines the VelocityModelAuto and pyro AutoGuideList classes. + + Args: + num_cells (int): Number of cells. + num_genes (int): Number of genes. + model_type (str, optional): Model type. Default is "auto". + guide_type (str, optional): Guide type. Default is "velocity_auto". + likelihood (str, optional): Likelihood type. Default is "Poisson". + shared_time (bool, optional): If True, a shared time parameter will be used. Default is True. + t_scale_on (bool, optional): If True, scale time parameter. Default is False. + plate_size (int, optional): Size of the plate set. Default is 2. + latent_factor (str, optional): Latent factor. Default is "none". + latent_factor_operation (str, optional): Latent factor operation mode. Default is "selection". + latent_factor_size (int, optional): Size of the latent factor. Default is 10. + inducing_point_size (int, optional): Inducing point size. Default is 0. + include_prior (bool, optional): If True, include prior in the model. Default is False. + use_gpu (str, optional): Accelerator type. Default is "auto". + num_aux_cells (int, optional): Number of auxiliary cells. Default is 0. + only_cell_times (bool, optional): If True, only model cell times. Default is True. + decoder_on (bool, optional): If True, use the decoder. Default is False. + add_offset (bool, optional): If True, add offset to the model. Default is True. + correct_library_size (Union[bool, str], optional): Library size correction method. Default is True. + cell_specific_kinetics (Optional[str], optional): Cell-specific kinetics method. Default is None. + kinetics_num (Optional[int], optional): Number of kinetics. Default is None. + **initial_values: Initial values for the model parameters. + + Examples: + >>> from scvi.module.base import PyroBaseModuleClass + >>> from pyrovelocity.models._velocity_module import VelocityModule + >>> num_cells = 10 + >>> num_genes = 20 + >>> velocity_module1 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto_t0_constraint", add_offset=False + ... ) + >>> type(velocity_module1.model) + + >>> type(velocity_module1.guide) + + >>> velocity_module2 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto", add_offset=True + ... ) + >>> type(velocity_module2.model) + + >>> type(velocity_module2.guide) + + """ + + def __init__( + self, + num_cells: int, + num_genes: int, + model_type: str = "auto", + guide_type: str = "velocity_auto", + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + latent_factor_size: int = 10, + inducing_point_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = True, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + super().__init__() + self.num_cells = num_cells + self.num_genes = num_genes + self.model_type = model_type + self.guide_type = guide_type + self._model = None + self.plate_size = plate_size + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + logger.info( + f"Model type: {self.model_type}, Guide type: {self.guide_type}" + ) + + self.cell_specific_kinetics = cell_specific_kinetics + + self._model = MultiVelocityModelAuto( + self.num_cells, + self.num_genes, + likelihood, + shared_time, + t_scale_on, + self.plate_size, + latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + include_prior=include_prior, + num_aux_cells=num_aux_cells, + only_cell_times=self.only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + guide_type=self.guide_type, + cell_specific_kinetics=self.cell_specific_kinetics, + **initial_values, + ) + + guide = AutoGuideList( + self._model, create_plates=self._model.create_plates + ) + guide.append( + AutoNormal( + poutine.block( + self._model, + expose=[ + "cell_time", + "u_read_depth", + "s_read_depth", + "kinetics_prob", + "kinetics_weights", + ], + ), + init_scale=0.1, + ) + ) + + if add_offset: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "dt_switching", + "t0", + "u_scale", + "s_scale", + "u_offset", + "s_offset", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + else: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "dt_switching", + "t0", + "u_scale", + "s_scale", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + self._guide = guide + + @property + def model(self) -> VelocityModelAuto: + return self._model + + @property + def guide(self) -> AutoGuideList: + return self._guide + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Tuple[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + Dict[Any, Any], + ]: + u_obs = tensor_dict["U"] + s_obs = tensor_dict["X"] + c_obs = tensor_dict['atac'] + u_log_library = tensor_dict["u_lib_size"] + s_log_library = tensor_dict["s_lib_size"] + u_log_library_mean = tensor_dict["u_lib_size_mean"] + s_log_library_mean = tensor_dict["s_lib_size_mean"] + u_log_library_scale = tensor_dict["u_lib_size_scale"] + s_log_library_scale = tensor_dict["s_lib_size_scale"] + ind_x = tensor_dict["ind_x"].long().squeeze() + cell_state = tensor_dict.get("pyro_cell_state") + time_info = tensor_dict.get("time_info") + return ( + c_obs, + u_obs, + s_obs, + u_log_library, + s_log_library, + u_log_library_mean, + s_log_library_mean, + u_log_library_scale, + s_log_library_scale, + ind_x, + cell_state, + time_info, + ), {} diff --git a/src/pyrovelocity/tasks/train.py b/src/pyrovelocity/tasks/train.py index 2998fc2c8..926815d12 100644 --- a/src/pyrovelocity/tasks/train.py +++ b/src/pyrovelocity/tasks/train.py @@ -2,9 +2,7 @@ import os import uuid from pathlib import Path -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Dict, Optional, Tuple import matplotlib.pyplot as plt import mlflow @@ -21,7 +19,6 @@ from pyrovelocity.tasks.data import load_anndata_from_path from pyrovelocity.utils import print_anndata - logger = configure_logging(__name__) @@ -279,6 +276,7 @@ def check_shared_time(posterior_samples, adata): @beartype def train_model( adata: str | AnnData, + adata_atac: Optional[AnnData] = None, guide_type: str = "auto", model_type: str = "auto", batch_size: int = -1, @@ -305,6 +303,7 @@ def train_model( Args: adata (str | AnnData): Path to a file that can be read to an AnnData object or an AnnData object. + adata_atac (Optional[AnnData], optional): An anndata object with atac data, matching the default adata input with RNA data. guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto". model_type (str, optional): The type of Pyro model. Default is "auto". batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset. @@ -347,16 +346,18 @@ def train_model( >>> copy_raw_counts(adata) >>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path) """ + if isinstance(adata, str): adata = load_anndata_from_path(adata) logger.info(f"AnnData object prior to model training") print_anndata(adata) - PyroVelocity.setup_anndata(adata) + PyroVelocity.setup_anndata(adata, adata_atac = adata_atac) model = PyroVelocity( adata, + adata_atac = adata_atac, likelihood=likelihood, model_type=model_type, guide_type=guide_type, diff --git a/src/pyrovelocity/tests/models/__init__.py b/src/pyrovelocity/tests/models/__init__.py new file mode 100644 index 000000000..555adbe72 --- /dev/null +++ b/src/pyrovelocity/tests/models/__init__.py @@ -0,0 +1 @@ +"""Unit test package for pyrovelocity.""" diff --git a/src/pyrovelocity/tests/models/test_transcription_dynamics.py b/src/pyrovelocity/tests/models/test_transcription_dynamics.py new file mode 100644 index 000000000..332e85f4f --- /dev/null +++ b/src/pyrovelocity/tests/models/test_transcription_dynamics.py @@ -0,0 +1,129 @@ +"""Tests for _transcription_dynamics_ functions.""" + +from pyrovelocity.models._transcription_dynamics import ( + atac_mrna_dynamics, + get_cell_parameters, + get_initial_states +) + +def test_get_cell_parameters(): + import torch + + n_cells = 4 + t = torch.arange(0, 120, 30).reshape(n_cells, 1) + t0_1 = torch.tensor((10.0, 10.0)) + dt_1 = torch.tensor((15.0, 15.0)) + dt_2 = torch.tensor((77.0, 53.0)) + dt_3 = torch.tensor((50.0, 70.0)) + alpha = torch.tensor((0.5, 0.3)) + alpha_off = torch.tensor(0.0) + k = torch.tensor((1.0, 1.0),(1.0,1.0)) + + output = get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k + ) + + correct_output = (torch.tensor([[0, 0], + [2, 2], + [2, 2], + [3, 3]]),torch.tensor([[0., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]), torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]), torch.tensor([[ 0., 10., 25., 75., 102.], + [ 0., 10., 25., 78., 95.]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + +def test_get_initial_states(): + import torch + + alpha_c = torch.tensor((0.1, 0.2)) + beta = torch.tensor((0.09, 0.11)) + gamma = torch.tensor((0.05, 0.06)) + state = torch.tensor([[0, 0],[2, 2],[2, 2],[3, 3]]) + k_c_state = torch.tensor([[0., 1., 1., 0., 0.], [0., 1., 1., 1., 0.]]) + alpha_state = torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000],[0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]) + t0_state = torch.tensor([[ 0., 10., 25., 75., 102.],[ 0., 10., 25., 78., 95.]]) + + output = get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + + correct_output = (torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + + +def test_atac_mrna_dynamics(): + import torch + + input = [torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]]), torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([0.1000, 0.2000]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([0.0900, 0.1100]), torch.tensor([0.0500, 0.0600])] + + tau_vec = input[0] + c0_vec = input[1] + u0_vec = input[2] + s0_vec = input[3] + k_c_vec = input[4] + alpha_c = input[5] + alpha_vec = input[6] + beta = input[7] + gamma = input[8] + + output = atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + + correct_output = (torch.tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), torch.tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + + diff --git a/src/pyrovelocity/tests/models/test_velocity_model.py b/src/pyrovelocity/tests/models/test_velocity_model.py new file mode 100644 index 000000000..ac624c5da --- /dev/null +++ b/src/pyrovelocity/tests/models/test_velocity_model.py @@ -0,0 +1,45 @@ +"""Tests for _velocity_model.py""" + +def test_MultiVelocityModelAuto(): + from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + +def test_MultiVelocityModelAuto_get_atac_rna(): + from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + import torch + + n_cells = 4 + u_scale = torch.tensor(1.0) + s_scale = torch.tensor(1.0) + t = torch.arange(0, 120, 30).reshape(n_cells,1) # cells, 1 + t0_1 = torch.tensor((10.0, 10.0)) + dt_1 = torch.tensor((15.0, 15.0)) + dt_2 = torch.tensor((77.0, 53.0)) + dt_3 = torch.tensor((50.0, 70.0)) + alpha_c = torch.tensor((0.1, 0.2)) + alpha = torch.tensor((0.5, 0.3)) + alpha_off = torch.tensor(0.0) + beta = torch.tensor((0.09, 0.11)) + gamma = torch.tensor((0.05, 0.06)) + + mod = MultiVelocityModelAuto(num_cells = n_cells, num_genes = 2) + output = MultiVelocityModelAuto.get_atac_rna( + mod, u_scale, s_scale, t, t0_1, dt_1, dt_2, dt_3, alpha_c, alpha, alpha_off, beta, gamma + ) + + correct_output = ((torch.tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), + torch.tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]]))) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + + diff --git a/src/pyrovelocity/tests/tasks/__init__.py b/src/pyrovelocity/tests/tasks/__init__.py new file mode 100644 index 000000000..555adbe72 --- /dev/null +++ b/src/pyrovelocity/tests/tasks/__init__.py @@ -0,0 +1 @@ +"""Unit test package for pyrovelocity.""" diff --git a/src/pyrovelocity/tests/tasks/test_train_model.py b/src/pyrovelocity/tests/tasks/test_train_model.py new file mode 100644 index 000000000..7f516485a --- /dev/null +++ b/src/pyrovelocity/tests/tasks/test_train_model.py @@ -0,0 +1,20 @@ +"""Tests for `pyrovelocity._train_model` task.""" + +from pyrovelocity.tasks.preprocess import copy_raw_counts +from pyrovelocity.tasks.train import train_model +from pyrovelocity.utils import generate_sample_data + + +def test_train_model(tmp_path): + loss_plot_path = str(tmp_path) + "/loss_plot_docs.png" + print(loss_plot_path) + adata = generate_sample_data(random_seed=99) + copy_raw_counts(adata) + _, model, posterior_samples = train_model( + adata, + adata_atac=None, + use_gpu="auto", + seed=99, + max_epochs=200, + loss_plot_path=loss_plot_path, + )