Skip to content

Commit

Permalink
Passing callbacks example (#68)
Browse files Browse the repository at this point in the history
* initial draft

* sloving with different callbacks

* callback example

---------

Co-authored-by: Arina Danilina <[email protected]>
  • Loading branch information
ArinaDanilina and Arina Danilina authored Mar 26, 2024
1 parent 47ec2fb commit 0dbb64a
Showing 1 changed file with 374 additions and 0 deletions.
374 changes: 374 additions & 0 deletions examples/problems/1100_callback.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Passing callbacks in {meth}`~moscot.problems.spatial.AlignmentProblem.prepare`\n",
"\n",
"In this example, we show how to use different callbacks.\n",
"\n",
"The `callback` argument that states which computation should be run on {attr}`~anndata.AnnData.X` to get the joint cost.\n",
"\n",
":::{seealso}\n",
"- See {doc}`200_custom_cost_matrices` for an on how to use custom matrices and pass `joint_attr`, `x_attr` and `y_attr` in the {meth}`~moscot.problems.generic.FGWProblem.prepare` method.\n",
":::"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports and data loading"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n",
"\n",
"from moscot import datasets\n",
"from moscot.problems.space import AlignmentProblem"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AnnData object with n_obs × n_vars = 1200 × 500\n",
" obs: 'batch'\n",
" uns: 'batch_colors'\n",
" obsm: 'spatial'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"adata = datasets.sim_align()\n",
"adata"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When `normalize_spatial=True` is passed, the spatial coordinates are normalized by standardizing them."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `x`. \n",
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `y`. \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n",
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `x`. \n",
"\u001b[34mINFO \u001b[0m Normalizing spatial coordinates of `y`. \n"
]
}
],
"source": [
"ap = AlignmentProblem(adata=adata)\n",
"ap = ap.prepare(batch_key=\"batch\", policy=\"sequential\", normalize_spatial=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ap[(\"1\", \"2\")].x.data_src.std() == 1"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
}
],
"source": [
"ap = ap.prepare(batch_key=\"batch\", policy=\"sequential\", normalize_spatial=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ap[(\"1\", \"2\")].x.data_src.std() == 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can pass `callback=\"local-pca\"` to run on {attr}`~anndata.AnnData.X` to get the joint cost."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
}
],
"source": [
"ap = ap.prepare(\n",
" batch_key=\"batch\",\n",
" policy=\"sequential\",\n",
" normalize_spatial=False,\n",
" callback=\"local-pca\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m2\u001b[0m` problems \n",
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n"
]
},
{
"data": {
"text/plain": [
"{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],\n",
" ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ap.solve()\n",
"ap.solutions"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
}
],
"source": [
"ap = ap.prepare(\n",
" batch_key=\"batch\",\n",
" policy=\"sequential\",\n",
" normalize_spatial=False,\n",
" callback=\"spatial-norm\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m2\u001b[0m` problems \n",
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n",
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n"
]
},
{
"data": {
"text/plain": [
"{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],\n",
" ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ap.solve()\n",
"ap.solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use graphs, we can either pass `callback=\"graph-construction\"`"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
}
],
"source": [
"ap = ap.prepare(\n",
" batch_key=\"batch\",\n",
" policy=\"sequential\",\n",
" normalize_spatial=False,\n",
" callback=\"graph-construction\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m2\u001b[0m` problems \n",
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n",
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m400\u001b[0m, \u001b[1;36m400\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n"
]
},
{
"data": {
"text/plain": [
"{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],\n",
" ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ap.solve()\n",
"ap.solutions"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "moscot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 0dbb64a

Please sign in to comment.