Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train_model): add atac layer argument #542

Draft
wants to merge 23 commits into
base: beta
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9d2f67d
added atac_layer argument to train_model tasks and made tests for it
AlexanderAivazidis Jul 8, 2024
81ed87c
fix(vscode): disable automatic python env activation
AlexanderAivazidis Jul 8, 2024
f16cc6d
corrected ATAC argument type from layer name to anndata object
AlexanderAivazidis Jul 10, 2024
112b44d
feat[PyroVelocity]: added atac_data to setup_anndata method
AlexanderAivazidis Jul 10, 2024
7b78188
fix[PyroVelocity]: missing colon in Args description
AlexanderAivazidis Jul 10, 2024
a36cae7
feat(VelocityTrainingMixin): Added atac data to train_faster method
AlexanderAivazidis Jul 10, 2024
cfee317
feat(_velocity_module): Added a MultiVelocityModule for multiome data.
AlexanderAivazidis Jul 10, 2024
365ba7b
feat(_velocity_model): Added a MultiVelocityModelAuto class for multi…
AlexanderAivazidis Jul 11, 2024
a857453
feat(_transcription_dynamics): Added function for multiome dynamics.
AlexanderAivazidis Jul 11, 2024
5a4dbbe
feat(_test_transcription_dynamics): Unit tests for transcription dyna…
AlexanderAivazidis Jul 22, 2024
8368ce0
feat(_test_velocity_model): Unit tests for velocity model
AlexanderAivazidis Jul 22, 2024
b5508f9
feat(.gitignore): Added example_notebooks directory to .gitignore file.
AlexanderAivazidis Jul 22, 2024
fe53c90
fix[_trainer_]: checking for existence of atac data
AlexanderAivazidis Jul 22, 2024
e208085
fix(_velocity): Save existence of atac data in adata
AlexanderAivazidis Jul 22, 2024
e349bb7
fix(_velocity_model): LogNormal instead of Normal likelihood for atac…
AlexanderAivazidis Jul 22, 2024
21a83ac
fix(_trainer.py): Properly processing atac data
AlexanderAivazidis Jul 26, 2024
7a04607
fix(_transcription_dynamics): Ensuring no inplace tensor operations i…
AlexanderAivazidis Jul 26, 2024
d5dcc98
fix(_velocity): Handling atac data.
AlexanderAivazidis Jul 26, 2024
d9f0636
fix(_velocity_module): Removed rates from multivariateNormalGuide, be…
AlexanderAivazidis Jul 26, 2024
c03a247
fix(train): Ensure atac data is handled properly.
AlexanderAivazidis Jul 26, 2024
637588a
feat(_transcription_dynamics): Added latent discrete parameter for mo…
AlexanderAivazidis Jul 26, 2024
1e096dd
feat(_velocity_model): Sampling latent discrete parameter for modelli…
AlexanderAivazidis Jul 26, 2024
54cb35c
feat(_test_transcription_dynamics): Adapted test to include latent di…
AlexanderAivazidis Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#
/archive/
example_notebooks/*

#
.DS_Store
Expand Down
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"search.followSymlinks": false,
"terminal.integrated.fontSize": 14,
"terminal.integrated.scrollback": 100000,
"python.terminal.activateEnvironment": false,
"workbench.colorTheme": "Catppuccin Mocha",
"workbench.iconTheme": "vscode-icons",
// Passing --no-cov to pytestArgs is required to respect breakpoints
Expand Down
3 changes: 2 additions & 1 deletion src/pyrovelocity/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from pyrovelocity.models._deterministic_simulation import (
solve_transcription_splicing_model_analytical,
)
from pyrovelocity.models._transcription_dynamics import mrna_dynamics
from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics
from pyrovelocity.models._velocity import PyroVelocity


__all__ = [
deterministic_transcription_splicing_probabilistic_model,
mrna_dynamics,
atac_mrna_dynamics,
PyroVelocity,
solve_transcription_splicing_model,
solve_transcription_splicing_model_analytical,
Expand Down
163 changes: 115 additions & 48 deletions src/pyrovelocity/models/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def train_faster(
if scipy.sparse.issparse(self.adata.layers["raw_spliced"])
else self.adata.layers["raw_spliced"],
dtype=torch.float32,
).to(device)

).to(device)
epsilon = 1e-6

log_u_library_size = np.log(
Expand Down Expand Up @@ -335,60 +335,127 @@ def train_faster(

losses = []
patience = patient_init
for step in range(max_epochs):
if cell_state is None:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
None,

if not self.adata.uns['atac']:

for step in range(max_epochs):
if cell_state is None:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
None,
)
/ normalizer
)
/ normalizer
)
else:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
cell_state.reshape(-1, 1),
else:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
cell_state.reshape(-1, 1),
)
/ normalizer
)
if (step == 0) or (
((step + 1) % log_every == 0) and ((step + 1) < max_epochs)
):
mlflow.log_metric("-ELBO", -elbos, step=step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
if step > log_every:
if (losses[-1] - elbos) < losses[-1] * patient_improve:
patience -= 1
else:
patience = patient_init
if patience <= 0:
break
losses.append(elbos)

else:

c = torch.tensor(
np.array(
self.adata.layers["atac"].toarray(), dtype="float32"
)
if scipy.sparse.issparse(self.adata.layers["atac"])
else self.adata.layers["atac"],
dtype=torch.float32,
).to(device)


for step in range(max_epochs):
if cell_state is None:
elbos = (
svi.step(
c,
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
None,
)
/ normalizer
)
/ normalizer
)
if (step == 0) or (
((step + 1) % log_every == 0) and ((step + 1) < max_epochs)
):
mlflow.log_metric("-ELBO", -elbos, step=step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
if step > log_every:
if (losses[-1] - elbos) < losses[-1] * patient_improve:
patience -= 1
else:
patience = patient_init
if patience <= 0:
break
losses.append(elbos)
elbos = (
svi.step(
c,
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
cell_state.reshape(-1, 1),
)
/ normalizer
)
if (step == 0) or (
((step + 1) % log_every == 0) and ((step + 1) < max_epochs)
):
mlflow.log_metric("-ELBO", -elbos, step=step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
if step > log_every:
if (losses[-1] - elbos) < losses[-1] * patient_improve:
patience -= 1
else:
patience = patient_init
if patience <= 0:
break
losses.append(elbos)

mlflow.log_metric("-ELBO", -elbos, step=step + 1)
mlflow.log_metric("real_epochs", step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
return losses
return losses

def train_faster_with_batch(
self,
Expand Down
Loading
Loading