From 1e3a531c6e2387230f26fdf452fa108bc03d94d6 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Dec 2024 23:26:27 +0100 Subject: [PATCH] initializer update --- .../200_linear_problems_advanced.ipynb | 20 ++++++++++++------- tutorials/500_spatiotemporal.ipynb | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/solvers/200_linear_problems_advanced.ipynb b/examples/solvers/200_linear_problems_advanced.ipynb index cc40b28..732c928 100644 --- a/examples/solvers/200_linear_problems_advanced.ipynb +++ b/examples/solvers/200_linear_problems_advanced.ipynb @@ -34,7 +34,8 @@ "outputs": [], "source": [ "from moscot import datasets\n", - "from moscot.problems.generic import SinkhornProblem" + "from moscot.problems.generic import SinkhornProblem\n", + "from ott.initializers.linear import initializers_lr as init_lr_lib" ] }, { @@ -105,6 +106,7 @@ "output_type": "stream", "text": [ "\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n", + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" ] } @@ -114,7 +116,7 @@ "sp = sp.prepare(key=\"day\")\n", "\n", "ik = {\"min_iterations\": 5, \"max_iterations\": 200}\n", - "sp = sp.solve(epsilon=0, rank=3, initializer=\"k-means\", initializer_kwargs=ik)" + "sp = sp.solve(epsilon=0, rank=3, initializer=init_lr_lib.KMeansInitializer(rank=3), initializer_kwargs=ik)" ] }, { @@ -145,6 +147,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'solved'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", "\u001b[31mWARNING \u001b[0m Solver did not converge \n" ] @@ -179,13 +182,14 @@ "name": "stdout", "output_type": "stream", "text": [ + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'solved'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", "\u001b[31mWARNING \u001b[0m Solver did not converge \n" ] } ], "source": [ - "sp = sp.solve(epsilon=0, rank=3, initializer=\"random\", max_iterations=30, gamma=1000)" + "sp = sp.solve(epsilon=0, rank=3, initializer=init_lr_lib.RandomInitializer(rank=3), max_iterations=30, gamma=1000)" ] }, { @@ -198,12 +202,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'solved'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n" + "\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n", + "\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'solved'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n", + "\u001b[31mWARNING \u001b[0m Solver did not converge \n" ] } ], "source": [ - "sp = sp.solve(epsilon=0, rank=3, initializer=\"random\", max_iterations=30, gamma=10)" + "sp = sp.solve(epsilon=0, rank=3, initializer=init_lr_lib.RandomInitializer(rank=3), max_iterations=30, gamma=10)" ] }, { @@ -233,7 +239,7 @@ "kernelspec": { "display_name": "moscot", "language": "python", - "name": "moscot" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -245,7 +251,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/tutorials/500_spatiotemporal.ipynb b/tutorials/500_spatiotemporal.ipynb index cee9e34..797838c 100644 --- a/tutorials/500_spatiotemporal.ipynb +++ b/tutorials/500_spatiotemporal.ipynb @@ -276,7 +276,7 @@ "source": [ "As the data is large, we use low-rank optimal transport to decrease the computational complexity {cite}`scetbon:21a`. \n", "\n", - "To solve the problem we pass the following arguments, `alpha` (in $(0, 1]$), which defines the influence of the spatial coordinates as opposed to the single-cell data. `epsilon`, the entropy parameter, and `initializer='rank2'` to improve the speed of convergence.\n", + "To solve the problem we pass the following arguments, `alpha` (in $(0, 1]$), which defines the influence of the spatial coordinates as opposed to the single-cell data and `epsilon`, the entropy parameter to improve the speed of convergence.\n", "\n", ":::{seealso}\n", "- See {doc}`../examples/solvers/200_linear_problems_advanced` on how to modify low-rank parameters. \n",