Skip to content

Commit

Permalink
fgw instead of gw
Browse files Browse the repository at this point in the history
  • Loading branch information
Arina Danilina committed Mar 14, 2024
1 parent b89576b commit c1b9957
Showing 1 changed file with 35 additions and 64 deletions.
99 changes: 35 additions & 64 deletions examples/problems/200_custom_cost_matrices.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"warnings.simplefilter(\"ignore\", FutureWarning)\n",
"\n",
"from moscot import datasets\n",
"from moscot.problems.generic import GWProblem\n",
"from moscot.problems.generic import FGWProblem\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -110,22 +110,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n",
" \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n",
" \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
},
{
"data": {
"text/plain": [
"GWProblem[('1', '2'), ('0', '1')]"
"FGWProblem[('1', '2'), ('0', '1')]"
]
},
"execution_count": 3,
Expand All @@ -134,9 +126,9 @@
}
],
"source": [
"gwp = GWProblem(adata)\n",
"gwp = gwp.prepare(key=\"batch\", x_attr=\"X_pca\", y_attr=\"X_pca\")\n",
"gwp"
"fgw = FGWProblem(adata)\n",
"fgw = fgw.prepare(key=\"batch\", x_attr=\"X_pca\", y_attr=\"X_pca\")\n",
"fgw"
]
},
{
Expand Down Expand Up @@ -168,8 +160,8 @@
"outputs": [],
"source": [
"rng = np.random.default_rng(seed=42)\n",
"obs_names_0 = gwp[\"0\", \"1\"].adata_src.obs_names\n",
"obs_names_1 = gwp[\"0\", \"1\"].adata_tgt.obs_names\n",
"obs_names_0 = fgw[\"0\", \"1\"].adata_src.obs_names\n",
"obs_names_1 = fgw[\"0\", \"1\"].adata_tgt.obs_names\n",
"\n",
"cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))\n",
"cost_quad_0 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_0))))\n",
Expand Down Expand Up @@ -198,9 +190,9 @@
"metadata": {},
"outputs": [],
"source": [
"gwp[\"0\", \"1\"].set_xy(cm_linear, tag=\"cost_matrix\")\n",
"gwp[\"0\", \"1\"].set_x(cm_quad_0, tag=\"cost_matrix\")\n",
"gwp[\"0\", \"1\"].set_y(cm_quad_1, tag=\"cost_matrix\")"
"fgw[\"0\", \"1\"].set_xy(cm_linear, tag=\"cost_matrix\")\n",
"fgw[\"0\", \"1\"].set_x(cm_quad_0, tag=\"cost_matrix\")\n",
"fgw[\"0\", \"1\"].set_y(cm_quad_1, tag=\"cost_matrix\")"
]
},
{
Expand All @@ -211,7 +203,7 @@
"source": [
"When solving the problem, the custom cost matrices will be used for the\n",
"problem mapping from batch `'0'` to batch `'1'`, while the problem mapping from batch `'1'` to\n",
"batch `'2'` is still using the information passed in {meth}`~moscot.problems.generic.GWProblem.prepare`."
"batch `'2'` is still using the information passed in {meth}`~moscot.problems.generic.FGWProblem.prepare`."
]
},
{
Expand Down Expand Up @@ -253,7 +245,7 @@
}
],
"source": [
"obs_names_2 = gwp[\"1\", \"2\"].adata_tgt.obs_names\n",
"obs_names_2 = fgw[\"1\", \"2\"].adata_tgt.obs_names\n",
"\n",
"cost_linear_12 = np.abs(rng.normal(size=(len(obs_names_1), len(obs_names_2))))\n",
"cost_quad_2 = np.abs(rng.normal(size=(len(obs_names_2), len(obs_names_2))))\n",
Expand Down Expand Up @@ -286,7 +278,7 @@
"metadata": {},
"source": [
"We need to specify where to fetch the custom cost matrices in the\n",
"{meth}`~moscot.problems.generic.GWProblem.prepare` methods. If we want to only\n",
"{meth}`~moscot.problems.generic.FGWProblem.prepare` methods. If we want to only\n",
"use the linear custom cost matrix, we need to modify the `joint_attr` as follows:"
]
},
Expand All @@ -295,25 +287,10 @@
"execution_count": 7,
"id": "4b99031c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n",
" \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n",
" \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n"
]
}
],
"outputs": [],
"source": [
"joint_attr = {\"key\": \"cost_matrices\", \"tag\": \"cost_matrix\"}\n",
"gwp = gwp.prepare(key=\"batch\", joint_attr=joint_attr, x_attr=\"X_pca\", y_attr=\"X_pca\")"
"fgw = fgw.prepare(key=\"batch\", joint_attr=joint_attr, x_attr=\"X_pca\", y_attr=\"X_pca\")"
]
},
{
Expand All @@ -332,18 +309,14 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n",
" \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n",
" \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n"
]
"data": {
"text/plain": [
"OTProblem[stage='prepared', shape=(20, 20)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -359,7 +332,8 @@
" \"tag\": \"cost_matrix\",\n",
" \"cost\": \"custom\",\n",
"}\n",
"gwp = gwp.prepare(key=\"batch\", joint_attr=\"X_pca\", x_attr=x_attr, y_attr=y_attr)"
"fgw = fgw.prepare(key=\"batch\", joint_attr=\"X_pca\", x_attr=x_attr, y_attr=y_attr)\n",
"fgw[(\"0\", \"1\")]"
]
},
{
Expand All @@ -378,22 +352,19 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n",
" \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n",
" \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n"
]
"data": {
"text/plain": [
"OTProblem[stage='prepared', shape=(20, 20)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gwp = gwp.prepare(key=\"batch\", joint_attr=joint_attr, x_attr=x_attr, y_attr=y_attr)"
"fgw = fgw.prepare(key=\"batch\", joint_attr=joint_attr, x_attr=x_attr, y_attr=y_attr)\n",
"fgw[(\"1\", \"2\")]"
]
}
],
Expand Down

0 comments on commit c1b9957

Please sign in to comment.