generated from theislab/sc_analysis_template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial draft * sloving with different callbacks * callback example --------- Co-authored-by: Arina Danilina <[email protected]>
- Loading branch information
1 parent
47ec2fb
commit 0dbb64a
Showing
1 changed file
with
374 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,374 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Passing callbacks in {meth}`~moscot.problems.spatial.AlignmentProblem.prepare`\n", | ||
"\n", | ||
"In this example, we show how to use different callbacks.\n", | ||
"\n", | ||
"The `callback` argument that states which computation should be run on {attr}`~anndata.AnnData.X` to get the joint cost.\n", | ||
"\n", | ||
":::{seealso}\n", | ||
"- See {doc}`200_custom_cost_matrices` for an on how to use custom matrices and pass `joint_attr`, `x_attr` and `y_attr` in the {meth}`~moscot.problems.generic.FGWProblem.prepare` method.\n", | ||
":::" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Imports and data loading" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import warnings\n", | ||
"\n", | ||
"warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n", | ||
"\n", | ||
"from moscot import datasets\n", | ||
"from moscot.problems.space import AlignmentProblem" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"AnnData object with n_obs × n_vars = 1200 × 500\n", | ||
" obs: 'batch'\n", | ||
" uns: 'batch_colors'\n", | ||
" obsm: 'spatial'" | ||
] | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"adata = datasets.sim_align()\n", | ||
"adata" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"When `normalize_spatial=True` is passed, the spatial coordinates are normalized by standardizing them." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `x`. \n", | ||
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `y`. \n", | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n", | ||
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `x`. \n", | ||
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `y`. \n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"ap = AlignmentProblem(adata=adata)\n", | ||
"ap = ap.prepare(batch_key=\"batch\", policy=\"sequential\", normalize_spatial=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"True" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ap[(\"1\", \"2\")].x.data_src.std() == 1" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"ap = ap.prepare(batch_key=\"batch\", policy=\"sequential\", normalize_spatial=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"False" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ap[(\"1\", \"2\")].x.data_src.std() == 1" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can pass `callback=\"local-pca\"` to run on {attr}`~anndata.AnnData.X` to get the joint cost." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n", | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"ap = ap.prepare(\n", | ||
" batch_key=\"batch\",\n", | ||
" policy=\"sequential\",\n", | ||
" normalize_spatial=False,\n", | ||
" callback=\"local-pca\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m2\u001b[0m` problems \n", | ||
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],\n", | ||
" ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ap.solve()\n", | ||
"ap.solutions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n", | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"ap = ap.prepare(\n", | ||
" batch_key=\"batch\",\n", | ||
" policy=\"sequential\",\n", | ||
" normalize_spatial=False,\n", | ||
" callback=\"spatial-norm\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m2\u001b[0m` problems \n", | ||
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", | ||
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],\n", | ||
" ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ap.solve()\n", | ||
"ap.solutions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use graphs, we can either pass `callback=\"graph-construction\"`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n", | ||
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"ap = ap.prepare(\n", | ||
" batch_key=\"batch\",\n", | ||
" policy=\"sequential\",\n", | ||
" normalize_spatial=False,\n", | ||
" callback=\"graph-construction\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m2\u001b[0m` problems \n", | ||
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", | ||
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],\n", | ||
" ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}" | ||
] | ||
}, | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ap.solve()\n", | ||
"ap.solutions" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "moscot", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |