diff --git a/examples/problems/100_tagged_arrays.ipynb b/examples/problems/100_tagged_arrays.ipynb index 15eec8e..8fd3879 100644 --- a/examples/problems/100_tagged_arrays.ipynb +++ b/examples/problems/100_tagged_arrays.ipynb @@ -36,6 +36,10 @@ "metadata": {}, "outputs": [], "source": [ + "import warnings\n", + "\n", + "warnings.simplefilter(\"ignore\", FutureWarning)\n", + "\n", "from moscot import datasets\n", "from moscot.problems.generic import GWProblem\n", "\n", @@ -58,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "dec86f59", "metadata": {}, "outputs": [ @@ -72,7 +76,7 @@ " varm: 'PCs'" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -96,13 +100,26 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "16f0c3a9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n" + ] + } + ], "source": [ "gwp = GWProblem(adata)\n", - "gwp = gwp.prepare(key=\"batch\", joint_attr=\"X_pca\", GW_x=\"X_pca\", GW_y=\"X_pca\")" + "gwp = gwp.prepare(key=\"batch\", x_attr=\"X_pca\", y_attr=\"X_pca\", joint_attr=\"X_pca\")" ] }, { @@ -116,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "e4ca2aea", "metadata": {}, "outputs": [ @@ -137,10 +154,10 @@ " [ 2.024, -1.597, -0.591, ..., -0.114, 0.29 , -0.332],\n", " [-0.938, 2.426, -0.128, ..., 0.194, 0.829, -0.438],\n", " [ 2.709, -2.885, -0.925, ..., 0.315, 0.334, 0.078]],\n", - " dtype=float32), tag='point_cloud', cost=)" + " dtype=float32), tag=, cost=)" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -162,17 +179,18 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "7b70c639", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "('point_cloud', )" + "(,\n", + " )" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -198,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "e50ce79b", "metadata": {}, "outputs": [ @@ -223,7 +241,7 @@ " dtype=float32))" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -245,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "e3d67c03", "metadata": {}, "outputs": [ @@ -263,7 +281,7 @@ " None)" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -293,14 +311,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "9b1a71ba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "('point_cloud',\n", + "(,\n", " ArrayView([[-1.076, -3.572, -0.315, ..., 0.492, -0.226, 0.194],\n", " [ 0.852, -1.033, 3.834, ..., 0.252, -0.115, 0.243],\n", " [-0.411, -2.689, -1.863, ..., -0.347, 0.601, 0.005],\n", @@ -311,7 +329,7 @@ " dtype=float32))" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -331,17 +349,17 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "a723c855", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "('cost_matrix', None)" + "(, None)" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -369,7 +387,7 @@ "kernelspec": { "display_name": "moscot", "language": "python", - "name": "moscot" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -381,7 +399,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/problems/200_custom_cost_matrices.ipynb b/examples/problems/200_custom_cost_matrices.ipynb index 6f2a751..3f648bf 100644 --- a/examples/problems/200_custom_cost_matrices.ipynb +++ b/examples/problems/200_custom_cost_matrices.ipynb @@ -39,6 +39,10 @@ "metadata": {}, "outputs": [], "source": [ + "import warnings\n", + "\n", + "warnings.simplefilter(\"ignore\", FutureWarning)\n", + "\n", "from moscot import datasets\n", "from moscot.problems.generic import GWProblem\n", "\n", @@ -102,6 +106,22 @@ "id": "36ef5120", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n", + " \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n", + " \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n", + "\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n", + "\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" + ] + }, { "data": { "text/plain": [ @@ -115,7 +135,7 @@ ], "source": [ "gwp = GWProblem(adata)\n", - "gwp = gwp.prepare(key=\"batch\", joint_attr=\"X_pca\", GW_x=\"X_pca\", GW_y=\"X_pca\")\n", + "gwp = gwp.prepare(key=\"batch\", x_attr=\"X_pca\", y_attr=\"X_pca\")\n", "gwp" ] }, @@ -275,10 +295,25 @@ "execution_count": 7, "id": "4b99031c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n", + " \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n", + " \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n" + ] + } + ], "source": [ "joint_attr = {\"key\": \"cost_matrices\", \"tag\": \"cost_matrix\"}\n", - "gwp = gwp.prepare(key=\"batch\", joint_attr=joint_attr, GW_x=\"X_pca\", GW_y=\"X_pca\")" + "gwp = gwp.prepare(key=\"batch\", joint_attr=joint_attr, x_attr=\"X_pca\", y_attr=\"X_pca\")" ] }, { @@ -287,7 +322,7 @@ "id": "b3ed4f02", "metadata": {}, "source": [ - "If we want to use only quadratic custom cost matrices, we need to modify `GW_x` and `GW_y`." + "If we want to use only quadratic custom cost matrices, we need to modify `x_attr` and `y_attr`." ] }, { @@ -295,11 +330,36 @@ "execution_count": 8, "id": "7c6f8eb6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n", + " \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n", + " \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n" + ] + } + ], "source": [ - "GW_x = {\"attr\": \"obsp\", \"key\": \"cost_matrices\", \"tag\": \"cost_matrix\", \"cost\": \"custom\"}\n", - "GW_y = {\"attr\": \"obsp\", \"key\": \"cost_matrices\", \"tag\": \"cost_matrix\", \"cost\": \"custom\"}\n", - "gwp = gwp.prepare(key=\"batch\", joint_attr=\"X_pca\", GW_x=GW_x, GW_y=GW_y)" + "x_attr = {\n", + " \"attr\": \"obsp\",\n", + " \"key\": \"cost_matrices\",\n", + " \"tag\": \"cost_matrix\",\n", + " \"cost\": \"custom\",\n", + "}\n", + "y_attr = {\n", + " \"attr\": \"obsp\",\n", + " \"key\": \"cost_matrices\",\n", + " \"tag\": \"cost_matrix\",\n", + " \"cost\": \"custom\",\n", + "}\n", + "gwp = gwp.prepare(key=\"batch\", joint_attr=\"X_pca\", x_attr=x_attr, y_attr=y_attr)" ] }, { @@ -316,9 +376,24 @@ "execution_count": 9, "id": "cbb8b363", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m, \u001b[32m'0-2'\u001b[0m, \u001b[32m'1-2'\u001b[0m, \u001b[32m'2-2'\u001b[0m, \u001b[32m'3-2'\u001b[0m, \u001b[32m'4-2'\u001b[0m, \u001b[32m'5-2'\u001b[0m, \u001b[32m'6-2'\u001b[0m, \u001b[32m'7-2'\u001b[0m, \n", + " \u001b[32m'8-2'\u001b[0m, \u001b[32m'9-2'\u001b[0m, \u001b[32m'10-2'\u001b[0m, \u001b[32m'11-2'\u001b[0m, \u001b[32m'12-2'\u001b[0m, \u001b[32m'13-2'\u001b[0m, \u001b[32m'14-2'\u001b[0m, \u001b[32m'15-2'\u001b[0m, \u001b[32m'16-2'\u001b[0m, \n", + " \u001b[32m'17-2'\u001b[0m, \u001b[32m'18-2'\u001b[0m, \u001b[32m'19-2'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n" + ] + } + ], "source": [ - "gwp = gwp.prepare(key=\"batch\", joint_attr=joint_attr, GW_x=GW_x, GW_y=GW_y)" + "gwp = gwp.prepare(key=\"batch\", joint_attr=joint_attr, x_attr=x_attr, y_attr=y_attr)" ] } ], @@ -344,7 +419,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/examples/solvers/300_quad_problems_basic.ipynb b/examples/solvers/300_quad_problems_basic.ipynb index d96b882..6214c74 100644 --- a/examples/solvers/300_quad_problems_basic.ipynb +++ b/examples/solvers/300_quad_problems_basic.ipynb @@ -33,6 +33,10 @@ "metadata": {}, "outputs": [], "source": [ + "import warnings\n", + "\n", + "warnings.simplefilter(\"ignore\", FutureWarning)\n", + "\n", "from moscot import datasets\n", "from moscot.problems.generic import GWProblem\n", "\n", @@ -96,10 +100,37 @@ "name": "stdout", "output_type": "stream", "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n", "\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n", + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", + "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n", + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", - "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", - "0.021854\n" + "max difference: 0.021854\n" ] } ], @@ -107,16 +138,16 @@ "gwp = GWProblem(adata)\n", "gwp = gwp.prepare(\n", " key=\"batch\",\n", - " GW_x={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", - " GW_y={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", + " x_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", + " y_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", ")\n", "gwp = gwp.solve(alpha=1.0, epsilon=1e-1)\n", "\n", "fgwp = GWProblem(adata)\n", "fgwp = fgwp.prepare(\n", " key=\"batch\",\n", - " GW_x={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", - " GW_y={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", + " x_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", + " y_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", " joint_attr=\"X_pca\",\n", ")\n", "fgwp = fgwp.solve(epsilon=1e-1, alpha=0.5)\n", @@ -127,7 +158,7 @@ " - fgwp[\"0\", \"1\"].solution.transport_matrix\n", " )\n", ")\n", - "print(f\"{max_difference:.6f}\")" + "print(f\"max difference: {max_difference:.6f}\")" ] }, { @@ -142,7 +173,7 @@ "Whenever the dataset is very large, the computational complexity can be\n", "reduced by setting `rank` to a positive integer {cite}`scetbon:21a`. In this\n", "case, `epsilon` can also be set to $0$, while only the balanced case\n", - "($\\text{tau}_a = \\text{tau_b} = 1$) is supported. Moreover, the data has to be provided\n", + "($\\text{tau}_a = \\text{tau}_b = 1$) is supported. Moreover, the data has to be provided\n", "as point clouds, i.e., no precomputed cost matrix can be passed." ] }, @@ -156,6 +187,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'solved'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" ] } @@ -185,7 +217,7 @@ "kernelspec": { "display_name": "moscot", "language": "python", - "name": "moscot" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -197,7 +229,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/solvers/400_quad_problems_advanced.ipynb b/examples/solvers/400_quad_problems_advanced.ipynb index de3655f..7c18971 100644 --- a/examples/solvers/400_quad_problems_advanced.ipynb +++ b/examples/solvers/400_quad_problems_advanced.ipynb @@ -33,6 +33,10 @@ "metadata": {}, "outputs": [], "source": [ + "import warnings\n", + "\n", + "warnings.simplefilter(\"ignore\", FutureWarning)\n", + "\n", "from moscot import datasets\n", "from moscot.problems.generic import GWProblem\n", "\n", @@ -84,6 +88,12 @@ "name": "stdout", "output_type": "stream", "text": [ + "\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n", + " \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n", + " \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n", + " \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n", + " \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n", + " \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n", "\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n" ] }, @@ -102,8 +112,8 @@ "gwp = GWProblem(adata)\n", "gwp = gwp.prepare(\n", " key=\"batch\",\n", - " GW_x={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", - " GW_y={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", + " x_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", + " y_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n", ")\n", "gwp" ] @@ -159,7 +169,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", + "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[31mWARNING \u001b[0m Solver did not converge \n" ] } @@ -195,8 +219,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'solved'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", - "\u001b[31mWARNING \u001b[0m Solver did not converge \n" + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", + "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'solved'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" ] } ], @@ -245,7 +269,7 @@ "kernelspec": { "display_name": "moscot", "language": "python", - "name": "moscot" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -257,7 +281,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.5" } }, "nbformat": 4,