From 6e63862cff5aa145d7425b841b046bf7d8e25d75 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 21:26:03 +0000 Subject: [PATCH 1/2] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.2 → v0.8.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.2...v0.8.4) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f762b1..b8cb1a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: hooks: - id: pyproject-fmt - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 + rev: v0.8.4 hooks: - id: ruff types_or: [python, pyi, jupyter] From 55d9144ebb597374f9c9664fbd409ba435e4012f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 21:27:06 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../0_format_Xenium_sdata.ipynb | 21 +- .../1_run_segmentation_free.ipynb | 67 +- .../2_process_cells.ipynb | 58 +- .../3_quantify_exRNA.ipynb | 62 +- .../4_explore_comunication.ipynb | 120 +-- .../6_identify_patterns.ipynb | 41 +- .../7_image_params.ipynb | 123 ++- .../8_exrna_signature.ipynb | 42 +- .../spatialdata_tutorials/run_sainsc.ipynb | 155 ++-- src/troutpy/__init__.py | 6 +- src/troutpy/_utils.py | 10 - src/troutpy/pl/__init__.py | 21 +- src/troutpy/pl/plotting.py | 701 +++++++++++------- src/troutpy/pp/__init__.py | 4 +- src/troutpy/pp/compute.py | 93 +-- src/troutpy/pp/format.py | 59 +- src/troutpy/read/__init__.py | 2 +- src/troutpy/tl/NMF.py | 154 ++-- src/troutpy/tl/__init__.py | 24 +- src/troutpy/tl/estimate_density.py | 41 +- src/troutpy/tl/interactions.py | 93 ++- src/troutpy/tl/quantify_xrna.py | 261 +++---- src/troutpy/tl/segmentation_free.py | 56 +- src/troutpy/tl/source_cell.py | 178 +++-- src/troutpy/tl/target_cell.py | 75 +- 25 files changed, 1356 insertions(+), 1111 deletions(-) diff --git a/notebooks/spatialdata_tutorials/0_format_Xenium_sdata.ipynb b/notebooks/spatialdata_tutorials/0_format_Xenium_sdata.ipynb index 6b83106..9000539 100644 --- a/notebooks/spatialdata_tutorials/0_format_Xenium_sdata.ipynb +++ b/notebooks/spatialdata_tutorials/0_format_Xenium_sdata.ipynb @@ -13,8 +13,8 @@ "metadata": {}, "outputs": [], "source": [ - "import spatialdata_io\n", - "import spatialdata as sd" + "import spatialdata as sd\n", + "import spatialdata_io" ] }, { @@ -39,8 +39,8 @@ } ], "source": [ - "path='/media/sergio/Discovair_final/Xenium_Prime_Mouse_Brain_Coronal_FF_outs'\n", - "sdata=spatialdata_io.xenium(path)" + "path = \"/media/sergio/Discovair_final/Xenium_Prime_Mouse_Brain_Coronal_FF_outs\"\n", + "sdata = spatialdata_io.xenium(path)" ] }, { @@ -66,7 +66,7 @@ } ], "source": [ - "outpath='/media/sergio/Discovair_final/mousebrain_prime.zarr'\n", + "outpath = \"/media/sergio/Discovair_final/mousebrain_prime.zarr\"\n", "sdata.write(outpath)" ] }, @@ -130,7 +130,7 @@ } ], "source": [ - "xenium_path='/media/sergio/Discovair_final/mousebrain_prime.zarr'\n", + "xenium_path = \"/media/sergio/Discovair_final/mousebrain_prime.zarr\"\n", "sdata = sd.read_zarr(xenium_path)\n", "sdata" ] @@ -179,7 +179,7 @@ " axes=[\"x\", \"y\"],\n", " min_coordinate=[17500, 0],\n", " max_coordinate=[35000, 15000],\n", - " target_coordinate_system='global',\n", + " target_coordinate_system=\"global\",\n", ")\n", "\n", "cropped_sdata" @@ -209,8 +209,7 @@ } ], "source": [ - "import spatialdata_plot\n", - "cropped_sdata.pl.render_images(\"morphology_focus\").pl.show( title=\"Morphology image\")" + "cropped_sdata.pl.render_images(\"morphology_focus\").pl.show(title=\"Morphology image\")" ] }, { @@ -232,8 +231,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_half.zarr'\n", - "cropped_sdata.write(xenium_path_cropped,overwrite=True)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_half.zarr\"\n", + "cropped_sdata.write(xenium_path_cropped, overwrite=True)" ] } ], diff --git a/notebooks/spatialdata_tutorials/1_run_segmentation_free.ipynb b/notebooks/spatialdata_tutorials/1_run_segmentation_free.ipynb index f043ed1..b4b64ea 100644 --- a/notebooks/spatialdata_tutorials/1_run_segmentation_free.ipynb +++ b/notebooks/spatialdata_tutorials/1_run_segmentation_free.ipynb @@ -6,13 +6,12 @@ "metadata": {}, "outputs": [], "source": [ - "import spatialdata_io\n", - "import spatialdata as sd\n", - "import pandas as pd\n", - "import spatialdata_plot\n", "import sys\n", + "\n", + "import spatialdata as sd\n", + "\n", "sys.path.append(\"../../src\")\n", - "import troutpy " + "import troutpy" ] }, { @@ -40,8 +39,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop.zarr'\n", - "sdata= sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop.zarr\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -81,11 +80,12 @@ } ], "source": [ - "#define points2regions_params\n", - "points2regions_params={'num_clusters':100, 'pixel_width':0.4, 'pixel_smoothing':3.5}\n", - "#run_segmentation_free\n", - "troutpy.tl.segmentation_free_clustering(sdata, params=points2regions_params, \n", - " x='x', y='y', feature_name='feature_name', transcript_id='transcript_id')" + "# define points2regions_params\n", + "points2regions_params = {\"num_clusters\": 100, \"pixel_width\": 0.4, \"pixel_smoothing\": 3.5}\n", + "# run_segmentation_free\n", + "troutpy.tl.segmentation_free_clustering(\n", + " sdata, params=points2regions_params, x=\"x\", y=\"y\", feature_name=\"feature_name\", transcript_id=\"transcript_id\"\n", + ")" ] }, { @@ -110,7 +110,7 @@ } ], "source": [ - "troutpy.pp.define_extracellular(sdata, method='segmentation_free',min_prop_of_extracellular=0.8) " + "troutpy.pp.define_extracellular(sdata, method=\"segmentation_free\", min_prop_of_extracellular=0.8)" ] }, { @@ -130,7 +130,7 @@ } ], "source": [ - "len(sdata.points['transcripts']['extracellular'].compute())" + "len(sdata.points[\"transcripts\"][\"extracellular\"].compute())" ] }, { @@ -151,7 +151,8 @@ ], "source": [ "import numpy as np\n", - "np.sum(sdata.points['transcripts']['extracellular'].compute())" + "\n", + "np.sum(sdata.points[\"transcripts\"][\"extracellular\"].compute())" ] }, { @@ -173,8 +174,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions.zarr'\n", - "sdata.write(xenium_path_cropped,overwrite=True)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions.zarr\"\n", + "sdata.write(xenium_path_cropped, overwrite=True)" ] }, { @@ -184,13 +185,35 @@ "outputs": [], "source": [ "### work a bit into better functions for plotting here\n", - "plot_data=sdata.points['transcripts'][['segmentation_free_clusters','overlaps_cell','overlaps_nucleus']].compute()\n", - "troutpy.pl.plot_crosstab(plot_data,xvar='segmentation_free_clusters',yvar='overlaps_cell',normalize=True,axis=1,kind='bar',figsize=(20,7),stacked=True,cmap='coolwarm',sortby=True)\n", - "troutpy.pl.plot_crosstab(plot_data,xvar='segmentation_free_clusters',yvar='overlaps_nucleus',normalize=True,axis=1,kind='bar',figsize=(20,7),stacked=True,cmap='coolwarm',sortby=1)\n", + "plot_data = sdata.points[\"transcripts\"][[\"segmentation_free_clusters\", \"overlaps_cell\", \"overlaps_nucleus\"]].compute()\n", + "troutpy.pl.plot_crosstab(\n", + " plot_data,\n", + " xvar=\"segmentation_free_clusters\",\n", + " yvar=\"overlaps_cell\",\n", + " normalize=True,\n", + " axis=1,\n", + " kind=\"bar\",\n", + " figsize=(20, 7),\n", + " stacked=True,\n", + " cmap=\"coolwarm\",\n", + " sortby=True,\n", + ")\n", + "troutpy.pl.plot_crosstab(\n", + " plot_data,\n", + " xvar=\"segmentation_free_clusters\",\n", + " yvar=\"overlaps_nucleus\",\n", + " normalize=True,\n", + " axis=1,\n", + " kind=\"bar\",\n", + " figsize=(20, 7),\n", + " stacked=True,\n", + " cmap=\"coolwarm\",\n", + " sortby=1,\n", + ")\n", "\n", "\n", - "input_data=sdata.points['transcripts'][['missegmentation_associated']].compute()\n", - "troutpy.pl.pie_of_positive(input_data,groupby='missegmentation_associated',save=True)" + "input_data = sdata.points[\"transcripts\"][[\"missegmentation_associated\"]].compute()\n", + "troutpy.pl.pie_of_positive(input_data, groupby=\"missegmentation_associated\", save=True)" ] } ], diff --git a/notebooks/spatialdata_tutorials/2_process_cells.ipynb b/notebooks/spatialdata_tutorials/2_process_cells.ipynb index 8f8c539..b799cb6 100644 --- a/notebooks/spatialdata_tutorials/2_process_cells.ipynb +++ b/notebooks/spatialdata_tutorials/2_process_cells.ipynb @@ -21,19 +21,13 @@ } ], "source": [ - "import os\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import spatialdata_io\n", - "import spatialdata as sd\n", + "# while not pip installable, add path to file\n", + "import sys\n", + "\n", "import scanpy as sc\n", + "import spatialdata as sd\n", "\n", - "# while not pip installable, add path to file \n", - "import sys\n", - "sys.path.append(\"../../src\")\n", - "import troutpy \n" + "sys.path.append(\"../../src\")" ] }, { @@ -61,9 +55,9 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions.zarr'\n", - "output_path='/media/sergio/Discovair_final/analysis_crop'\n", - "sdata=sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions.zarr\"\n", + "output_path = \"/media/sergio/Discovair_final/analysis_crop\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -72,7 +66,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata=sdata['table']" + "adata = sdata[\"table\"]" ] }, { @@ -81,7 +75,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata.raw=adata\n", + "adata.raw = adata\n", "sc.pp.filter_cells(adata, min_genes=100)\n", "sc.pp.filter_genes(adata, min_cells=3)" ] @@ -155,7 +149,7 @@ ], "source": [ "sc.tl.umap(adata)\n", - "sc.tl.leiden(adata,key_added='leiden')" + "sc.tl.leiden(adata, key_added=\"leiden\")" ] }, { @@ -175,7 +169,7 @@ } ], "source": [ - "sc.pl.umap(adata,color=\"leiden\",size=30)" + "sc.pl.umap(adata, color=\"leiden\", size=30)" ] }, { @@ -204,9 +198,7 @@ ], "source": [ "sc.tl.rank_genes_groups(adata, groupby=\"leiden\", method=\"wilcoxon\")\n", - "sc.pl.rank_genes_groups_dotplot(\n", - " adata, groupby=\"leiden\", standard_scale=\"var\", n_genes=5\n", - ")" + "sc.pl.rank_genes_groups_dotplot(adata, groupby=\"leiden\", standard_scale=\"var\", n_genes=5)" ] }, { @@ -226,7 +218,7 @@ } ], "source": [ - "sc.pl.spatial(adata,color='leiden',spot_size=20)" + "sc.pl.spatial(adata, color=\"leiden\", spot_size=20)" ] }, { @@ -242,8 +234,22 @@ "metadata": {}, "outputs": [], "source": [ - "anndict={'0':'CA','1':'Astro','2':'Oligo','3':'3','4':'4','5':'5','6':'INH','7':'7','8':'','9':'9','10':'10','11':'11','12':'12'}\n", - "adata.obs['cell type']=adata.obs['leiden'].map(anndict)" + "anndict = {\n", + " \"0\": \"CA\",\n", + " \"1\": \"Astro\",\n", + " \"2\": \"Oligo\",\n", + " \"3\": \"3\",\n", + " \"4\": \"4\",\n", + " \"5\": \"5\",\n", + " \"6\": \"INH\",\n", + " \"7\": \"7\",\n", + " \"8\": \"\",\n", + " \"9\": \"9\",\n", + " \"10\": \"10\",\n", + " \"11\": \"11\",\n", + " \"12\": \"12\",\n", + "}\n", + "adata.obs[\"cell type\"] = adata.obs[\"leiden\"].map(anndict)" ] }, { @@ -266,8 +272,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr'\n", - "sdata.write(xenium_path_cropped,overwrite=True)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr\"\n", + "sdata.write(xenium_path_cropped, overwrite=True)" ] } ], diff --git a/notebooks/spatialdata_tutorials/3_quantify_exRNA.ipynb b/notebooks/spatialdata_tutorials/3_quantify_exRNA.ipynb index d24b444..dd4f914 100644 --- a/notebooks/spatialdata_tutorials/3_quantify_exRNA.ipynb +++ b/notebooks/spatialdata_tutorials/3_quantify_exRNA.ipynb @@ -21,18 +21,13 @@ } ], "source": [ - "import os\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import spatialdata_io\n", + "# while not pip installable, add path to file\n", + "import sys\n", + "\n", "import spatialdata as sd\n", "\n", - "# while not pip installable, add path to file \n", - "import sys\n", "sys.path.append(\"../../src\")\n", - "import troutpy \n" + "import troutpy" ] }, { @@ -67,8 +62,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr'\n", - "sdata=sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -108,11 +103,22 @@ } ], "source": [ - "control_codewords = ['negative_control_probe','unassigned_codeword', 'deprecated_codeword', \n", - " 'genomic_control_probe', 'negative_control_codeword']\n", + "control_codewords = [\n", + " \"negative_control_probe\",\n", + " \"unassigned_codeword\",\n", + " \"deprecated_codeword\",\n", + " \"genomic_control_probe\",\n", + " \"negative_control_codeword\",\n", + "]\n", "\n", - "troutpy.tl.quantify_overexpression(sdata,layer='transcripts',codeword_column=\"codeword_category\",\n", - " control_codewords=control_codewords,gene_id_column=\"feature_name\",percentile_threshold=99.99)\n" + "troutpy.tl.quantify_overexpression(\n", + " sdata,\n", + " layer=\"transcripts\",\n", + " codeword_column=\"codeword_category\",\n", + " control_codewords=control_codewords,\n", + " gene_id_column=\"feature_name\",\n", + " percentile_threshold=99.99,\n", + ")" ] }, { @@ -161,7 +167,7 @@ } ], "source": [ - "troutpy.tl.spatial_variability(sdata, gene_id_key='feature_name', n_neighbors=10,binsize=20)" + "troutpy.tl.spatial_variability(sdata, gene_id_key=\"feature_name\", n_neighbors=10, binsize=20)" ] }, { @@ -185,7 +191,15 @@ } ], "source": [ - "troutpy.tl.spatial_colocalization(sdata, coords_keys=['x', 'y'], gene_id_key='feature_name',resolution=1000,binsize=5, threshold_colocalized=1 ,copy=False)" + "troutpy.tl.spatial_colocalization(\n", + " sdata,\n", + " coords_keys=[\"x\", \"y\"],\n", + " gene_id_key=\"feature_name\",\n", + " resolution=1000,\n", + " binsize=5,\n", + " threshold_colocalized=1,\n", + " copy=False,\n", + ")" ] }, { @@ -212,7 +226,7 @@ } ], "source": [ - "sdata['xrna_metadata'].var['control_probe'].unique()" + "sdata[\"xrna_metadata\"].var[\"control_probe\"].unique()" ] }, { @@ -235,8 +249,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_quantified.zarr'\n", - "sdata.write(xenium_path_cropped,overwrite=True)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_quantified.zarr\"\n", + "sdata.write(xenium_path_cropped, overwrite=True)" ] }, { @@ -252,9 +266,11 @@ "metadata": {}, "outputs": [], "source": [ - "exrna_metrics=sdata['xrna_metadata'].var\n", - "exrna_metrics=exrna_metrics[exrna_metrics['count']>120]\n", - "exrna_metrics_filt=exrna_metrics.loc[:,['logfoldratio_over_noise','logfoldratio_extracellular','moran_I','proportion_of_colocalized']]" + "exrna_metrics = sdata[\"xrna_metadata\"].var\n", + "exrna_metrics = exrna_metrics[exrna_metrics[\"count\"] > 120]\n", + "exrna_metrics_filt = exrna_metrics.loc[\n", + " :, [\"logfoldratio_over_noise\", \"logfoldratio_extracellular\", \"moran_I\", \"proportion_of_colocalized\"]\n", + "]" ] } ], diff --git a/notebooks/spatialdata_tutorials/4_explore_comunication.ipynb b/notebooks/spatialdata_tutorials/4_explore_comunication.ipynb index 25a7d0e..ce84604 100644 --- a/notebooks/spatialdata_tutorials/4_explore_comunication.ipynb +++ b/notebooks/spatialdata_tutorials/4_explore_comunication.ipynb @@ -21,18 +21,14 @@ } ], "source": [ - "import os\n", - "import numpy as np\n", + "# while not pip installable, add path to file\n", + "import sys\n", + "\n", "import pandas as pd\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import spatialdata_io\n", "import spatialdata as sd\n", "\n", - "# while not pip installable, add path to file \n", - "import sys \n", "sys.path.append(\"../../src\")\n", - "import troutpy " + "import troutpy" ] }, { @@ -60,9 +56,9 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_quantified.zarr'\n", - "output_path='/media/sergio/Discovair_final/analysis_crop'\n", - "sdata=sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_quantified.zarr\"\n", + "output_path = \"/media/sergio/Discovair_final/analysis_crop\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -97,10 +93,10 @@ "outputs": [], "source": [ "# work on visualization\n", - "#filtered_proportions = source_proportion[(source_proportion > 0.1).any(axis=1)]\n", - "#absent_source_proportions = source_proportion[np.max(source_proportion,axis=1)<0.1]\n", - "#troutpy.pl.sorted_heatmap(source_proportion, output_path,filename=\"Heatmap_source_cells_by_gene.pdf\",cmap='Blues',vmax=0.2)\n", - "#troutpy.pl.sorted_heatmap(absent_source_proportions, output_path,filename=\"Heatmap_source_cells_by_gene.pdf\",cmap='Blues',vmax=1,figsize=(4,3))" + "# filtered_proportions = source_proportion[(source_proportion > 0.1).any(axis=1)]\n", + "# absent_source_proportions = source_proportion[np.max(source_proportion,axis=1)<0.1]\n", + "# troutpy.pl.sorted_heatmap(source_proportion, output_path,filename=\"Heatmap_source_cells_by_gene.pdf\",cmap='Blues',vmax=0.2)\n", + "# troutpy.pl.sorted_heatmap(absent_source_proportions, output_path,filename=\"Heatmap_source_cells_by_gene.pdf\",cmap='Blues',vmax=1,figsize=(4,3))" ] }, { @@ -124,7 +120,9 @@ } ], "source": [ - "troutpy.tl.distance_to_source_cell(sdata, xcellcoord='x_centroid', ycellcoord='y_centroid',gene_id_column='feature_name')" + "troutpy.tl.distance_to_source_cell(\n", + " sdata, xcellcoord=\"x_centroid\", ycellcoord=\"y_centroid\", gene_id_column=\"feature_name\"\n", + ")" ] }, { @@ -142,7 +140,9 @@ } ], "source": [ - "troutpy.tl.compute_distant_cells_prop(sdata, layer='transcripts', gene_id_column='feature_name', threshold=30,copy=False)" + "troutpy.tl.compute_distant_cells_prop(\n", + " sdata, layer=\"transcripts\", gene_id_column=\"feature_name\", threshold=30, copy=False\n", + ")" ] }, { @@ -152,7 +152,7 @@ "outputs": [], "source": [ "## work on plotting\n", - "#troutpy.pl.proportion_above_threshold(proportions_above_threshold, top_percentile=0.01, bottom_percentile=0.01, figsize=(5, 10), bar_color=\"orange\",save=True,output_path=output_path)" + "# troutpy.pl.proportion_above_threshold(proportions_above_threshold, top_percentile=0.01, bottom_percentile=0.01, figsize=(5, 10), bar_color=\"orange\",save=True,output_path=output_path)" ] }, { @@ -181,8 +181,16 @@ ], "source": [ "# Calculate closest cells and distances\n", - "troutpy.tl.calculate_target_cells(sdata, layer='transcripts', xcoord='x', \n", - "ycoord='y',xcellcoord='x_centroid',ycellcoord='y_centroid',celltype_key='cell type',copy=False)" + "troutpy.tl.calculate_target_cells(\n", + " sdata,\n", + " layer=\"transcripts\",\n", + " xcoord=\"x\",\n", + " ycoord=\"y\",\n", + " xcellcoord=\"x_centroid\",\n", + " ycellcoord=\"y_centroid\",\n", + " celltype_key=\"cell type\",\n", + " copy=False,\n", + ")" ] }, { @@ -191,9 +199,9 @@ "metadata": {}, "outputs": [], "source": [ - "### work on plotting \n", + "### work on plotting\n", "##troutpy.pl.sorted_heatmap(target_proportion, output_path,filename=\"Heatmap_target_cells_by_gene.pdf\",cmap='Reds',vmax=1)\n", - "#troutpy.pl.coupled_scatter(sdata,layer='extracellular_transcripts_enriched',save=False,transcript_group='distance_to_source_cell',size=3,vmax=40,figsize=(6,4))" + "# troutpy.pl.coupled_scatter(sdata,layer='extracellular_transcripts_enriched',save=False,transcript_group='distance_to_source_cell',size=3,vmax=40,figsize=(6,4))" ] }, { @@ -216,8 +224,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_communication.zarr'\n", - "sdata.write(xenium_path_cropped,overwrite=True)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_communication.zarr\"\n", + "sdata.write(xenium_path_cropped, overwrite=True)" ] }, { @@ -256,8 +264,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_communication.zarr'\n", - "sdata=sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_communication.zarr\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -266,10 +274,11 @@ "metadata": {}, "outputs": [], "source": [ - "def communicating_genes_per_celltype(sdata,proportion_threshold: float = 0.2,cell_type_key='cell type'):\n", + "def communicating_genes_per_celltype(sdata, proportion_threshold: float = 0.2, cell_type_key=\"cell type\"):\n", " \"\"\"Computes the number of significant genes exchanged between source and target cell types based on their proportions in the dataset.\n", "\n", - " Parameters:\n", + " Parameters\n", + " ----------\n", " - sdata: AnnData object\n", " A spatial data object containing the 'table' and 'xrna_metadata' components.\n", " - proportion_threshold: float, optional (default=0.2)\n", @@ -277,22 +286,31 @@ " - cell_type_key: str, optional (default='cell type')\n", " The key in `adata.obs` that contains the cell type annotations.\n", "\n", - " Returns:\n", + " Returns\n", + " -------\n", " - number_interactions_df: pandas.DataFrame\n", " A DataFrame where rows represent source cell types, columns represent target cell types, and values indicate the number of significant genes shared between them.\n", " \"\"\"\n", - " adata=sdata['table']\n", - " source_proportions=pd.DataFrame(sdata['xrna_metadata'].varm['source'],index=sdata['xrna_metadata'].var.index,columns=adata.obs[cell_type_key].unique().dropna())\n", - " target_proportions=pd.DataFrame(sdata['xrna_metadata'].varm['target'],index=sdata['xrna_metadata'].var.index,columns=adata.obs[cell_type_key].unique().dropna())\n", + " adata = sdata[\"table\"]\n", + " source_proportions = pd.DataFrame(\n", + " sdata[\"xrna_metadata\"].varm[\"source\"],\n", + " index=sdata[\"xrna_metadata\"].var.index,\n", + " columns=adata.obs[cell_type_key].unique().dropna(),\n", + " )\n", + " target_proportions = pd.DataFrame(\n", + " sdata[\"xrna_metadata\"].varm[\"target\"],\n", + " index=sdata[\"xrna_metadata\"].var.index,\n", + " columns=adata.obs[cell_type_key].unique().dropna(),\n", + " )\n", "\n", " # filter the source and target cell types by defining significant proportions\n", - " source_binary = (source_proportions > proportion_threshold)\n", - " target_binary = (target_proportions > proportion_threshold)\n", - " \n", + " source_binary = source_proportions > proportion_threshold\n", + " target_binary = target_proportions > proportion_threshold\n", + "\n", " # prepare dataframe to store the number of exchanged genes\n", - " number_interactions_df = pd.DataFrame(index=source_binary.columns,columns=target_binary.columns)\n", + " number_interactions_df = pd.DataFrame(index=source_binary.columns, columns=target_binary.columns)\n", "\n", - " # loop through the source and target cell types to compute the number of \n", + " # loop through the source and target cell types to compute the number of\n", " # exchanged genes\n", " for col in source_binary.columns:\n", " sig_gene_source = source_binary.index[source_binary[col]]\n", @@ -300,10 +318,10 @@ " sig_gene_target = target_binary.index[target_binary[col2]]\n", " number_interactions_df.loc[col, col2] = len(set(sig_gene_source).intersection(sig_gene_target))\n", "\n", - " number_interactions_df=number_interactions_df[number_interactions_df.index]\n", - " number_interactions_df.columns.name='Target cell type' \n", - " number_interactions_df.index.name='Source cell type' \n", - " return number_interactions_df\n" + " number_interactions_df = number_interactions_df[number_interactions_df.index]\n", + " number_interactions_df.columns.name = \"Target cell type\"\n", + " number_interactions_df.index.name = \"Source cell type\"\n", + " return number_interactions_df" ] }, { @@ -312,7 +330,7 @@ "metadata": {}, "outputs": [], "source": [ - "communications_ngenes=communicating_genes_per_celltype(sdata,proportion_threshold = 0.6,cell_type_key='cell type') " + "communications_ngenes = communicating_genes_per_celltype(sdata, proportion_threshold=0.6, cell_type_key=\"cell type\")" ] }, { @@ -332,7 +350,7 @@ } ], "source": [ - "troutpy.pl.sorted_heatmap(communications_ngenes.astype(float),save=False,figsize=(5,5))" + "troutpy.pl.sorted_heatmap(communications_ngenes.astype(float), save=False, figsize=(5, 5))" ] }, { @@ -353,7 +371,7 @@ ], "source": [ "# WORK ON THIS FUNCTION SINCE IT'S CURRENTLY NOT WORKING WELL\n", - "#interactions_Vxn=troutpy.tl.get_gene_interaction_strength(source_proportion, target_proportion,gene_symbol=\"Gfap\")" + "# interactions_Vxn=troutpy.tl.get_gene_interaction_strength(source_proportion, target_proportion,gene_symbol=\"Gfap\")" ] }, { @@ -381,7 +399,9 @@ } ], "source": [ - "troutpy.pl.spatial_interactions(sdata,layer= 'extracellular_transcripts_enriched', gene = 'Arc',gene_key= 'feature_name',figsize=(5,5))" + "troutpy.pl.spatial_interactions(\n", + " sdata, layer=\"extracellular_transcripts_enriched\", gene=\"Arc\", gene_key=\"feature_name\", figsize=(5, 5)\n", + ")" ] }, { @@ -409,7 +429,15 @@ } ], "source": [ - "troutpy.pl.interactions_with_arrows(sdata,layer= 'extracellular_transcripts_enriched', gene = 'Kif5a',gene_key= 'feature_name',figsize=(7,7),dpi=100,size=5)" + "troutpy.pl.interactions_with_arrows(\n", + " sdata,\n", + " layer=\"extracellular_transcripts_enriched\",\n", + " gene=\"Kif5a\",\n", + " gene_key=\"feature_name\",\n", + " figsize=(7, 7),\n", + " dpi=100,\n", + " size=5,\n", + ")" ] } ], diff --git a/notebooks/spatialdata_tutorials/6_identify_patterns.ipynb b/notebooks/spatialdata_tutorials/6_identify_patterns.ipynb index 4fdc368..d4ef45b 100644 --- a/notebooks/spatialdata_tutorials/6_identify_patterns.ipynb +++ b/notebooks/spatialdata_tutorials/6_identify_patterns.ipynb @@ -21,18 +21,13 @@ } ], "source": [ - "import os\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import spatialdata_io\n", + "# while not pip installable, add path to file\n", + "import sys\n", + "\n", "import spatialdata as sd\n", "\n", - "# while not pip installable, add path to file \n", - "import sys \n", "sys.path.append(\"../../src\")\n", - "import troutpy " + "import troutpy" ] }, { @@ -60,9 +55,9 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr'\n", - "output_path='/media/sergio/Discovair_final/analysis_crop'\n", - "sdata=sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr\"\n", + "output_path = \"/media/sergio/Discovair_final/analysis_crop\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -78,8 +73,16 @@ "metadata": {}, "outputs": [], "source": [ - "sdata=troutpy.tl.nmf(sdata, layer='extracellular_transcripts', feature_key='feature_name', bin_key='bin_id', \n", - " density_table_key='segmentation_free_table', n_components=10, subsample_percentage=0.1,all=False)" + "sdata = troutpy.tl.nmf(\n", + " sdata,\n", + " layer=\"extracellular_transcripts\",\n", + " feature_key=\"feature_name\",\n", + " bin_key=\"bin_id\",\n", + " density_table_key=\"segmentation_free_table\",\n", + " n_components=10,\n", + " subsample_percentage=0.1,\n", + " all=False,\n", + ")" ] }, { @@ -89,7 +92,7 @@ "outputs": [], "source": [ "## in this part we plot the factors in exrna\n", - "#troutpy.pl.nmf_factors_exrna_cells_W(sdata,nmf_adata_key,saving_path=output_path,save=False,spot_size=5) " + "# troutpy.pl.nmf_factors_exrna_cells_W(sdata,nmf_adata_key,saving_path=output_path,save=False,spot_size=5)" ] }, { @@ -109,7 +112,9 @@ } ], "source": [ - "troutpy.pl.nmf_gene_contributions(sdata,nmf_adata_key='nmf_data', vmin=0.0, vmax=0.02,saving_path=output_path,save=False,figsize=(4,5))" + "troutpy.pl.nmf_gene_contributions(\n", + " sdata, nmf_adata_key=\"nmf_data\", vmin=0.0, vmax=0.02, saving_path=output_path, save=False, figsize=(4, 5)\n", + ")" ] }, { @@ -136,7 +141,7 @@ } ], "source": [ - "sdata = troutpy.tl.apply_exrna_factors_to_cells(sdata,layer_factors='nmf_data')" + "sdata = troutpy.tl.apply_exrna_factors_to_cells(sdata, layer_factors=\"nmf_data\")" ] }, { @@ -246,7 +251,7 @@ } ], "source": [ - "troutpy.pl.paired_nmf_factors(sdata,figsize=(7,7),n_factors=10)" + "troutpy.pl.paired_nmf_factors(sdata, figsize=(7, 7), n_factors=10)" ] } ], diff --git a/notebooks/spatialdata_tutorials/7_image_params.ipynb b/notebooks/spatialdata_tutorials/7_image_params.ipynb index 347c08f..1e4d917 100644 --- a/notebooks/spatialdata_tutorials/7_image_params.ipynb +++ b/notebooks/spatialdata_tutorials/7_image_params.ipynb @@ -15,19 +15,17 @@ } ], "source": [ - "import os\n", + "# while not pip installable, add path to file\n", + "import sys\n", + "\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import spatialdata_io\n", "import spatialdata as sd\n", "from tqdm import tqdm\n", "\n", - "# while not pip installable, add path to file \n", - "import sys \n", - "sys.path.append(\"../../src\")\n", - "import troutpy " + "sys.path.append(\"../../src\")" ] }, { @@ -55,9 +53,9 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr'\n", - "output_path='/media/sergio/Discovair_final/analysis_crop'\n", - "sdata=sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr\"\n", + "output_path = \"/media/sergio/Discovair_final/analysis_crop\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -66,24 +64,24 @@ "metadata": {}, "outputs": [], "source": [ - "imarray=sdata.images['morphology_focus']['scale0'].image.compute()\n", - "chnames=list(sdata.images['morphology_focus']['scale0'].image.c.values)\n", - "transcripts=sdata.points['extracellular_transcripts_enriched'].compute()\n", + "imarray = sdata.images[\"morphology_focus\"][\"scale0\"].image.compute()\n", + "chnames = list(sdata.images[\"morphology_focus\"][\"scale0\"].image.c.values)\n", + "transcripts = sdata.points[\"extracellular_transcripts_enriched\"].compute()\n", "\n", - "minx=transcripts.x.min()\n", - "maxx=transcripts.x.max()\n", - "transcripts_size_x=maxx-minx\n", - "image_size_x=imarray.shape[2]\n", - "multi_factor_x=image_size_x/transcripts_size_x\n", + "minx = transcripts.x.min()\n", + "maxx = transcripts.x.max()\n", + "transcripts_size_x = maxx - minx\n", + "image_size_x = imarray.shape[2]\n", + "multi_factor_x = image_size_x / transcripts_size_x\n", "\n", - "miny=transcripts.y.min()\n", - "maxy=transcripts.y.max()\n", - "transcripts_size_y=maxy-miny\n", - "image_size_y=imarray.shape[1]\n", - "multi_factor_y=image_size_y/transcripts_size_y\n", + "miny = transcripts.y.min()\n", + "maxy = transcripts.y.max()\n", + "transcripts_size_y = maxy - miny\n", + "image_size_y = imarray.shape[1]\n", + "multi_factor_y = image_size_y / transcripts_size_y\n", "\n", - "transcripts['x_scaled']=(transcripts['x']-minx)*multi_factor_x\n", - "transcripts['y_scaled']=(transcripts['y']-miny)*multi_factor_y" + "transcripts[\"x_scaled\"] = (transcripts[\"x\"] - minx) * multi_factor_x\n", + "transcripts[\"y_scaled\"] = (transcripts[\"y\"] - miny) * multi_factor_y" ] }, { @@ -113,8 +111,8 @@ } ], "source": [ - "plt.imshow(imarray[0,:,:])\n", - "plt.scatter(transcripts.x_scaled,transcripts.y_scaled,s=0.01,color='red')" + "plt.imshow(imarray[0, :, :])\n", + "plt.scatter(transcripts.x_scaled, transcripts.y_scaled, s=0.01, color=\"red\")" ] }, { @@ -131,16 +129,16 @@ } ], "source": [ - "import xarray as xr\n", - "import numpy as np\n", "from scipy.ndimage import gaussian_filter, zoom\n", "\n", + "\n", "# Example: Assuming `data_array` is your xArray DataArray with dims (channels, height, width)\n", "def process_xarray(data_array, sigma=5, downsize_factor=0.1):\n", " \"\"\"\n", " Processes a 3D xarray by applying Gaussian smoothing and downscaling on each 2D slice.\n", "\n", - " Parameters:\n", + " Parameters\n", + " ----------\n", " - data_array: xarray.DataArray\n", " A 3D xarray where the first dimension represents slices, and the next two dimensions\n", " represent 2D spatial data.\n", @@ -149,30 +147,32 @@ " - downsize_factor: float, optional (default=0.1)\n", " The factor by which each 2D slice is resized during downscaling.\n", "\n", - " Returns:\n", + " Returns\n", + " -------\n", " - smoothed_and_downsized: numpy.ndarray\n", " A 3D numpy array with processed slices, where the first dimension corresponds to\n", " the original slicing dimension, and the remaining dimensions are the downsized\n", " 2D spatial data.\n", "\n", - " Notes:\n", + " Notes\n", + " -----\n", " - The function loops through each slice of the input data, applies Gaussian smoothing,\n", " downsizes the smoothed array, and stacks the results back into a single 3D array.\n", " - Ensure that `data_array` has the correct shape and data types for processing.\n", " \"\"\"\n", " processed_slices = []\n", - " \n", + "\n", " for i in range(data_array.shape[0]): # Loop over the first dimension\n", " slice_2d = data_array[i, :, :].values # Extract the 2D slice\n", - " \n", + "\n", " # Apply KDE (Gaussian smoothing)\n", " smoothed = gaussian_filter(slice_2d, sigma=sigma)\n", - " \n", + "\n", " # Downsize the array\n", " downscaled = zoom(smoothed, downsize_factor)\n", - " \n", + "\n", " processed_slices.append(downscaled)\n", - " \n", + "\n", " smoothed_and_downsized = np.stack(processed_slices, axis=0)\n", " return smoothed_and_downsized\n", "\n", @@ -182,7 +182,7 @@ "smoothed_and_downsized = process_xarray(imarray, sigma=5, downsize_factor=0.5)\n", "\n", "# Check the new shape\n", - "print(smoothed_and_downsized.shape)\n" + "print(smoothed_and_downsized.shape)" ] }, { @@ -206,9 +206,9 @@ } ], "source": [ - "allg=[]\n", + "allg = []\n", "for indi in tqdm(transcripts.index):\n", - " allg.append(imarray[:,int(transcripts.loc[indi,'y_scaled'])-1,int(transcripts.loc[indi,'x_scaled'])-1])" + " allg.append(imarray[:, int(transcripts.loc[indi, \"y_scaled\"]) - 1, int(transcripts.loc[indi, \"x_scaled\"]) - 1])" ] }, { @@ -225,7 +225,6 @@ } ], "source": [ - "import numpy as np\n", "from tqdm import tqdm\n", "\n", "# Initialize a list to hold the pixel intensities for all transcripts\n", @@ -233,9 +232,9 @@ "\n", "for indi in tqdm(transcripts.index):\n", " # Get the x and y coordinates from the DataFrame\n", - " y_coord = int(transcripts.loc[indi, 'y_scaled'])\n", - " x_coord = int(transcripts.loc[indi, 'x_scaled'])\n", - " \n", + " y_coord = int(transcripts.loc[indi, \"y_scaled\"])\n", + " x_coord = int(transcripts.loc[indi, \"x_scaled\"])\n", + "\n", " # Ensure coordinates are within the bounds of the image array\n", " if 0 <= y_coord < imarray.shape[1] and 0 <= x_coord < imarray.shape[2]:\n", " # Extract the pixel intensity for all channels at the given coordinates\n", @@ -252,7 +251,7 @@ "for i in range(pixel_intensities.shape[1]): # Iterate over the number of channels\n", " transcripts[chnames[i]] = pixel_intensities[:, i]\n", "\n", - "# The transcripts DataFrame now has a column for each channel's intensity\n" + "# The transcripts DataFrame now has a column for each channel's intensity" ] }, { @@ -261,9 +260,9 @@ "metadata": {}, "outputs": [], "source": [ - "tran2=transcripts.loc[:,chnames+['feature_name']]\n", - "tran2['feature_name']=tran2['feature_name'].astype(str)\n", - "mean_intensity=tran2.groupby('feature_name').mean()" + "tran2 = transcripts.loc[:, chnames + [\"feature_name\"]]\n", + "tran2[\"feature_name\"] = tran2[\"feature_name\"].astype(str)\n", + "mean_intensity = tran2.groupby(\"feature_name\").mean()" ] }, { @@ -293,10 +292,10 @@ } ], "source": [ - "## turn into plotting \n", - "mean_intensity_norm=mean_intensity.subtract(mean_intensity.min(axis=0),axis=1)\n", - "mean_intensity_norm=mean_intensity_norm.div(mean_intensity_norm.max(axis=0),axis=1)\n", - "sns.clustermap(mean_intensity_norm.fillna(0).astype(float),figsize=(6,8))" + "## turn into plotting\n", + "mean_intensity_norm = mean_intensity.subtract(mean_intensity.min(axis=0), axis=1)\n", + "mean_intensity_norm = mean_intensity_norm.div(mean_intensity_norm.max(axis=0), axis=1)\n", + "sns.clustermap(mean_intensity_norm.fillna(0).astype(float), figsize=(6, 8))" ] }, { @@ -326,11 +325,11 @@ } ], "source": [ - "selection='ATP1A1/CD45/E-Cadherin'\n", - "mean_intensity_sorted=mean_intensity.sort_values(by=selection)\n", - "intensity_edges=pd.concat([mean_intensity_sorted.head(10),mean_intensity_sorted.tail(10)])\n", - "plt.figure(figsize=(5,5))\n", - "sns.scatterplot(y=intensity_edges.index,x=intensity_edges[selection])" + "selection = \"ATP1A1/CD45/E-Cadherin\"\n", + "mean_intensity_sorted = mean_intensity.sort_values(by=selection)\n", + "intensity_edges = pd.concat([mean_intensity_sorted.head(10), mean_intensity_sorted.tail(10)])\n", + "plt.figure(figsize=(5, 5))\n", + "sns.scatterplot(y=intensity_edges.index, x=intensity_edges[selection])" ] }, { @@ -360,11 +359,11 @@ } ], "source": [ - "selection='AlphaSMA/Vimentin'\n", - "mean_intensity_sorted=mean_intensity.sort_values(by=selection)\n", - "intensity_edges=pd.concat([mean_intensity_sorted.head(10),mean_intensity_sorted.tail(10)])\n", - "plt.figure(figsize=(5,5))\n", - "sns.scatterplot(y=intensity_edges.index,x=intensity_edges[selection])" + "selection = \"AlphaSMA/Vimentin\"\n", + "mean_intensity_sorted = mean_intensity.sort_values(by=selection)\n", + "intensity_edges = pd.concat([mean_intensity_sorted.head(10), mean_intensity_sorted.tail(10)])\n", + "plt.figure(figsize=(5, 5))\n", + "sns.scatterplot(y=intensity_edges.index, x=intensity_edges[selection])" ] } ], diff --git a/notebooks/spatialdata_tutorials/8_exrna_signature.ipynb b/notebooks/spatialdata_tutorials/8_exrna_signature.ipynb index 42f10ef..5bb564e 100644 --- a/notebooks/spatialdata_tutorials/8_exrna_signature.ipynb +++ b/notebooks/spatialdata_tutorials/8_exrna_signature.ipynb @@ -19,14 +19,11 @@ } ], "source": [ - "from gseapy import Biomart\n", "import os\n", - "import numpy as np\n", + "\n", "import pandas as pd\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import spatialdata_io\n", - "import spatialdata as sd" + "import spatialdata as sd\n", + "from gseapy import Biomart" ] }, { @@ -35,7 +32,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "bm = Biomart()" ] }, @@ -45,9 +41,10 @@ "metadata": {}, "outputs": [], "source": [ - "#queries ={'gene_id': ['ACTA2'] } # need to be a dict object\n", - "results = bm.query(dataset='mmusculus_gene_ensembl',\n", - " attributes=['ensembl_gene_id', 'external_gene_name', 'entrezgene_id', 'go_id'])" + "# queries ={'gene_id': ['ACTA2'] } # need to be a dict object\n", + "results = bm.query(\n", + " dataset=\"mmusculus_gene_ensembl\", attributes=[\"ensembl_gene_id\", \"external_gene_name\", \"entrezgene_id\", \"go_id\"]\n", + ")" ] }, { @@ -122,7 +119,7 @@ ], "source": [ "### biological process= extracellular_transport\n", - "results[results['go_id']=='GO:0006858']" + "results[results[\"go_id\"] == \"GO:0006858\"]" ] }, { @@ -132,7 +129,7 @@ "outputs": [], "source": [ "### cellular_component= extracellular_vesicle\n", - "extracellular_vesicle=results[results['go_id']=='GO:1903561']" + "extracellular_vesicle = results[results[\"go_id\"] == \"GO:1903561\"]" ] }, { @@ -167,9 +164,9 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr'\n", - "output_path='/media/sergio/Discovair_final/analysis_crop'\n", - "sdata=sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr\"\n", + "output_path = \"/media/sergio/Discovair_final/analysis_crop\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -178,7 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata=sdata['table']" + "adata = sdata[\"table\"]" ] }, { @@ -188,8 +185,9 @@ "outputs": [], "source": [ "import scanpy as sc\n", - "exvesicle=adata.var.index[adata.var.index.isin(extracellular_vesicle['external_gene_name'])]\n", - "sc.tl.score_genes(adata,exvesicle,score_name='exrna')" + "\n", + "exvesicle = adata.var.index[adata.var.index.isin(extracellular_vesicle[\"external_gene_name\"])]\n", + "sc.tl.score_genes(adata, exvesicle, score_name=\"exrna\")" ] }, { @@ -209,7 +207,7 @@ } ], "source": [ - "sc.pl.umap(adata,color=['exrna','cell type'])" + "sc.pl.umap(adata, color=[\"exrna\", \"cell type\"])" ] }, { @@ -229,7 +227,7 @@ } ], "source": [ - "sc.pl.spatial(adata,color='exrna',spot_size=20)" + "sc.pl.spatial(adata, color=\"exrna\", spot_size=20)" ] }, { @@ -258,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "exrna_prop=pd.read_parquet(os.path.join(output_path,'extracellular_proportion_of_transcripts.parquet'))" + "exrna_prop = pd.read_parquet(os.path.join(output_path, \"extracellular_proportion_of_transcripts.parquet\"))" ] }, { @@ -336,7 +334,7 @@ } ], "source": [ - "exrna_prop.loc[exvesicle,:]" + "exrna_prop.loc[exvesicle, :]" ] }, { diff --git a/notebooks/spatialdata_tutorials/run_sainsc.ipynb b/notebooks/spatialdata_tutorials/run_sainsc.ipynb index 960bfa5..4c9eb6e 100644 --- a/notebooks/spatialdata_tutorials/run_sainsc.ipynb +++ b/notebooks/spatialdata_tutorials/run_sainsc.ipynb @@ -21,14 +21,12 @@ } ], "source": [ - "import spatialdata_io\n", - "import spatialdata as sd\n", "import sys\n", + "\n", + "import spatialdata as sd\n", + "\n", "sys.path.append(\"../../src\")\n", - "import troutpy \n", - "import pandas as pd\n", - "import spatialdata_plot\n", - "import polars as pls" + "import pandas as pd" ] }, { @@ -56,8 +54,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr'\n", - "sdata= sd.read_zarr(xenium_path_cropped)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions_annotated.zarr\"\n", + "sdata = sd.read_zarr(xenium_path_cropped)" ] }, { @@ -66,8 +64,8 @@ "metadata": {}, "outputs": [], "source": [ - "transcripts=sdata.points['transcripts'][['feature_name','x','y','codeword_category']].compute()\n", - "transcripts=transcripts.reset_index(drop=True)" + "transcripts = sdata.points[\"transcripts\"][[\"feature_name\", \"x\", \"y\", \"codeword_category\"]].compute()\n", + "transcripts = transcripts.reset_index(drop=True)" ] }, { @@ -76,9 +74,8 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "transcripts=transcripts.rename({\"feature_name\": \"gene\", \"x\": \"x\", \"y\": \"y\"})\n", - "transcripts=transcripts[transcripts['codeword_category']=='predesigned_gene']" + "transcripts = transcripts.rename({\"feature_name\": \"gene\", \"x\": \"x\", \"y\": \"y\"})\n", + "transcripts = transcripts[transcripts[\"codeword_category\"] == \"predesigned_gene\"]" ] }, { @@ -87,7 +84,7 @@ "metadata": {}, "outputs": [], "source": [ - "transcripts.columns=[\"gene\",\"x\", \"y\",\"codeword_category\"]" + "transcripts.columns = [\"gene\", \"x\", \"y\", \"codeword_category\"]" ] }, { @@ -96,8 +93,8 @@ "metadata": {}, "outputs": [], "source": [ - "savpath='/media/sergio/Discovair_final/trans.csv.gz'\n", - "transcripts.to_csv(savpath,compression='gzip')" + "savpath = \"/media/sergio/Discovair_final/trans.csv.gz\"\n", + "transcripts.to_csv(savpath, compression=\"gzip\")" ] }, { @@ -107,20 +104,18 @@ "outputs": [], "source": [ "import polars as pl\n", - "savpath='/media/sergio/Discovair_final/trans.csv.gz'\n", + "\n", + "savpath = \"/media/sergio/Discovair_final/trans.csv.gz\"\n", "xenium_file = savpath\n", "n_threads = 16\n", "\n", "# Read xenium file, rename columns and filter blanks/controls\n", - "transcripts = (\n", - " pl.read_csv(\n", - " xenium_file,\n", - " columns=[\"gene\", \"x\", \"y\"],\n", - " schema_overrides={\"gene\": pl.Categorical},\n", - " n_threads=n_threads,\n", - " )\n", - " .filter(~pl.col(\"gene\").cast(pl.Utf8).str.contains(\"(BLANK|NegControl)\"))\n", - ")" + "transcripts = pl.read_csv(\n", + " xenium_file,\n", + " columns=[\"gene\", \"x\", \"y\"],\n", + " schema_overrides={\"gene\": pl.Categorical},\n", + " n_threads=n_threads,\n", + ").filter(~pl.col(\"gene\").cast(pl.Utf8).str.contains(\"(BLANK|NegControl)\"))" ] }, { @@ -130,6 +125,7 @@ "outputs": [], "source": [ "from sainsc import LazyKDE\n", + "\n", "embryo = LazyKDE.from_dataframe(transcripts, resolution=1000, binsize=3, n_threads=n_threads)" ] }, @@ -259,7 +255,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata=sdata['table']" + "adata = sdata[\"table\"]" ] }, { @@ -268,8 +264,8 @@ "metadata": {}, "outputs": [], "source": [ - "expr=adata.to_df()\n", - "expr['cell type']=adata.obs['cell type']" + "expr = adata.to_df()\n", + "expr[\"cell type\"] = adata.obs[\"cell type\"]" ] }, { @@ -287,7 +283,7 @@ } ], "source": [ - "signatures=expr.groupby('cell type').mean().transpose()" + "signatures = expr.groupby(\"cell type\").mean().transpose()" ] }, { @@ -307,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "embryo.assign_celltype(signatures,log=True)" + "embryo.assign_celltype(signatures, log=True)" ] }, { @@ -330,7 +326,7 @@ "import colorcet as cc\n", "import seaborn as sns\n", "\n", - "cmap = dict(zip(embryo.celltypes, sns.color_palette(cc.glasbey, n_colors=len(embryo.celltypes))))\n", + "cmap = dict(zip(embryo.celltypes, sns.color_palette(cc.glasbey, n_colors=len(embryo.celltypes)), strict=False))\n", "\n", "_ = embryo.plot_celltype_map(cmap=cmap)" ] @@ -403,8 +399,9 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "plt.figure(figsize=(10,10))\n", - "plt.imshow(embryo.cosine_similarity.transpose(),cmap='nipy_spectral',vmax=0.9)" + "\n", + "plt.figure(figsize=(10, 10))\n", + "plt.imshow(embryo.cosine_similarity.transpose(), cmap=\"nipy_spectral\", vmax=0.9)" ] }, { @@ -413,12 +410,14 @@ "metadata": {}, "outputs": [], "source": [ - "celltype=embryo.celltype_map.flatten()\n", - "assignment_score=embryo.assignment_score.flatten()\n", - "cosine_similarity=embryo.cosine_similarity.flatten()\n", - "output_df=pd.DataFrame({'cell type':celltype,'assignment_score':assignment_score,'cosine_similarity':cosine_similarity})\n", - "num2ct=dict(zip(range(0,len(embryo.celltypes)),embryo.celltypes))\n", - "output_df['cell type']=output_df['cell type'].map(num2ct)" + "celltype = embryo.celltype_map.flatten()\n", + "assignment_score = embryo.assignment_score.flatten()\n", + "cosine_similarity = embryo.cosine_similarity.flatten()\n", + "output_df = pd.DataFrame(\n", + " {\"cell type\": celltype, \"assignment_score\": assignment_score, \"cosine_similarity\": cosine_similarity}\n", + ")\n", + "num2ct = dict(zip(range(0, len(embryo.celltypes)), embryo.celltypes, strict=False))\n", + "output_df[\"cell type\"] = output_df[\"cell type\"].map(num2ct)" ] }, { @@ -427,7 +426,7 @@ "metadata": {}, "outputs": [], "source": [ - "filt=list((output_df['cell type']=='Astro') & (output_df['cosine_similarity']>0.5))" + "filt = list((output_df[\"cell type\"] == \"Astro\") & (output_df[\"cosine_similarity\"] > 0.5))" ] }, { @@ -444,13 +443,14 @@ } ], "source": [ - "from tqdm import tqdm\n", "import numpy as np\n", - "allres=np.zeros([np.sum(filt),len(embryo.counts.genes())])\n", - "n=0\n", + "from tqdm import tqdm\n", + "\n", + "allres = np.zeros([np.sum(filt), len(embryo.counts.genes())])\n", + "n = 0\n", "for g in tqdm(embryo.counts.genes()):\n", - " allres[:,n]=embryo.counts.get(g).todense().flatten()[filt]\n", - " n=n+1" + " allres[:, n] = embryo.counts.get(g).todense().flatten()[filt]\n", + " n = n + 1" ] }, { @@ -460,8 +460,9 @@ "outputs": [], "source": [ "import scanpy as sc\n", - "adata=sc.AnnData(allres)\n", - "adata.var.index=embryo.counts.genes()" + "\n", + "adata = sc.AnnData(allres)\n", + "adata.var.index = embryo.counts.genes()" ] }, { @@ -479,8 +480,8 @@ "metadata": {}, "outputs": [], "source": [ - "adata.obs['x']=x_coords.flatten()[filt]\n", - "adata.obs['y']=y_coords.flatten()[filt]" + "adata.obs[\"x\"] = x_coords.flatten()[filt]\n", + "adata.obs[\"y\"] = y_coords.flatten()[filt]" ] }, { @@ -489,7 +490,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata.obs['cosine_similarity']=output_df['cosine_similarity'][filt]" + "adata.obs[\"cosine_similarity\"] = output_df[\"cosine_similarity\"][filt]" ] }, { @@ -498,7 +499,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata.obsm['spatial']=np.array(adata.obs.loc[:,['y','x']])" + "adata.obsm[\"spatial\"] = np.array(adata.obs.loc[:, [\"y\", \"x\"]])" ] }, { @@ -518,7 +519,7 @@ } ], "source": [ - "sc.pl.spatial(adata,spot_size=1)" + "sc.pl.spatial(adata, spot_size=1)" ] }, { @@ -549,7 +550,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata.obs['total_counts']=np.sum(adata.X,axis=1)" + "adata.obs[\"total_counts\"] = np.sum(adata.X, axis=1)" ] }, { @@ -558,7 +559,7 @@ "metadata": {}, "outputs": [], "source": [ - "sc.pp.filter_cells(adata,min_counts=5)" + "sc.pp.filter_cells(adata, min_counts=5)" ] }, { @@ -567,7 +568,7 @@ "metadata": {}, "outputs": [], "source": [ - "sc.pp.filter_cells(adata,min_genes=3)" + "sc.pp.filter_cells(adata, min_genes=3)" ] }, { @@ -624,7 +625,7 @@ } ], "source": [ - "plt.hist(adata.obs['n_counts'],bins=20)" + "plt.hist(adata.obs[\"n_counts\"], bins=20)" ] }, { @@ -633,7 +634,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata.layers['raw']=adata.X" + "adata.layers[\"raw\"] = adata.X" ] }, { @@ -642,8 +643,8 @@ "metadata": {}, "outputs": [], "source": [ - "#sc.pp.normalize_total(adata)\n", - "#sc.pp.log1p(adata)" + "# sc.pp.normalize_total(adata)\n", + "# sc.pp.log1p(adata)" ] }, { @@ -678,7 +679,7 @@ "metadata": {}, "outputs": [], "source": [ - "seed=42\n", + "seed = 42\n", "sc.pp.pca(adata)" ] }, @@ -688,7 +689,7 @@ "metadata": {}, "outputs": [], "source": [ - "sc.pp.neighbors(adata, random_state=seed,n_pcs=0,n_neighbors=10)" + "sc.pp.neighbors(adata, random_state=seed, n_pcs=0, n_neighbors=10)" ] }, { @@ -735,7 +736,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata.obs.to_csv('/media/sergio/Discovair_final/leiden_clust.csv')" + "adata.obs.to_csv(\"/media/sergio/Discovair_final/leiden_clust.csv\")" ] }, { @@ -755,7 +756,7 @@ } ], "source": [ - "sc.pl.spatial(adata,spot_size=0.7,color='leiden')" + "sc.pl.spatial(adata, spot_size=0.7, color=\"leiden\")" ] }, { @@ -1017,10 +1018,11 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "for n,g in output_df.groupby('cell type'):\n", + "\n", + "for n, g in output_df.groupby(\"cell type\"):\n", " print(n)\n", - " plt.figure(figsize=(2,1))\n", - " ss=plt.hist(g['cosine_similarity'],bins=100)\n", + " plt.figure(figsize=(2, 1))\n", + " ss = plt.hist(g[\"cosine_similarity\"], bins=100)\n", " plt.title(n)\n", " plt.show()" ] @@ -1130,8 +1132,7 @@ } ], "source": [ - "\n", - "seed=42\n", + "seed = 42\n", "sc.pp.neighbors(cellproxy_adata, random_state=seed)\n", "sc.tl.umap(cellproxy_adata, min_dist=0.1, random_state=seed)\n", "sc.tl.leiden(cellproxy_adata, resolution=1.5, random_state=seed)\n", @@ -1169,8 +1170,12 @@ "outputs": [], "source": [ "cmap_denovo = dict(\n", - " zip(cellproxy_adata.obs[\"leiden\"].cat.categories,\n", - " cellproxy_adata.uns[\"leiden_colors\"],))\n", + " zip(\n", + " cellproxy_adata.obs[\"leiden\"].cat.categories,\n", + " cellproxy_adata.uns[\"leiden_colors\"],\n", + " strict=False,\n", + " )\n", + ")\n", "\n", "_ = embryo.plot_celltype_map(cmap=cmap_denovo)" ] @@ -1201,8 +1206,8 @@ } ], "source": [ - "xenium_path_cropped='/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions.zarr'\n", - "sdata.write(xenium_path_cropped,overwrite=True)" + "xenium_path_cropped = \"/media/sergio/Discovair_final/mousebrain_prime_crop_points2regions.zarr\"\n", + "sdata.write(xenium_path_cropped, overwrite=True)" ] }, { @@ -1238,7 +1243,9 @@ } ], "source": [ - "sdata.pl.render_points(\"transcripts\",color=\"points2region\",size=5.0).pl.show(title=f\"Points2region\", coordinate_systems=\"global\", figsize=(20, 20))" + "sdata.pl.render_points(\"transcripts\", color=\"points2region\", size=5.0).pl.show(\n", + " title=\"Points2region\", coordinate_systems=\"global\", figsize=(20, 20)\n", + ")" ] } ], diff --git a/src/troutpy/__init__.py b/src/troutpy/__init__.py index eb7bcef..477f7d0 100644 --- a/src/troutpy/__init__.py +++ b/src/troutpy/__init__.py @@ -1,7 +1,7 @@ from importlib.metadata import version -from . import pl, pp, tl,read +from . import pl, pp, read, tl -__all__ = ["pl", "pp", "tl","read"] +__all__ = ["pl", "pp", "tl", "read"] -#__version__ = version("troutpy") +# __version__ = version("troutpy") diff --git a/src/troutpy/_utils.py b/src/troutpy/_utils.py index 37f44b6..e69de29 100644 --- a/src/troutpy/_utils.py +++ b/src/troutpy/_utils.py @@ -1,10 +0,0 @@ -import os -import scanpy as sc -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns -import os -import spatialdata as sd - - diff --git a/src/troutpy/pl/__init__.py b/src/troutpy/pl/__init__.py index 1839f7a..6634749 100644 --- a/src/troutpy/pl/__init__.py +++ b/src/troutpy/pl/__init__.py @@ -1,4 +1,17 @@ -from .plotting import sorted_heatmap,coupled_scatter,heatmap,plot_crosstab,pie_of_positive,genes_over_noise -from .plotting import moranI_histogram,proportion_above_threshold,nmf_factors_exrna_cells_W -from .plotting import nmf_gene_contributions,apply_exrnaH_to_cellular_to_create_cellularW,paired_nmf_factors -from .plotting import W,spatial_interactions,interactions_with_arrows \ No newline at end of file +from .plotting import ( + W, + apply_exrnaH_to_cellular_to_create_cellularW, + coupled_scatter, + genes_over_noise, + heatmap, + interactions_with_arrows, + moranI_histogram, + nmf_factors_exrna_cells_W, + nmf_gene_contributions, + paired_nmf_factors, + pie_of_positive, + plot_crosstab, + proportion_above_threshold, + sorted_heatmap, + spatial_interactions, +) diff --git a/src/troutpy/pl/plotting.py b/src/troutpy/pl/plotting.py index 09083fd..d5399b5 100644 --- a/src/troutpy/pl/plotting.py +++ b/src/troutpy/pl/plotting.py @@ -1,25 +1,37 @@ +import os +from collections.abc import Sequence +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import os import scanpy as sc import seaborn as sns -import matplotlib.pyplot as plt -from tqdm import tqdm -from troutpy.pp.compute import compute_crosstab -from typing import Optional, Union, Sequence, Tuple from anndata import AnnData -from matplotlib.colors import Colormap, Normalize -from pathlib import Path +from matplotlib.colors import Colormap -def sorted_heatmap(celltype_by_feature, output_path:str='',filename:str="Heatmap_target_cells_by_gene",format='pdf',cmap='viridis',vmax=None,save=False,figsize=(10, 10)): +from troutpy.pp.compute import compute_crosstab + + +def sorted_heatmap( + celltype_by_feature, + output_path: str = "", + filename: str = "Heatmap_target_cells_by_gene", + format="pdf", + cmap="viridis", + vmax=None, + save=False, + figsize=(10, 10), +): """ Plots the heatmap of target cells by gene. - Parameters: + Parameters + ---------- celltype_by_feature (pd.DataFrame): DataFrame showing the fraction of each feature by cell type. outpath_dummy (str): Path to save the output plots. """ - figures_path = os.path.join(output_path, 'figures') + figures_path = os.path.join(output_path, "figures") os.makedirs(figures_path, exist_ok=True) # Sort by maximum feature in cell types @@ -30,20 +42,35 @@ def sorted_heatmap(celltype_by_feature, output_path:str='',filename:str="Heatmap # Heatmap plot plt.figure(figsize=figsize) sns.heatmap(celltype_by_feature, cmap=cmap, vmax=vmax) - plt.ylabel(f'{celltype_by_feature.index.name}') - plt.xlabel(f'{celltype_by_feature.columns.name}') + plt.ylabel(f"{celltype_by_feature.index.name}") + plt.xlabel(f"{celltype_by_feature.columns.name}") plt.title(filename) if save: - plt.savefig(os.path.join(figures_path, f'{filename}.{format}')) - -def coupled_scatter(sdata, layer='extracellular_transcripts', output_path:str='', transcript_group='distance_to_source_cell', - save=True, format='pdf', xcoord='x', ycoord='y', xcellcoord='x_centroid', ycellcoord='y_centroid', - colormap='Blues', size=2, color_cells='red', figsize=(10, 7), vmax=None): + plt.savefig(os.path.join(figures_path, f"{filename}.{format}")) + + +def coupled_scatter( + sdata, + layer="extracellular_transcripts", + output_path: str = "", + transcript_group="distance_to_source_cell", + save=True, + format="pdf", + xcoord="x", + ycoord="y", + xcellcoord="x_centroid", + ycellcoord="y_centroid", + colormap="Blues", + size=2, + color_cells="red", + figsize=(10, 7), + vmax=None, +): """Plots a scatter plot of transcript locations and cell centroids, coloring the transcripts by a specific feature (e.g., distance to the closest cell) and optionally saving the plot to a file. This function creates a scatter plot where transcripts are plotted according to their spatial coordinates (x, y), and their color represents a feature, such as the distance to the nearest cell. Cell centroids are overlaid on the plot with a specified color. The plot can be saved to a specified file path. - Parameters: + Parameters ---------- sdata : dict-like spatial data object A spatial data object that contains transcript and cell information. The relevant data is accessed from: @@ -78,58 +105,78 @@ def coupled_scatter(sdata, layer='extracellular_transcripts', output_path:str='' The size of the figure in inches (width, height). This controls the dimensions of the plot (default: (10, 7)). vmax : float, optional The upper limit for the colormap. If provided, this limits the color scale to values below `vmax` (default: None). - Returns: + + Returns ------- None The function generates a scatter plot and optionally saves it to the specified output path. - Notes: + Notes ----- - The transcript data and cell centroid data are extracted from `sdata`. - The `vmax` parameter allows control over the maximum value of the color scale for better visualization control. - The plot is saved in the specified format and at the specified output path if `save=True`. """ - # Copy the AnnData object for cell data - adata = sdata['table'].copy() + adata = sdata["table"].copy() # Use raw layer for transcript data - adata.X = sdata['table'].layers['raw'] + adata.X = sdata["table"].layers["raw"] # Extract x, y centroid coordinates from the cell data - adata.obs['x_centroid'] = [sp[0] for sp in adata.obsm['spatial']] - adata.obs['y_centroid'] = [sp[1] for sp in adata.obsm['spatial']] + adata.obs["x_centroid"] = [sp[0] for sp in adata.obsm["spatial"]] + adata.obs["y_centroid"] = [sp[1] for sp in adata.obsm["spatial"]] # Extract transcript data from the specified layer transcripts = sdata.points[layer].compute() # Create output directory if it doesn't exist - figures_path = os.path.join(output_path, 'figures') + figures_path = os.path.join(output_path, "figures") os.makedirs(figures_path, exist_ok=True) # Create the scatter plot plt.figure(figsize=figsize) # Plot transcript locations, colored by the selected feature (transcript_group) - plt.scatter(transcripts[xcoord], transcripts[ycoord], c=transcripts[transcript_group], s=size*0.1, cmap=colormap, vmax=vmax) + plt.scatter( + transcripts[xcoord], + transcripts[ycoord], + c=transcripts[transcript_group], + s=size * 0.1, + cmap=colormap, + vmax=vmax, + ) # Plot cell centroids plt.scatter(adata.obs[xcellcoord], adata.obs[ycellcoord], s=size, color=color_cells) # Set plot title - plt.title(f'{transcript_group}') + plt.title(f"{transcript_group}") # Save the plot if specified if save: plt.savefig(os.path.join(figures_path, f"Scatter_{transcript_group}_{colormap}.{format}")) - -def heatmap(data, output_path: str = '', save: bool = False, figsize=None, tag: str = '', title: str = None, - cmap: str = "RdBu_r", annot: bool = False, cbar: bool = True, vmax=None, vmin=0, - row_cluster: bool = True, col_cluster: bool = True): + + +def heatmap( + data, + output_path: str = "", + save: bool = False, + figsize=None, + tag: str = "", + title: str = None, + cmap: str = "RdBu_r", + annot: bool = False, + cbar: bool = True, + vmax=None, + vmin=0, + row_cluster: bool = True, + col_cluster: bool = True, +): """Generate a clustered heatmap from the given data and optionally save it to a file. - Parameters: - ----------- + Parameters + ---------- data : pandas.DataFrame or numpy.ndarray The data to visualize as a heatmap. Rows and columns will be clustered if specified. output_path : str, optional @@ -157,13 +204,13 @@ def heatmap(data, output_path: str = '', save: bool = False, figsize=None, tag: col_cluster : bool, optional Whether to perform hierarchical clustering on columns. Defaults to True. - Returns: - -------- + Returns + ------- None Displays the heatmap and optionally saves it to a file. - Notes: - ------ + Notes + ----- - If `save` is True, the heatmap will be saved as a PDF file in the `output_path/figures` directory. - Clustering is performed using seaborn's `clustermap` function. @@ -173,23 +220,44 @@ def heatmap(data, output_path: str = '', save: bool = False, figsize=None, tag: """ if figsize is None: figsize = (data.shape[1] / 3, (data.shape[0] / 7) + 2) - g = sns.clustermap(data, cmap=cmap, annot=annot, figsize=figsize, vmax=vmax, vmin=vmin, - col_cluster=col_cluster, row_cluster=row_cluster) - g.fig.suptitle(title) + g = sns.clustermap( + data, + cmap=cmap, + annot=annot, + figsize=figsize, + vmax=vmax, + vmin=vmin, + col_cluster=col_cluster, + row_cluster=row_cluster, + ) + g.fig.suptitle(title) if save: - figures_path = os.path.join(output_path, 'figures') + figures_path = os.path.join(output_path, "figures") os.makedirs(figures_path, exist_ok=True) plt.savefig(os.path.join(figures_path, "heatmap_" + tag + ".pdf")) plt.show() -def plot_crosstab(data, xvar: str = '', yvar: str = '', normalize=True, axis=1, kind='barh', - save=True, figures_path: str = '', stacked=True, figsize=(6, 10), - cmap='viridis', saving_format='pdf', sortby=None): + +def plot_crosstab( + data, + xvar: str = "", + yvar: str = "", + normalize=True, + axis=1, + kind="barh", + save=True, + figures_path: str = "", + stacked=True, + figsize=(6, 10), + cmap="viridis", + saving_format="pdf", + sortby=None, +): """ Plot a cross-tabulation between two variables in a dataset and visualize it as either a bar plot, horizontal bar plot, or heatmap. - Parameters: - ----------- + Parameters + ---------- data : pd.DataFrame Input dataset containing the variables for the cross-tabulation. @@ -233,63 +301,63 @@ def plot_crosstab(data, xvar: str = '', yvar: str = '', normalize=True, axis=1, sortby : str, optional (default: None) The column or row to sort the cross-tabulated data by before plotting. - Returns: - -------- + Returns + ------- None This function generates a plot and optionally saves it to a file. """ - # Compute the crosstab data crosstab_data = compute_crosstab(data, xvar=xvar, yvar=yvar) - + # Normalize the data if required if normalize: crosstab_data = crosstab_data.div(crosstab_data.sum(axis=axis), axis=0) - normtag = 'normalize' + normtag = "normalize" else: - normtag = 'raw' - + normtag = "raw" + # Sort the data if needed if sortby is not None: crosstab_data = crosstab_data.sort_values(by=sortby) - + # Generate the plot filename plot_filename = f"{kind}_{xvar}_{yvar}_{normtag}_{cmap}.{saving_format}" - + # Plot based on the selected kind - if kind == 'barh': + if kind == "barh": plt.figure() - crosstab_data.plot(kind='barh', stacked=stacked, figsize=figsize, width=0.99, colormap=cmap) - plt.title(f'{xvar}_vs_{yvar}') + crosstab_data.plot(kind="barh", stacked=stacked, figsize=figsize, width=0.99, colormap=cmap) + plt.title(f"{xvar}_vs_{yvar}") if save: plt.savefig(os.path.join(figures_path, plot_filename)) plt.show() - elif kind == 'bar': + elif kind == "bar": plt.figure() - crosstab_data.plot(kind='bar', stacked=stacked, figsize=figsize, width=0.99, colormap=cmap) - plt.title(f'{xvar}_vs_{yvar}') + crosstab_data.plot(kind="bar", stacked=stacked, figsize=figsize, width=0.99, colormap=cmap) + plt.title(f"{xvar}_vs_{yvar}") if save: plt.savefig(os.path.join(figures_path, plot_filename)) plt.show() - elif kind == 'heatmap': + elif kind == "heatmap": plt.figure() sns.heatmap(crosstab_data, figsize=figsize, cmap=cmap) - plt.title(f'{xvar}_vs_{yvar}') + plt.title(f"{xvar}_vs_{yvar}") if save: plt.savefig(os.path.join(figures_path, plot_filename)) plt.show() - elif kind == 'clustermap': + elif kind == "clustermap": plt.figure() sns.clustermap(crosstab_data, figsize=figsize, cmap=cmap) - plt.title(f'{xvar}_vs_{yvar}') + plt.title(f"{xvar}_vs_{yvar}") if save: plt.savefig(os.path.join(figures_path, plot_filename)) plt.show() -def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool = True): + +def pie_of_positive(data, groupby: str = "", figures_path: str = "", save: bool = True): """ Generates a pie chart showing the proportion of positive and negative values for a specified categorical variable in the data. @@ -309,34 +377,37 @@ def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool None The function generates and either saves or displays a pie chart, depending on the value of the `save` parameter. """ - plt.figure() - y = np.array([np.sum(~data[groupby]), np.sum(data[groupby] )]) + y = np.array([np.sum(~data[groupby]), np.sum(data[groupby])]) mylabels = [f"{groupby}=False", f"{groupby}=True"] - - plt.pie(y, labels=mylabels, colors=['#a0b7e0', '#c5e493']) - plt.title(f'Proportion of {groupby}') - + + plt.pie(y, labels=mylabels, colors=["#a0b7e0", "#c5e493"]) + plt.title(f"Proportion of {groupby}") + if save: plot_filename = f"pie_positivity_{groupby}_.pdf" plt.savefig(os.path.join(figures_path, plot_filename)) -def genes_over_noise(sdata, scores_by_genes,layer='extracellular_transcripts', output_path:str='',save=True,format:str='pdf'): + +def genes_over_noise( + sdata, scores_by_genes, layer="extracellular_transcripts", output_path: str = "", save=True, format: str = "pdf" +): """Function that plots log fold change per gene over noise using a boxplot. - - Parameters: + + Parameters + ---------- - data_quantified: DataFrame containing the extracellular transcript data, including feature names and codeword categories. - scores_by_genes: DataFrame containing gene scores with feature names and log fold ratios. - output_path: Path to save the figure. """ - data_quantified=sdata.points[layer].compute() + data_quantified = sdata.points[layer].compute() # Create the output directory for figures if it doesn't exist PATH_FIGURES = os.path.join(output_path, "figures") os.makedirs(PATH_FIGURES, exist_ok=True) # Map feature names to codeword categories - feature2codeword = dict(zip(data_quantified['feature_name'], data_quantified['codeword_category'])) - scores_by_genes['codeword_category'] = scores_by_genes['feature_name'].map(feature2codeword) + feature2codeword = dict(zip(data_quantified["feature_name"], data_quantified["codeword_category"], strict=False)) + scores_by_genes["codeword_category"] = scores_by_genes["feature_name"].map(feature2codeword) # Plot the boxplot sns.boxplot( @@ -348,16 +419,19 @@ def genes_over_noise(sdata, scores_by_genes,layer='extracellular_transcripts', o # Plot the reference line at x = 0 plt.plot([0, 0], [*plt.gca().get_ylim()], "r--") if save: - # Save the figure - plt.savefig(os.path.join(PATH_FIGURES, f"boxplot_log_fold_change_per_gene{format}"), bbox_inches="tight", pad_inches=0) + # Save the figure + plt.savefig( + os.path.join(PATH_FIGURES, f"boxplot_log_fold_change_per_gene{format}"), bbox_inches="tight", pad_inches=0 + ) # Show the plot plt.show() -def moranI_histogram(svg_df, save=True, figures_path: str = '', bins: int = 200, format: str = 'pdf'): + +def moranI_histogram(svg_df, save=True, figures_path: str = "", bins: int = 200, format: str = "pdf"): """Plots the distribution of Moran's I scores from a DataFrame. - Parameters: - ----------- + Parameters + ---------- svg_df : pandas.DataFrame DataFrame containing a column 'I' with Moran's I scores. save : bool, optional, default=True @@ -369,51 +443,54 @@ def moranI_histogram(svg_df, save=True, figures_path: str = '', bins: int = 200, format : str, optional, default='pdf' Format in which to save the figure (e.g., 'pdf', 'png'). - Returns: - -------- + Returns + ------- None """ # Check if figures_path exists if saving the figure if save and figures_path: if not os.path.exists(figures_path): raise ValueError(f"The provided path '{figures_path}' does not exist.") - + # Plot the distribution plt.figure(figsize=(8, 6)) - plt.hist(svg_df.sort_values(by='I', ascending=False)['I'], bins=bins) + plt.hist(svg_df.sort_values(by="I", ascending=False)["I"], bins=bins) plt.xlabel("Moran's I") plt.ylabel("Frequency") plt.title("Distribution of Moran's I Scores") - + # Save the plot if requested if save: - file_name = os.path.join(figures_path, f'barplot_moranI_by_gene.{format}') + file_name = os.path.join(figures_path, f"barplot_moranI_by_gene.{format}") plt.savefig(file_name, format=format) print(f"Plot saved to: {file_name}") - + plt.show() + def proportion_above_threshold( - df, - threshold_col='proportion_above_threshold', - feature_col='feature_name', - top_percentile=0.05, - bottom_percentile=0.05, - specific_transcripts=None, - figsize=(4, 10), - orientation='h', - bar_color="black", - title='Proportion of distant exRNa (>30um) from source', - xlabel='Proportion above threshold', - ylabel='Feature', + df, + threshold_col="proportion_above_threshold", + feature_col="feature_name", + top_percentile=0.05, + bottom_percentile=0.05, + specific_transcripts=None, + figsize=(4, 10), + orientation="h", + bar_color="black", + title="Proportion of distant exRNa (>30um) from source", + xlabel="Proportion above threshold", + ylabel="Feature", save=False, - output_path:str='',format='pdf' + output_path: str = "", + format="pdf", ): """ Plots the top and bottom percentiles of features with the highest and lowest proportions above a threshold, or visualizes a specific list of transcripts. - Parameters: + Parameters + ---------- - df: DataFrame containing feature proportions. - threshold_col: Column name for proportions above the threshold (default: 'proportion_above_threshold'). - feature_col: Column name for feature names (default: 'feature_name'). @@ -427,46 +504,55 @@ def proportion_above_threshold( - xlabel: Label for the x-axis (default: 'Proportion above threshold'). - ylabel: Label for the y-axis (default: 'Feature'). """ - df=df[~df[threshold_col].isna()] + df = df[~df[threshold_col].isna()] print(df.shape) # Filter for top and bottom percentiles if no specific transcripts are provided if specific_transcripts is None: top_cutoff = df[threshold_col].quantile(1 - top_percentile) bottom_cutoff = df[threshold_col].quantile(bottom_percentile) - plot_data = pd.concat([ - df[df[threshold_col] >= top_cutoff], # Top percentile - df[df[threshold_col] <= bottom_cutoff] # Bottom percentile - ]) + plot_data = pd.concat( + [ + df[df[threshold_col] >= top_cutoff], # Top percentile + df[df[threshold_col] <= bottom_cutoff], # Bottom percentile + ] + ) else: plot_data = df[df[feature_col].isin(specific_transcripts)] # Plot plt.figure(figsize=figsize) - if orientation=='h': - plt.barh(plot_data['feature_name'],plot_data[threshold_col],color=bar_color) - if orientation=='v': - plt.bar(plot_data['feature_name'],plot_data[threshold_col],color=bar_color) - + if orientation == "h": + plt.barh(plot_data["feature_name"], plot_data[threshold_col], color=bar_color) + if orientation == "v": + plt.bar(plot_data["feature_name"], plot_data[threshold_col], color=bar_color) + plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) - figures_path = os.path.join(output_path, 'figures') + figures_path = os.path.join(output_path, "figures") os.makedirs(figures_path, exist_ok=True) - filename=f'barplot_distant_from_source_min{bottom_percentile}_max{top_percentile}_{bar_color}' + filename = f"barplot_distant_from_source_min{bottom_percentile}_max{top_percentile}_{bar_color}" if save: - plt.savefig(os.path.join(figures_path, f'{filename}.{format}')) + plt.savefig(os.path.join(figures_path, f"{filename}.{format}")) plt.show() -def nmf_factors_exrna_cells_W(sdata, nmf_adata_key: str = 'nmf_data', save: bool = True, saving_path: str = '', - spot_size: int = 30, cmap: str = 'viridis'): + +def nmf_factors_exrna_cells_W( + sdata, + nmf_adata_key: str = "nmf_data", + save: bool = True, + saving_path: str = "", + spot_size: int = 30, + cmap: str = "viridis", +): """ Plot NMF factors for each cell in a spatial transcriptomics dataset. - This function extracts the NMF (Non-negative Matrix Factorization) factors from the specified AnnData object + This function extracts the NMF (Non-negative Matrix Factorization) factors from the specified AnnData object within the spatial data (`sdata`) and creates spatial plots for each factor. The plots can be displayed or saved to disk. - Parameters: - ----------- + Parameters + ---------- sdata : AnnData or SpatialData object A spatial transcriptomics dataset that contains the NMF factors in the specified key. nmf_adata_key : str, optional @@ -481,45 +567,67 @@ def nmf_factors_exrna_cells_W(sdata, nmf_adata_key: str = 'nmf_data', save: bool cmap : str, optional Colormap to use for the spatial plots. Defaults to 'viridis'. - Returns: - -------- + Returns + ------- None Displays the spatial plots for each NMF factor. If `save` is True, the plots are saved as PNG files. - Notes: - ------ + Notes + ----- - The NMF factors are expected to be stored in `adata.obsm['W_nmf']`, where `adata` is extracted from `sdata`. - A maximum of 20 factors is plotted by iterating through the columns of `W_nmf`. - When saving, each plot is named `spatialnmf{factor}.png` and stored in a `figures` directory inside `saving_path`. Example: -------- - >>> nmf_factors_exrna_cells_W(sdata, nmf_adata_key='nmf_data', save=True, saving_path='./results', spot_size=50, cmap='plasma') + >>> nmf_factors_exrna_cells_W( + ... sdata, nmf_adata_key="nmf_data", save=True, saving_path="./results", spot_size=50, cmap="plasma" + ... ) """ # Plot the factors for each cell in a spatial plot adata = sdata[nmf_adata_key] - W = adata.obsm['W_nmf'] + W = adata.obsm["W_nmf"] for factor in range(20): # Add the factor values to adata.obs for plotting - adata.obs[f'NMF_factor_{factor + 1}'] = W[:, factor] + adata.obs[f"NMF_factor_{factor + 1}"] = W[:, factor] # Plot spatial map of cells colored by this factor if save: - sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', - spot_size=30, show=False) - plt.savefig(saving_path + '/figures/' + f'spatialnmf{factor}.png') + sc.pl.spatial( + adata, + color=f"NMF_factor_{factor + 1}", + cmap=cmap, + title=f"NMF Factor {factor + 1}", + spot_size=30, + show=False, + ) + plt.savefig(saving_path + "/figures/" + f"spatialnmf{factor}.png") plt.show() else: - sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', - spot_size=spot_size) - -def nmf_gene_contributions(sdata, nmf_adata_key: str = 'nmf_data', save: bool = True, vmin: float = 0.0, vmax: float = 0.02, - saving_path: str = '', cmap: str = 'viridis', figsize: tuple = (5, 5)): + sc.pl.spatial( + adata, + color=f"NMF_factor_{factor + 1}", + cmap=cmap, + title=f"NMF Factor {factor + 1}", + spot_size=spot_size, + ) + + +def nmf_gene_contributions( + sdata, + nmf_adata_key: str = "nmf_data", + save: bool = True, + vmin: float = 0.0, + vmax: float = 0.02, + saving_path: str = "", + cmap: str = "viridis", + figsize: tuple = (5, 5), +): """Plot a heatmap of gene contributions to NMF factors. This function extracts the NMF (Non-negative Matrix Factorization) gene loadings matrix from the specified AnnData object within the spatial data (`sdata`), filters genes based on their maximum loading value, and plots a heatmap of the filtered loadings. - Parameters: - ----------- + Parameters + ---------- sdata : AnnData or SpatialData object A spatial transcriptomics dataset that contains the NMF factors in the specified key. nmf_adata_key : str, optional @@ -538,25 +646,27 @@ def nmf_gene_contributions(sdata, nmf_adata_key: str = 'nmf_data', save: bool = figsize : tuple, optional Size of the heatmap figure. Defaults to (5, 5). - Returns: - -------- + Returns + ------- None Displays a heatmap of gene contributions to NMF factors. If `save` is True, the heatmap is saved as a PDF file. - Notes: - ------ + Notes + ----- - The gene loadings matrix is expected to be stored in `adata.uns['H_nmf']`, where `adata` is extracted from `sdata`. - Genes with a maximum loading value greater than 0.05 are included in the heatmap. - The rows of the heatmap are sorted based on the factor with the highest contribution for each gene. Example: -------- - >>> nmf_gene_contributions(sdata, nmf_adata_key='nmf_data', save=True, saving_path='./results', cmap='plasma', figsize=(10, 8)) + >>> nmf_gene_contributions( + ... sdata, nmf_adata_key="nmf_data", save=True, saving_path="./results", cmap="plasma", figsize=(10, 8) + ... ) """ adata = sdata[nmf_adata_key] - loadings = pd.DataFrame(adata.uns['H_nmf'], columns=adata.var.index) + loadings = pd.DataFrame(adata.uns["H_nmf"], columns=adata.var.index) loadings_filtered = loadings.loc[:, np.max(loadings, axis=0) > 0.05].transpose() - figures_path = os.path.join(saving_path, 'figures') + figures_path = os.path.join(saving_path, "figures") os.makedirs(figures_path, exist_ok=True) # Sort by maximum feature in cell types @@ -572,38 +682,41 @@ def nmf_gene_contributions(sdata, nmf_adata_key: str = 'nmf_data', save: bool = plt.show() plt.close() # Close the figure to avoid memory issues + def apply_exrnaH_to_cellular_to_create_cellularW(adata_extracellular_with_nmf, adata_annotated_cellular): """Apply extracellular RNA NMF loadings (H) to cellular data to generate cellular NMF factors (W). This function transfers the gene loadings (H matrix) derived from extracellular RNA analysis to a cellular dataset. It calculates the new W matrix for cellular data by multiplying the gene expression values of the cellular dataset with the filtered H matrix. - Parameters: - ----------- + Parameters + ---------- adata_extracellular_with_nmf : AnnData - An AnnData object containing the extracellular RNA data with the NMF results. + An AnnData object containing the extracellular RNA data with the NMF results. The H matrix is expected to be stored in `adata.uns['H_nmf']`. adata_annotated_cellular : AnnData An AnnData object containing the cellular RNA data with annotated gene expression values. - Returns: - -------- + Returns + ------- AnnData The input `adata_annotated_cellular` object with the following updates: - Adds the calculated NMF factors (W matrix) as a DataFrame to `adata.obsm['factors']`. - Adds each NMF factor as individual columns in `adata.obs` with names `NMF_factor_1`, `NMF_factor_2`, etc. - Notes: - ------ + Notes + ----- - Only the genes common between the extracellular RNA data and the cellular data are used for the computation. - The gene intersection ensures compatibility between the NMF H matrix and the cellular gene expression matrix. Example: -------- - >>> adata_cellular = apply_exrnaH_to_cellular_to_create_cellularW(adata_extracellular_with_nmf, adata_annotated_cellular) + >>> adata_cellular = apply_exrnaH_to_cellular_to_create_cellularW( + ... adata_extracellular_with_nmf, adata_annotated_cellular + ... ) """ # Extract the H matrix (NMF gene loadings) from the extracellular data - H = adata_extracellular_with_nmf.uns['H_nmf'] + H = adata_extracellular_with_nmf.uns["H_nmf"] # Check the genes in both datasets genes_spots2region = adata_extracellular_with_nmf.var_names @@ -620,35 +733,34 @@ def apply_exrnaH_to_cellular_to_create_cellularW(adata_extracellular_with_nmf, a W_annotated = adata_annotated_cellular.X @ H_filtered.T # Store the W matrix in the obsm attribute as a DataFrame - adata_annotated_cellular.obsm['factors'] = pd.DataFrame( - W_annotated, index=adata_annotated_cellular.obs.index - ) + adata_annotated_cellular.obsm["factors"] = pd.DataFrame(W_annotated, index=adata_annotated_cellular.obs.index) # Add individual NMF factors to adata.obs for factor in range(W_annotated.shape[1]): - adata_annotated_cellular.obs[f'NMF_factor_{factor + 1}'] = W_annotated[:, factor] + adata_annotated_cellular.obs[f"NMF_factor_{factor + 1}"] = W_annotated[:, factor] return adata_annotated_cellular + def paired_nmf_factors( - sdata, - layer='nmf_data', + sdata, + layer="nmf_data", n_factors=5, # Number of NMF factors to plot figsize=(12, 6), # Size of the figure spot_size_exrna=5, # Spot size for extracellular transcripts spot_size_cells=10, # Spot size for cell map - cmap_exrna='YlGnBu', # Colormap for extracellular transcripts - cmap_cells='Reds', # Colormap for cells - vmax_exrna='p99', # Maximum value for color scale (extracellular) + cmap_exrna="YlGnBu", # Colormap for extracellular transcripts + cmap_cells="Reds", # Colormap for cells + vmax_exrna="p99", # Maximum value for color scale (extracellular) vmax_cells=None, # Maximum value for color scale (cells) save=False, - output_path:str='', - format='pdf' + output_path: str = "", + format="pdf", ): """ Plots the spatial distribution of NMF factors for extracellular transcripts and cells. - Parameters: + Parameters ---------- sdata : spatial data object The spatial data object containing both extracellular and cell data. @@ -680,57 +792,67 @@ def paired_nmf_factors( vmax_cells : str or float, optional Maximum value for cell color scale (default: None). """ - # Extract NMF data from sdata adata = sdata[layer] - adata_annotated = sdata['table'] - + adata_annotated = sdata["table"] + # Get the factors from the obsm attribute (NMF results) - factors = pd.DataFrame(adata.obsm['W_nmf'], index=adata.obs.index) - factors.columns = [f'NMF_factor_{fact+1}' for fact in factors.columns] - + factors = pd.DataFrame(adata.obsm["W_nmf"], index=adata.obs.index) + factors.columns = [f"NMF_factor_{fact+1}" for fact in factors.columns] + # Add each NMF factor to adata.obs for f in factors.columns: adata.obs[f] = factors[f] - + # Loop over the specified number of NMF factors and plot for factor in range(n_factors): - factor_name = f'NMF_factor_{factor + 1}' - + factor_name = f"NMF_factor_{factor + 1}" + # Create a figure with a single subplot for each factor fig, axs = plt.subplots(1, 1, figsize=figsize) - + # Plot the spatial distribution for extracellular transcripts sc.pl.spatial( - adata, color=factor_name, cmap=cmap_exrna, - title=f'NMF Factor {factor + 1} (Extracellular)', - ax=axs, show=False, spot_size=spot_size_exrna, vmax=vmax_exrna + adata, + color=factor_name, + cmap=cmap_exrna, + title=f"NMF Factor {factor + 1} (Extracellular)", + ax=axs, + show=False, + spot_size=spot_size_exrna, + vmax=vmax_exrna, ) - + # Overlay the cell spatial distribution sc.pl.spatial( - adata_annotated, color=factor_name, cmap=cmap_cells, - title=f'NMF Factor cell-red/exRNa-blue {factor + 1}', - ax=axs, show=False, spot_size=spot_size_cells, vmax=vmax_cells + adata_annotated, + color=factor_name, + cmap=cmap_cells, + title=f"NMF Factor cell-red/exRNa-blue {factor + 1}", + ax=axs, + show=False, + spot_size=spot_size_cells, + vmax=vmax_cells, ) if save: - figures_path = os.path.join(output_path, 'figures') + figures_path = os.path.join(output_path, "figures") os.makedirs(figures_path, exist_ok=True) - file_name = os.path.join(figures_path, f'Spatial_NMF Factor {factor + 1}.{format}') + file_name = os.path.join(figures_path, f"Spatial_NMF Factor {factor + 1}.{format}") plt.savefig(file_name) # Adjust layout and show the combined plot plt.tight_layout() plt.show() + def plot_nmf_factors_spatial(adata, n_factors, save=True): """ Plot spatial maps of cells colored by NMF factors. This function visualizes the spatial distribution of cells, colored by their corresponding NMF factor values, stored in `adata.obs`. It iterates over all specified NMF factors and generates spatial plots for each factor. - Parameters: - ----------- + Parameters + ---------- adata : AnnData An AnnData object containing the dataset with NMF factors already added as columns in `adata.obs`.Each factor should be named `NMF_factor_1`, `NMF_factor_2`, ..., `NMF_factor_n`. n_factors : int @@ -738,13 +860,13 @@ def plot_nmf_factors_spatial(adata, n_factors, save=True): save : bool, optional (default=True) If `True`, saves the plots to files with filenames `exo_to_cell_spatial_.png`. - Returns: - -------- + Returns + ------- None This function does not return anything but generates and optionally saves spatial plots. - Notes: - ------ + Notes + ----- - The plots are colored using the 'plasma' colormap. - The spot size for the spatial plots is set to 15 by default. - Files are saved in the current working directory unless specified otherwise using `sc.settings.figdir`. @@ -756,41 +878,42 @@ def plot_nmf_factors_spatial(adata, n_factors, save=True): for factor in range(n_factors): sc.pl.spatial( adata, - color=f'NMF_factor_{factor + 1}', - cmap='plasma', - title=f'NMF Factor {factor + 1}', + color=f"NMF_factor_{factor + 1}", + cmap="plasma", + title=f"NMF Factor {factor + 1}", spot_size=15, - save=f'exo_to_cell_spatial_{factor}.png' if save else None + save=f"exo_to_cell_spatial_{factor}.png" if save else None, ) + def spatial_interactions( sdata: AnnData, - layer: str = 'extracellular_transcripts_enriched', - gene: str = 'Arc', - gene_key: str = 'feature_name', - cell_id_key: str = 'cell_id', - color_target:str='blue', - color_source:str='red', - color_transcript:str='green', - spatial_key: str = 'spatial', - img: Optional[Union[bool, Sequence]] = None, - img_alpha: Optional[float] = None, - image_cmap: Optional[Colormap] = None, - size: Optional[Union[float, Sequence[float]]] = 8, + layer: str = "extracellular_transcripts_enriched", + gene: str = "Arc", + gene_key: str = "feature_name", + cell_id_key: str = "cell_id", + color_target: str = "blue", + color_source: str = "red", + color_transcript: str = "green", + spatial_key: str = "spatial", + img: bool | Sequence | None = None, + img_alpha: float | None = None, + image_cmap: Colormap | None = None, + size: float | Sequence[float] | None = 8, alpha: float = 0.6, - title: Optional[Union[str, Sequence[str]]] = None, - legend_loc: Optional[str] = 'best', - figsize: Tuple[float, float] = (10, 10), - dpi: Optional[int] = 100, - save: Optional[Union[str, Path]] = None, - **kwargs + title: str | Sequence[str] | None = None, + legend_loc: str | None = "best", + figsize: tuple[float, float] = (10, 10), + dpi: int | None = 100, + save: str | Path | None = None, + **kwargs, ): """ Visualizes the spatial interactions of extracellular RNA and associated cells. This function generates a scatter plot showing the positions of target cells, source cells, and extracellular RNA transcripts within a spatial omics dataset. The target and source cells are highlighted in different colors, while the RNA transcripts are shown as points at their respective positions. Optionally, a background image (e.g., tissue section) can be displayed. - Parameters: + Parameters ---------- sdata : AnnData An AnnData object containing the spatial omics data, including transcript expression and cell positions. @@ -852,7 +975,7 @@ def spatial_interactions( **kwargs : Additional keyword arguments Any additional arguments passed to the `scatter` or `imshow` functions for customizing plot appearance. - Returns: + Returns ------- None The function generates and displays (or saves) a scatter plot. @@ -860,58 +983,76 @@ def spatial_interactions( # Extract relevant data transcripts = sdata.points[layer] trans_filt = transcripts[transcripts[gene_key] == gene] - target_cells = trans_filt['closest_target_cell'].compute() - source_cells = trans_filt['closest_source_cell'].compute() - cell_positions = pd.DataFrame(sdata['table'].obsm[spatial_key], index=sdata.table.obs[cell_id_key], columns=['x', 'y']) + target_cells = trans_filt["closest_target_cell"].compute() + source_cells = trans_filt["closest_source_cell"].compute() + cell_positions = pd.DataFrame( + sdata["table"].obsm[spatial_key], index=sdata.table.obs[cell_id_key], columns=["x", "y"] + ) # Plotting plt.figure(figsize=figsize, dpi=dpi) if img is not None: plt.imshow(img, alpha=img_alpha, cmap=image_cmap, **kwargs) - plt.scatter(cell_positions['x'], cell_positions['y'], c='grey', s=0.6, alpha=alpha, **kwargs) - plt.scatter(cell_positions.loc[target_cells, 'x'], cell_positions.loc[target_cells, 'y'], c=color_target, s=size, label='Target Cells', **kwargs) - plt.scatter(cell_positions.loc[source_cells, 'x'], cell_positions.loc[source_cells, 'y'], c=color_source, s=size, label='Source Cells', **kwargs) - plt.scatter(trans_filt['x'], trans_filt['y'], c=color_transcript, s=size*0.4, label='Transcripts', **kwargs) - + plt.scatter(cell_positions["x"], cell_positions["y"], c="grey", s=0.6, alpha=alpha, **kwargs) + plt.scatter( + cell_positions.loc[target_cells, "x"], + cell_positions.loc[target_cells, "y"], + c=color_target, + s=size, + label="Target Cells", + **kwargs, + ) + plt.scatter( + cell_positions.loc[source_cells, "x"], + cell_positions.loc[source_cells, "y"], + c=color_source, + s=size, + label="Source Cells", + **kwargs, + ) + plt.scatter(trans_filt["x"], trans_filt["y"], c=color_transcript, s=size * 0.4, label="Transcripts", **kwargs) + # Titles and Legends plt.title(title or gene) plt.legend(loc=legend_loc) - plt.xlabel('X Position') - plt.ylabel('Y Position') + plt.xlabel("X Position") + plt.ylabel("Y Position") # Save the plot if path provided if save: plt.savefig(save) plt.show() + def interactions_with_arrows( sdata: AnnData, - layer: str = 'extracellular_transcripts_enriched', - gene: str = 'Arc', - gene_key: str = 'feature_name', - cell_id_key: str = 'cell_id', - color_target: str = 'blue', - color_source: str = 'red', - color_transcript: str = 'green', - spatial_key: str = 'spatial', - img: Optional[Union[bool, Sequence]] = None, - img_alpha: Optional[float] = None, - image_cmap: Optional[Colormap] = None, - size: Optional[Union[float, Sequence[float]]] = 8, + layer: str = "extracellular_transcripts_enriched", + gene: str = "Arc", + gene_key: str = "feature_name", + cell_id_key: str = "cell_id", + color_target: str = "blue", + color_source: str = "red", + color_transcript: str = "green", + spatial_key: str = "spatial", + img: bool | Sequence | None = None, + img_alpha: float | None = None, + image_cmap: Colormap | None = None, + size: float | Sequence[float] | None = 8, alpha: float = 0.6, - title: Optional[Union[str, Sequence[str]]] = None, - legend_loc: Optional[str] = 'best', - figsize: Tuple[float, float] = (10, 10), - dpi: Optional[int] = 100, - save: Optional[Union[str, Path]] = None, - **kwargs + title: str | Sequence[str] | None = None, + legend_loc: str | None = "best", + figsize: tuple[float, float] = (10, 10), + dpi: int | None = 100, + save: str | Path | None = None, + **kwargs, ): """ Visualizes interactions between source and target cells using arrows, along with transcript locations. - - The function plots arrows from source to target cells based on transcript proximity, color-coding source and target cells, and transcript locations. An optional image layer can be overlaid behind the plot. - Parameters: + The function plots arrows from source to target cells based on transcript proximity, color-coding source and target cells, and transcript locations. An optional image layer can be overlaid behind the plot. + + Parameters + ---------- sdata (AnnData): The AnnData object containing the spatial omics data. layer (str, optional): The key in `sdata` for the extracellular transcript layer to analyze. Default is 'extracellular_transcripts_enriched'. gene (str, optional): The gene of interest. Default is 'Arc'. @@ -933,19 +1074,22 @@ def interactions_with_arrows( save (Optional[Union[str, Path]], optional): If provided, the path where the plot will be saved. **kwargs: Additional arguments passed to the `scatter` and `imshow` functions for customization. - Returns: + Returns + ------- None: The function displays or saves a plot of interactions between cells and transcripts. - Notes: + Notes + ----- The plot will show arrows from source to target cells, with different colors for source, target, and transcript points. """ - # Extract relevant data transcripts = sdata.points[layer] trans_filt = transcripts[transcripts[gene_key] == gene] - target_cells = trans_filt['closest_target_cell'].compute() - source_cells = trans_filt['closest_source_cell'].compute() - cell_positions = pd.DataFrame(sdata['table'].obsm[spatial_key], index=sdata.table.obs[cell_id_key], columns=['x', 'y']) + target_cells = trans_filt["closest_target_cell"].compute() + source_cells = trans_filt["closest_source_cell"].compute() + cell_positions = pd.DataFrame( + sdata["table"].obsm[spatial_key], index=sdata.table.obs[cell_id_key], columns=["x", "y"] + ) # Plotting plt.figure(figsize=figsize, dpi=dpi) @@ -953,26 +1097,49 @@ def interactions_with_arrows( plt.imshow(img, alpha=img_alpha, cmap=image_cmap, **kwargs) # Plot arrows between each paired source and target cell - for source, target in zip(source_cells, target_cells): + for source, target in zip(source_cells, target_cells, strict=False): if source in cell_positions.index and target in cell_positions.index: if source != target: - x_start, y_start = cell_positions.loc[source, 'x'], cell_positions.loc[source, 'y'] - x_end, y_end = cell_positions.loc[target, 'x'], cell_positions.loc[target, 'y'] - plt.arrow(x_start, y_start, x_end - x_start, y_end - y_start, color='black', alpha=0.8, head_width=8, head_length=8) - + x_start, y_start = cell_positions.loc[source, "x"], cell_positions.loc[source, "y"] + x_end, y_end = cell_positions.loc[target, "x"], cell_positions.loc[target, "y"] + plt.arrow( + x_start, + y_start, + x_end - x_start, + y_end - y_start, + color="black", + alpha=0.8, + head_width=8, + head_length=8, + ) + # Plot source and target cells - plt.scatter(cell_positions['x'], cell_positions['y'], c='grey', s=0.6, alpha=alpha, **kwargs) - plt.scatter(cell_positions.loc[target_cells, 'x'], cell_positions.loc[target_cells, 'y'], c=color_target, s=size, label='Target Cells', **kwargs) - plt.scatter(cell_positions.loc[source_cells, 'x'], cell_positions.loc[source_cells, 'y'], c=color_source, s=size, label='Source Cells', **kwargs) - plt.scatter(trans_filt['x'], trans_filt['y'], c=color_transcript, s=size * 0.4, label='Transcripts', **kwargs) - + plt.scatter(cell_positions["x"], cell_positions["y"], c="grey", s=0.6, alpha=alpha, **kwargs) + plt.scatter( + cell_positions.loc[target_cells, "x"], + cell_positions.loc[target_cells, "y"], + c=color_target, + s=size, + label="Target Cells", + **kwargs, + ) + plt.scatter( + cell_positions.loc[source_cells, "x"], + cell_positions.loc[source_cells, "y"], + c=color_source, + s=size, + label="Source Cells", + **kwargs, + ) + plt.scatter(trans_filt["x"], trans_filt["y"], c=color_transcript, s=size * 0.4, label="Transcripts", **kwargs) + # Titles and Legends plt.title(title or gene) plt.legend(loc=legend_loc) - plt.xlabel('X Position') - plt.ylabel('Y Position') + plt.xlabel("X Position") + plt.ylabel("Y Position") # Save the plot if path provided if save: plt.savefig(save) - plt.show() \ No newline at end of file + plt.show() diff --git a/src/troutpy/pp/__init__.py b/src/troutpy/pp/__init__.py index 557b432..3c4e10e 100644 --- a/src/troutpy/pp/__init__.py +++ b/src/troutpy/pp/__init__.py @@ -1,2 +1,2 @@ -from .compute import compute_extracellular_counts,define_extracellular,compute_crosstab -from .format import format_adata \ No newline at end of file +from .compute import compute_crosstab, compute_extracellular_counts, define_extracellular +from .format import format_adata diff --git a/src/troutpy/pp/compute.py b/src/troutpy/pp/compute.py index 4687224..8c6ba28 100644 --- a/src/troutpy/pp/compute.py +++ b/src/troutpy/pp/compute.py @@ -1,49 +1,49 @@ -import os -import pandas as pd import numpy as np -import seaborn as sns -import matplotlib.pyplot as plt -import scanpy as sc +import pandas as pd import spatialdata as sd -from typing import List, Union, Tuple -def compute_extracellular_counts(data_extracell): # would be good to change the name of this function + +def compute_extracellular_counts(data_extracell): # would be good to change the name of this function """ Compute observed, expected, and fold ratio for extracellular transcript counts. - Parameters: + Parameters + ---------- data_extracell (pd.DataFrame): Data with extracellular transcripts. - Returns: + Returns + ------- pd.DataFrame: Dataframe with observed, expected counts, fold ratios, and gene categories. """ - extracellular_counts = data_extracell.groupby('feature_name').count() - extracellular_counts = pd.DataFrame({'observed': extracellular_counts.iloc[:, 0]}) - extracellular_counts['expected'] = int(extracellular_counts['observed'].sum() / extracellular_counts.shape[0]) - + extracellular_counts = data_extracell.groupby("feature_name").count() + extracellular_counts = pd.DataFrame({"observed": extracellular_counts.iloc[:, 0]}) + extracellular_counts["expected"] = int(extracellular_counts["observed"].sum() / extracellular_counts.shape[0]) + # Calculate fold ratios - extracellular_counts['fold_ratio'] = np.log(extracellular_counts['observed'] / extracellular_counts['expected']) - + extracellular_counts["fold_ratio"] = np.log(extracellular_counts["observed"] / extracellular_counts["expected"]) + # Map gene categories - gene2cat = dict(zip(data_extracell['feature_name'], data_extracell['codeword_category'])) - extracellular_counts['codeword_category'] = extracellular_counts.index.map(gene2cat) - + gene2cat = dict(zip(data_extracell["feature_name"], data_extracell["codeword_category"], strict=False)) + extracellular_counts["codeword_category"] = extracellular_counts.index.map(gene2cat) + return extracellular_counts + def define_extracellular( - sdata, - layer: str = 'transcripts', - method: str = 'segmentation_free', - min_prop_of_extracellular: float = 0.8, - unassigned_to_cell_tag: str = 'UNASSIGNED', - copy: bool = False + sdata, + layer: str = "transcripts", + method: str = "segmentation_free", + min_prop_of_extracellular: float = 0.8, + unassigned_to_cell_tag: str = "UNASSIGNED", + copy: bool = False, ): """ Define extracellular transcripts in spatial omics data. This function identifies extracellular transcripts based on the specified method and updates the spatial data object accordingly. - Parameters: + Parameters + ---------- sdata : SpatialData A spatial data object containing transcriptomic information. layer : str, optional (default: 'transcripts') @@ -54,52 +54,52 @@ def define_extracellular( - 'nuclei': Uses overlap with nuclear annotations to classify extracellular transcripts. - 'cells': Classifies transcripts not assigned to a cell as extracellular. min_prop_of_extracellular : float, optional (default: 0.8) - Minimum proportion of transcripts in a cluster required to be extracellular for + Minimum proportion of transcripts in a cluster required to be extracellular for it to be classified as such (used only with 'segmentation_free' method). unassigned_to_cell_tag : str, optional (default: 'UNASSIGNED') Tag indicating transcripts not assigned to any cell. copy : bool, optional (default: False) - If True, returns a copy of the updated spatial data. + If True, returns a copy of the updated spatial data. If False, updates the `sdata` object in-place. - Returns: + Returns + ------- Optional[SpatialData]: If `copy` is True, returns a copy of the updated `sdata` object. Otherwise, updates the `sdata` object in-place and returns None. - Notes: + Notes + ----- - The 'segmentation_free' method uses clustering results to determine extracellular transcripts. - The 'nuclei' method assumes transcripts outside nuclei are extracellular. - The 'cells' method classifies transcripts unassigned to cells as extracellular. Example: ```python - updated_sdata = define_extracellular( - sdata, method='segmentation_free', min_prop_of_extracellular=0.9, copy=True - ) + updated_sdata = define_extracellular(sdata, method="segmentation_free", min_prop_of_extracellular=0.9, copy=True) ``` """ # Compute the data layer data = sdata.points[layer].compute() # Method: Segmentation-free clustering - if method == 'segmentation_free': - data['overlaps_cell'] = (data['cell_id'] != unassigned_to_cell_tag).astype(int) - overlapping_cell = pd.crosstab(data['segmentation_free_clusters'], data['overlaps_cell']) + if method == "segmentation_free": + data["overlaps_cell"] = (data["cell_id"] != unassigned_to_cell_tag).astype(int) + overlapping_cell = pd.crosstab(data["segmentation_free_clusters"], data["overlaps_cell"]) # Compute proportions and define extracellular clusters cluster_totals = overlapping_cell.sum(axis=1) cluster_proportions = overlapping_cell.div(cluster_totals, axis=0) - extracellular_clusters = cluster_proportions[cluster_proportions.loc[:,0] >= min_prop_of_extracellular].index - data['extracellular'] = ~data['segmentation_free_clusters'].isin(extracellular_clusters) + extracellular_clusters = cluster_proportions[cluster_proportions.loc[:, 0] >= min_prop_of_extracellular].index + data["extracellular"] = ~data["segmentation_free_clusters"].isin(extracellular_clusters) # Method: Based on nuclei overlap - elif method == 'nuclei': - data['extracellular'] = data['overlaps_nucleus'] != 1 + elif method == "nuclei": + data["extracellular"] = data["overlaps_nucleus"] != 1 # Method: Based on cell assignment - elif method == 'cells': - data['extracellular'] = data['cell_id'] == unassigned_to_cell_tag + elif method == "cells": + data["extracellular"] = data["cell_id"] == unassigned_to_cell_tag # Unsupported method else: @@ -110,11 +110,12 @@ def define_extracellular( return sdata if copy else None -def compute_crosstab(data, xvar: str = '', yvar: str = ''): + +def compute_crosstab(data, xvar: str = "", yvar: str = ""): """Compute a crosstabulation (contingency table) of two categorical variables from the given DataFrame. - Parameters: - ----------- + Parameters + ---------- data : pandas.DataFrame The input DataFrame containing the data to be analyzed. xvar : str, optional @@ -122,8 +123,8 @@ def compute_crosstab(data, xvar: str = '', yvar: str = ''): yvar : str, optional The name of the column to use as the columns of the crosstab. Default is an empty string. - Returns: - -------- + Returns + ------- pandas.DataFrame A DataFrame representing the crosstab of the specified variables, with counts of occurrences for each combination of categories. """ diff --git a/src/troutpy/pp/format.py b/src/troutpy/pp/format.py index 7b07eff..dbae3c0 100644 --- a/src/troutpy/pp/format.py +++ b/src/troutpy/pp/format.py @@ -1,11 +1,8 @@ import os -import scanpy as sc + +import matplotlib.pyplot as plt import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns -import os -import spatialdata as sd +import scanpy as sc def format_adata(input_path, outpath_dummy, xlimits, ylimits): @@ -21,7 +18,8 @@ def format_adata(input_path, outpath_dummy, xlimits, ylimits): xlimits (list or tuple of two ints): Spatial limits for the x-coordinate filtering [min_x, max_x]. ylimits (list or tuple of two ints): Spatial limits for the y-coordinate filtering [min_y, max_y]. - Raises: + Raises + ------ FileNotFoundError: If required input files are not found in the input_path. ValueError: If xlimits or ylimits are not properly defined. """ @@ -35,9 +33,9 @@ def format_adata(input_path, outpath_dummy, xlimits, ylimits): os.makedirs(outpath_dummy, exist_ok=True) # Define input file paths - cell_feature_matrix_path = os.path.join(input_path, 'cell_feature_matrix.h5') - cells_parquet_path = os.path.join(input_path, 'cells.parquet') - transcripts_parquet_path = os.path.join(input_path, 'transcripts.parquet') + cell_feature_matrix_path = os.path.join(input_path, "cell_feature_matrix.h5") + cells_parquet_path = os.path.join(input_path, "cells.parquet") + transcripts_parquet_path = os.path.join(input_path, "transcripts.parquet") # Check if input files exist for file_path in [cell_feature_matrix_path, cells_parquet_path, transcripts_parquet_path]: @@ -54,12 +52,12 @@ def format_adata(input_path, outpath_dummy, xlimits, ylimits): # Merge cell information into adata.obs print("Merging cell information into AnnData object...") - adata.obs['cell_id'] = adata.obs.index.astype(str) - adata.obs = pd.merge(adata.obs, cells, on='cell_id', how='left') + adata.obs["cell_id"] = adata.obs.index.astype(str) + adata.obs = pd.merge(adata.obs, cells, on="cell_id", how="left") # Add spatial coordinates to adata.obsm print("Adding spatial coordinates to AnnData object...") - adata.obsm['spatial'] = adata.obs[['x_centroid', 'y_centroid']].values + adata.obsm["spatial"] = adata.obs[["x_centroid", "y_centroid"]].values # Load transcripts data print("Loading transcripts data from parquet file...") @@ -68,41 +66,40 @@ def format_adata(input_path, outpath_dummy, xlimits, ylimits): # Apply spatial filters to AnnData print("Applying spatial filters to AnnData object...") spatial_filter = ( - (adata.obs['x_centroid'] > xlimits[0]) & - (adata.obs['x_centroid'] < xlimits[1]) & - (adata.obs['y_centroid'] > ylimits[0]) & - (adata.obs['y_centroid'] < ylimits[1]) + (adata.obs["x_centroid"] > xlimits[0]) + & (adata.obs["x_centroid"] < xlimits[1]) + & (adata.obs["y_centroid"] > ylimits[0]) + & (adata.obs["y_centroid"] < ylimits[1]) ) - adata.obs.index=adata.obs.index.astype(str) + adata.obs.index = adata.obs.index.astype(str) adata_filtered = adata[spatial_filter].copy() # Apply spatial filters to transcripts print("Applying spatial filters to transcripts data...") transcripts_filtered = transcripts[ - (transcripts['x_location'] > xlimits[0]) & - (transcripts['x_location'] < xlimits[1]) & - (transcripts['y_location'] > ylimits[0]) & - (transcripts['y_location'] < ylimits[1]) + (transcripts["x_location"] > xlimits[0]) + & (transcripts["x_location"] < xlimits[1]) + & (transcripts["y_location"] > ylimits[0]) + & (transcripts["y_location"] < ylimits[1]) ].copy() # Save the processed AnnData object - adata_output_path = os.path.join(outpath_dummy, 'adata_raw.h5ad') + adata_output_path = os.path.join(outpath_dummy, "adata_raw.h5ad") print(f"Saving processed AnnData to {adata_output_path}...") adata_filtered.write(adata_output_path) # Save the filtered transcripts - transcripts_output_path = os.path.join(outpath_dummy, 'transcripts.parquet') + transcripts_output_path = os.path.join(outpath_dummy, "transcripts.parquet") print(f"Saving filtered transcripts to {transcripts_output_path}...") transcripts_filtered.to_parquet(transcripts_output_path) # Optional: Plot spatial data print("Generating spatial plot...") - sc.pl.spatial(adata_filtered, color='transcript_counts', spot_size=50) + sc.pl.spatial(adata_filtered, color="transcript_counts", spot_size=50) print("Processing complete.") - #selected roi - #selected roi + # selected roi + # selected roi plt.figure() - plt.scatter(adata_filtered.obs['x_centroid'],adata_filtered.obs['y_centroid'],s=4,c='red') - plt.scatter(transcripts_filtered['x_location'],transcripts_filtered['y_location'],s=0.0001) - return adata_filtered,transcripts_filtered - + plt.scatter(adata_filtered.obs["x_centroid"], adata_filtered.obs["y_centroid"], s=4, c="red") + plt.scatter(transcripts_filtered["x_location"], transcripts_filtered["y_location"], s=0.0001) + return adata_filtered, transcripts_filtered diff --git a/src/troutpy/read/__init__.py b/src/troutpy/read/__init__.py index bc315ce..f761658 100644 --- a/src/troutpy/read/__init__.py +++ b/src/troutpy/read/__init__.py @@ -1 +1 @@ -#from .read import * \ No newline at end of file +# from .read import * diff --git a/src/troutpy/tl/NMF.py b/src/troutpy/tl/NMF.py index 5f9312d..635523a 100644 --- a/src/troutpy/tl/NMF.py +++ b/src/troutpy/tl/NMF.py @@ -1,157 +1,159 @@ +import os + import numpy as np import pandas as pd -import sys -import pandas as pd -import numpy as np import scanpy as sc -import os from sklearn.decomposition import NMF -import matplotlib.pyplot as plt -import numpy as np -import seaborn as sns -def apply_nmf_to_adata(adata, n_components=20, subsample_percentage=1.0,save=False,output_path:str='',random_state=None): + +def apply_nmf_to_adata( + adata, n_components=20, subsample_percentage=1.0, save=False, output_path: str = "", random_state=None +): """ Applies Non-Negative Matrix Factorization (NMF) to an AnnData object. - This function performs NMF on the expression matrix (`adata.X`) to extract - a reduced number of latent factors that describe the gene expression profiles - of cells. The number of factors is specified by `n_components`. Optionally, + This function performs NMF on the expression matrix (`adata.X`) to extract + a reduced number of latent factors that describe the gene expression profiles + of cells. The number of factors is specified by `n_components`. Optionally, the data can be subsampled before applying NMF. """ - # Extract the cell count matrix (X) from AnnData object # Assuming that adata.X contains the raw counts for cells sc.pp.subsample(adata, subsample_percentage) counts = adata.X.copy() - + # Perform NMF with 20 factors - nmf_model = NMF(n_components=n_components, init='random', random_state=42) + nmf_model = NMF(n_components=n_components, init="random", random_state=42) W = nmf_model.fit_transform(counts) # Cell factors H = nmf_model.components_ # Gene loadings - + # Add NMF results to the AnnData object - adata.obsm['W_nmf'] = W # Add the cell factors to the AnnData object - adata.uns['H_nmf'] = H + adata.obsm["W_nmf"] = W # Add the cell factors to the AnnData object + adata.uns["H_nmf"] = H if save: - H=pd.DataFrame(adata.uns['H_nmf'],columns=adata.var.index) - H.to_parquet(os.path.join(output_path,'factor_loadings_H_per_gene.parquet')) - W=pd.DataFrame(adata.obsm['W_nmf'],index=adata.obs.index) - W.to_parquet(os.path.join(output_path,'factor_scores_W_per_cell.parquet')) + H = pd.DataFrame(adata.uns["H_nmf"], columns=adata.var.index) + H.to_parquet(os.path.join(output_path, "factor_loadings_H_per_gene.parquet")) + W = pd.DataFrame(adata.obsm["W_nmf"], index=adata.obs.index) + W.to_parquet(os.path.join(output_path, "factor_scores_W_per_cell.parquet")) return adata -def nmf( - sdata, layer='extracellular_transcripts_enriched', - feature_key='feature_name', bin_key='bin_id', - density_table_key='segmentation_free_table', - n_components=20, subsample_percentage=0.1, - random_state=None,all=False): +def nmf( + sdata, + layer="extracellular_transcripts_enriched", + feature_key="feature_name", + bin_key="bin_id", + density_table_key="segmentation_free_table", + n_components=20, + subsample_percentage=0.1, + random_state=None, + all=False, +): """ Applies Non-negative Matrix Factorization (NMF) on filtered data based on feature_name and bin_id. - Parameters: + Parameters ---------- sdata : spatial data object Input spatial data containing transcript and bin data. - + layer : str, optional Layer name of the data that contains extracellular transcripts (default: 'extracellular_transcripts_enriched'). - + feature_key : str, optional Column name for the transcript feature (default: 'feature_name'). - + bin_key : str, optional Column name for bin IDs (default: 'bin_id'). - + density_table_key : str, optional Key to retrieve the density table from sdata (default: 'segmentation_free_table'). - + n_components : int, optional Number of components for NMF (default: 20). - + subsample_percentage : float, optional Percentage of data to use for NMF (default: 0.1). - + random_state : int, optional Random state for NMF initialization for reproducibility (default: None). - Returns: + Returns ------- sdata : Updated spatial data object with NMF components stored. """ - if all==False: - # Extract the DataFrame with feature_name and bin_id - df = sdata.points[layer][[feature_key, bin_key]].compute() - # Filter the density table to include only the relevant bin_ids and feature_names - filtered_bin_ids = df[bin_key].astype(int).astype(str).unique() - filtered_feature_name_ids = df[feature_key].astype(str).unique() - # Filter adata_density to only include the bins and features present in df - adata_density_raw = sdata[density_table_key] - adata_density = adata_density_raw[adata_density_raw.obs.index.astype(str).isin(filtered_bin_ids),:] - adata_density = adata_density[:, adata_density.var.index.astype(str).isin(filtered_feature_name_ids)] - # Retrieve the segmentation-free density table + if all == False: + # Extract the DataFrame with feature_name and bin_id + df = sdata.points[layer][[feature_key, bin_key]].compute() + # Filter the density table to include only the relevant bin_ids and feature_names + filtered_bin_ids = df[bin_key].astype(int).astype(str).unique() + filtered_feature_name_ids = df[feature_key].astype(str).unique() + # Filter adata_density to only include the bins and features present in df + adata_density_raw = sdata[density_table_key] + adata_density = adata_density_raw[adata_density_raw.obs.index.astype(str).isin(filtered_bin_ids), :] + adata_density = adata_density[:, adata_density.var.index.astype(str).isin(filtered_feature_name_ids)] + # Retrieve the segmentation-free density table else: adata_density = sdata[density_table_key] # Apply NMF to filtered data adata_nmf = apply_nmf_to_adata( - adata_density, - n_components=n_components, - subsample_percentage=subsample_percentage, - random_state=random_state - ) # This function adds adata.obsm['W_nmf'] and adata.uns['H_nmf'] - + adata_density, n_components=n_components, subsample_percentage=subsample_percentage, random_state=random_state + ) # This function adds adata.obsm['W_nmf'] and adata.uns['H_nmf'] + # Store the NMF results in the spatial data - sdata['nmf_data'] = adata_nmf - + sdata["nmf_data"] = adata_nmf + return sdata -def apply_exrna_factors_to_cells(sdata, layer_factors='nmf_data'): + +def apply_exrna_factors_to_cells(sdata, layer_factors="nmf_data"): """Applies extracellular RNA (exRNA) factor loadings to cellular annotation data based on NMF factors. This function extracts extracellular RNA data and associated NMF factor loadings, intersects the gene annotations between the extracellular data and the cellular data, and applies the NMF factors to annotate the cellular data with exRNA-related factors. - Parameters: + Parameters + ---------- sdata (AnnData): The AnnData object containing both extracellular and cellular data. layer_factors (str, optional): The key in `sdata` that contains the extracellular RNA data with NMF factors. Default is 'nmf_data'. - - Returns: + + Returns + ------- AnnData: The updated `sdata` object with annotated cellular data that includes the applied exRNA factors as new columns. - - Notes: + + Notes + ----- The function assumes that the extracellular RNA data is stored in `sdata[layer_factors]` and that the NMF factor loadings are stored in the `uns` attribute of the extracellular dataset as 'H_nmf'. The factor scores are added to the `obs` attribute of the cellular data. """ - # Extract extracellular data and cellular annotations adata_extracellular_with_nmf = sdata[layer_factors] - adata_annotated_cellular = sdata['table'] - + adata_annotated_cellular = sdata["table"] + # Retrieve NMF factor loadings (H matrix) from extracellular data - H = adata_extracellular_with_nmf.uns['H_nmf'] - + H = adata_extracellular_with_nmf.uns["H_nmf"] + # Get gene names from both datasets genes_spots2region = adata_extracellular_with_nmf.var_names genes_annotated = adata_annotated_cellular.var_names - + # Get the intersection of genes between the extracellular and cellular datasets common_genes = genes_annotated.intersection(genes_spots2region) - + # Filter both datasets to retain only the common genes adata_annotated_cellular = adata_annotated_cellular[:, common_genes] H_filtered = H[:, np.isin(genes_spots2region, common_genes)] # Filtered NMF factor loadings for common genes - + # Apply NMF factors to the annotated cellular dataset # Calculate the W matrix by multiplying the cellular data (X) with the filtered NMF loadings (H) W_annotated = adata_annotated_cellular.X @ H_filtered.T - + # Store the factors in the 'obsm' attribute of the AnnData object - adata_annotated_cellular.obsm['factors'] = pd.DataFrame(W_annotated, index=adata_annotated_cellular.obs.index) - + adata_annotated_cellular.obsm["factors"] = pd.DataFrame(W_annotated, index=adata_annotated_cellular.obs.index) + # Add each factor as a new column in the 'obs' attribute of the cellular dataset for factor in range(W_annotated.shape[1]): - adata_annotated_cellular.obs[f'NMF_factor_{factor + 1}'] = W_annotated[:, factor] - + adata_annotated_cellular.obs[f"NMF_factor_{factor + 1}"] = W_annotated[:, factor] + # Update the 'table' in the sdata object with the annotated cellular data - sdata['table'] = adata_annotated_cellular - + sdata["table"] = adata_annotated_cellular + return sdata diff --git a/src/troutpy/tl/__init__.py b/src/troutpy/tl/__init__.py index 03323b8..fd56952 100644 --- a/src/troutpy/tl/__init__.py +++ b/src/troutpy/tl/__init__.py @@ -1,7 +1,19 @@ -from .source_cell import create_xrna_metadata,compute_source_cells,distance_to_source_cell,compute_distant_cells_prop,get_proportion_expressed_per_cell_type -from .target_cell import calculate_target_cells,define_target_by_celltype from .estimate_density import colocalization_proportion -from .quantify_xrna import spatial_variability,create_xrna_metadata,quantify_overexpression,extracellular_enrichment,spatial_colocalization -from .interactions import get_number_of_communication_genes,get_gene_interaction_strength -from .NMF import apply_nmf_to_adata,nmf,apply_exrna_factors_to_cells -from .segmentation_free import segmentation_free_clustering \ No newline at end of file +from .interactions import get_gene_interaction_strength, get_number_of_communication_genes +from .NMF import apply_exrna_factors_to_cells, apply_nmf_to_adata, nmf +from .quantify_xrna import ( + create_xrna_metadata, + extracellular_enrichment, + quantify_overexpression, + spatial_colocalization, + spatial_variability, +) +from .segmentation_free import segmentation_free_clustering +from .source_cell import ( + compute_distant_cells_prop, + compute_source_cells, + create_xrna_metadata, + distance_to_source_cell, + get_proportion_expressed_per_cell_type, +) +from .target_cell import calculate_target_cells, define_target_by_celltype diff --git a/src/troutpy/tl/estimate_density.py b/src/troutpy/tl/estimate_density.py index ba27419..3951bee 100644 --- a/src/troutpy/tl/estimate_density.py +++ b/src/troutpy/tl/estimate_density.py @@ -1,58 +1,59 @@ import os + import numpy as np import pandas as pd + def colocalization_proportion( - sdata, - outpath, - threshold_colocalized=1, - filename='proportion_of_grouped_exRNA.parquet', save=True + sdata, outpath, threshold_colocalized=1, filename="proportion_of_grouped_exRNA.parquet", save=True ): """ Calculate the proportion of colocalized transcripts for each gene in the provided AnnData object. - Parameters: + Parameters + ---------- - sdata: AnnData object with `.X` matrix containing the density of transcripts per gene. - outpath: The directory path where the output file should be saved. - threshold_colocalized: The threshold for considering a transcript colocalized (default is 1). - filename: The name of the output file (default is 'proportion_of_grouped_exRNA.parquet'). - Returns: + Returns + ------- - coloc: DataFrame containing the proportion of colocalized transcripts for each gene. """ # Load relevant data - df = sdata.points['extracellular_transcripts_enriched'][['feature_name', 'bin_id']].compute() - adata_density_raw = sdata['segmentation_free_table'] - + df = sdata.points["extracellular_transcripts_enriched"][["feature_name", "bin_id"]].compute() + adata_density_raw = sdata["segmentation_free_table"] + # Filter adata_density to include only bin_ids present in df - filtered_bin_ids = df['bin_id'].astype(str).unique() - filtered_feature_name_ids = df['feature_name'].astype(str).unique() + filtered_bin_ids = df["bin_id"].astype(str).unique() + filtered_feature_name_ids = df["feature_name"].astype(str).unique() adata_density = adata_density_raw[adata_density_raw.obs.index.isin(filtered_bin_ids)] - adata_density=adata_density[:,adata_density.var.index.isin(filtered_feature_name_ids)] + adata_density = adata_density[:, adata_density.var.index.isin(filtered_feature_name_ids)] # Convert the sparse matrix to dense format (assuming the matrix is large, sparse ops can be done here) dense_matrix = adata_density.X.todense() - + # Calculate positive and colocalized counts for each gene positive_counts = np.sum(dense_matrix > 0, axis=0) # Count non-zero (positive) values per gene colocalized_counts = np.sum(dense_matrix > threshold_colocalized, axis=0) # Colocalized counts per gene - + # Calculate the proportion of colocalized transcripts proportions = np.divide(colocalized_counts, positive_counts, where=(positive_counts > 0)) # Avoid div by zero - + # Create the result DataFrame coloc = pd.DataFrame( data=proportions.A1, # Convert to a 1D array - index=adata_density.var.index, - columns=['proportion_of_colocalized'] + index=adata_density.var.index, + columns=["proportion_of_colocalized"], ) - + # Ensure the output directory exists os.makedirs(outpath, exist_ok=True) - + # Save the DataFrame as a Parquet file if save: filepath = os.path.join(outpath, filename) coloc.to_parquet(filepath) - + return coloc diff --git a/src/troutpy/tl/interactions.py b/src/troutpy/tl/interactions.py index 3c4cfb6..21831d0 100644 --- a/src/troutpy/tl/interactions.py +++ b/src/troutpy/tl/interactions.py @@ -1,87 +1,78 @@ -from typing import Optional -import scanpy as sc -import pandas as pd -import numpy as np -import anndata as ad -import seaborn as sns import matplotlib.pyplot as plt +import pandas as pd # function to compute the number of exchanged genes between any two cell types + def get_number_of_communication_genes( - source_proportions: pd.DataFrame, # gene by source cell type - target_proportions: pd.DataFrame, # gene by target cell type + source_proportions: pd.DataFrame, # gene by source cell type + target_proportions: pd.DataFrame, # gene by target cell type source_proportion_threshold: float = 0.2, - target_proportion_threshold: float = 0.2 - ) -> pd.DataFrame: + target_proportion_threshold: float = 0.2, +) -> pd.DataFrame: """Compute the number of exchanged genes between any two cell types Args: - source_proportions (pd.DataFrame): A data frame (Gene name x Cell Type) with - proportion of cells per cell type expressing corresponding gene - target_proportions : A data frame - (Gene name x Cell Type) with proportion of cells per cell type being the - physically clostest cell to transcripts of corresponding gene. + source_proportions (pd.DataFrame): A data frame (Gene name x Cell Type) with + proportion of cells per cell type expressing corresponding gene + target_proportions : A data frame + (Gene name x Cell Type) with proportion of cells per cell type being the + physically clostest cell to transcripts of corresponding gene. Defaults to 0.2. source_proportion_threshold (float, optional): The threshold to consider a cell type to be a significant source of a gene. Defaults to 0.2. target_proportion_threshold (float, optional): The threshold to consider a cell type to be a significant target of a gene. Defaults to 0.2. - Returns: + Returns + ------- pd.DataFrame: _description_ """ - # filter the source and target cell types by defining signficant proportions - source_binary = (source_proportions > source_proportion_threshold) - target_binary = (target_proportions > target_proportion_threshold) - + source_binary = source_proportions > source_proportion_threshold + target_binary = target_proportions > target_proportion_threshold + # prepare dataframe to store the number of exchanged genes - number_interactions_df = pd.DataFrame( - index=source_binary.columns, - columns=target_binary.columns - ) + number_interactions_df = pd.DataFrame(index=source_binary.columns, columns=target_binary.columns) - # loop through the source and target cell types to compute the number of + # loop through the source and target cell types to compute the number of # exchanged genes for col in source_binary.columns: sig_gene_source = source_binary.index[source_binary[col]] for col2 in target_binary.columns: sig_gene_target = target_binary.index[target_binary[col2]] - number_interactions_df.loc[col, col2] = len( - set(sig_gene_source).intersection(sig_gene_target) - ) + number_interactions_df.loc[col, col2] = len(set(sig_gene_source).intersection(sig_gene_target)) - number_interactions_df=number_interactions_df[number_interactions_df.index] - number_interactions_df.columns.name='Target cell type' - number_interactions_df.index.name='Source cell type' + number_interactions_df = number_interactions_df[number_interactions_df.index] + number_interactions_df.columns.name = "Target cell type" + number_interactions_df.index.name = "Source cell type" return number_interactions_df def get_gene_interaction_strength( source_proportions: pd.DataFrame, # gene by source cell type target_proportions: pd.DataFrame, # gene by target cell type - gene_symbol: str = '', # Gene of interest - return_interactions: bool = False, # Flag to return interaction matrix - save: bool = False, # Flag to save the plot - output_path: str = '', # Directory to save the plot - format: str = 'pdf' # Format to save the plot (e.g., pdf, png) + gene_symbol: str = "", # Gene of interest + return_interactions: bool = False, # Flag to return interaction matrix + save: bool = False, # Flag to save the plot + output_path: str = "", # Directory to save the plot + format: str = "pdf", # Format to save the plot (e.g., pdf, png) ) -> None: """ Computes and visualizes the interaction strength for a specific gene between source and target cell types. This function calculates the interaction strength between source and target cell types for a specified gene - by multiplying the proportions of the gene in the source and target cell types. The interaction matrix can + by multiplying the proportions of the gene in the source and target cell types. The interaction matrix can be visualized using a chord diagram, with the option to save the resulting plot. - Parameters: + Parameters ---------- source_proportions : pd.DataFrame - A DataFrame where rows represent genes and columns represent source cell types. Each value indicates + A DataFrame where rows represent genes and columns represent source cell types. Each value indicates the proportion of the gene in the respective source cell type. target_proportions : pd.DataFrame - A DataFrame where rows represent genes and columns represent target cell types. Each value indicates + A DataFrame where rows represent genes and columns represent target cell types. Each value indicates the proportion of the gene in the respective target cell type. gene_symbol : str, optional @@ -94,21 +85,21 @@ def get_gene_interaction_strength( If True, saves the chord diagram plot to the specified output path (default: False). output_path : str, optional - The directory path where the plot will be saved. If `save=True`, this path will be used to store the file + The directory path where the plot will be saved. If `save=True`, this path will be used to store the file (default: ''). A 'figures' subdirectory is created if it doesn't exist. format : str, optional The file format for saving the plot (e.g., 'pdf', 'png'). This is used only if `save=True` (default: 'pdf'). - Returns: + Returns ------- None or np.ndarray - If `return_interactions=True`, the function returns the interaction matrix as a NumPy array. Otherwise, + If `return_interactions=True`, the function returns the interaction matrix as a NumPy array. Otherwise, the function generates a chord diagram plot. - Notes: + Notes ----- - - The function computes the interaction matrix by multiplying the proportions of the gene in the source and + - The function computes the interaction matrix by multiplying the proportions of the gene in the source and target cell types. - The chord diagram visualizes the interaction strength between the cell types. - If `save=True`, the plot is saved in the specified format and location. @@ -117,11 +108,12 @@ def get_gene_interaction_strength( ------- To compute and visualize the interaction strength for a specific gene: - >>> get_gene_specific_interaction_strength(source_proportions, target_proportions, gene_symbol='MYC', save=True, output_path='results', format='png') + >>> get_gene_specific_interaction_strength( + ... source_proportions, target_proportions, gene_symbol="MYC", save=True, output_path="results", format="png" + ... ) This will save the plot as a PNG file in the 'results/figures' directory. """ - # Ensure the target proportions have the same cell type columns as the source proportions target_proportions = target_proportions[source_proportions.columns] @@ -132,7 +124,6 @@ def get_gene_interaction_strength( # Compute the interaction matrix (source proportions * target proportions) interactions = source_proportions_vals @ target_proportions_vals - # Define the colormap and create color mappings for each cell type cmap = plt.get_cmap("tab20") colors = [cmap(i) for i in range(interactions.shape[0])] @@ -143,10 +134,10 @@ def get_gene_interaction_strength( # Save the plot if the 'save' option is enabled if save: - figures_path = os.path.join(output_path, 'figures') + figures_path = os.path.join(output_path, "figures") os.makedirs(figures_path, exist_ok=True) # Create 'figures' directory if it doesn't exist - plt.savefig(os.path.join(figures_path, f'communication_profile_{gene_symbol}.{format}')) # Save the figure + plt.savefig(os.path.join(figures_path, f"communication_profile_{gene_symbol}.{format}")) # Save the figure # Show the plot plt.show() - return pd.DataFrame(interactions,index=source_proportions.columns,columns=target_proportions.columns) + return pd.DataFrame(interactions, index=source_proportions.columns, columns=target_proportions.columns) diff --git a/src/troutpy/tl/quantify_xrna.py b/src/troutpy/tl/quantify_xrna.py index 39bc1b2..45b6249 100644 --- a/src/troutpy/tl/quantify_xrna.py +++ b/src/troutpy/tl/quantify_xrna.py @@ -1,32 +1,32 @@ -import scanpy as sc -#import squidpy as sq -import pandas as pd -import matplotlib.pyplot as plt -import os -from spatialdata import SpatialData -import spatialdata as sd import numpy as np -from typing import List, Union, Tuple + +# import squidpy as sq +import pandas as pd import polars as pl +import scanpy as sc +import spatialdata as sd +import squidpy as sq from sainsc import LazyKDE +from spatialdata import SpatialData from tqdm import tqdm -import squidpy as sq + def spatial_variability( - sdata, - coords_keys=['x', 'y'], - gene_id_key='feature_name', - n_neighbors=10, - resolution=1000, - binsize=20, - n_threads=1, - spatial_autocorr_mode="moran",copy=False + sdata, + coords_keys=["x", "y"], + gene_id_key="feature_name", + n_neighbors=10, + resolution=1000, + binsize=20, + n_threads=1, + spatial_autocorr_mode="moran", + copy=False, ): """ Computes spatial variability of extracellular RNA using Moran's I. - Parameters: - ----------- + Parameters + ---------- sdata : SpatialData The spatial transcriptomics dataset in SpatialData format. coords_keys : list of str, optional @@ -44,14 +44,14 @@ def spatial_variability( spatial_autocorr_mode : str, optional The mode for spatial autocorrelation computation (default: "moran"). - Returns: - -------- + Returns + ------- pd.DataFrame A DataFrame containing Moran's I values for each gene, indexed by gene names. """ # Step 1: Extract and preprocess data - data = sdata.points['transcripts'][coords_keys + ['extracellular', gene_id_key]].compute() - data = data[data['extracellular'] == True] + data = sdata.points["transcripts"][coords_keys + ["extracellular", gene_id_key]].compute() + data = data[data["extracellular"] == True] data[gene_id_key] = data[gene_id_key].astype(str) # Rename columns for clarity @@ -77,9 +77,9 @@ def spatial_variability( # Step 4: Create AnnData object adata = sc.AnnData(allres) adata.var.index = embryo.counts.genes() - adata.obs['x'] = x_coords.flatten() - adata.obs['y'] = y_coords.flatten() - adata.obsm['spatial'] = np.array(adata.obs.loc[:, ['x', 'y']]) + adata.obs["x"] = x_coords.flatten() + adata.obs["y"] = y_coords.flatten() + adata.obsm["spatial"] = np.array(adata.obs.loc[:, ["x", "y"]]) # Step 5: Compute spatial neighbors and Moran's I sq.gr.spatial_neighbors(adata, n_neighs=n_neighbors) @@ -87,68 +87,65 @@ def spatial_variability( # Extract Moran's I values svg_df = pd.DataFrame(adata.uns["moranI"]) - svg_df.columns=[spatial_autocorr_mode+'_'+str(g) for g in svg_df.columns] + svg_df.columns = [spatial_autocorr_mode + "_" + str(g) for g in svg_df.columns] try: - sdata['xrna_metadata'] + sdata["xrna_metadata"] except KeyError: - create_xrna_metadata(sdata, points_layer='transcripts') + create_xrna_metadata(sdata, points_layer="transcripts") for column in svg_df.columns: - if column in sdata['xrna_metadata'].var.columns: - sdata['xrna_metadata'].var=sdata['xrna_metadata'].var.drop([column],axis=1) - - - sdata['xrna_metadata'].var = sdata['xrna_metadata'].var.join(svg_df) + if column in sdata["xrna_metadata"].var.columns: + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.drop([column], axis=1) + + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.join(svg_df) return sdata if copy else None + def create_xrna_metadata( - sdata: SpatialData, - points_layer: str = 'transcripts', - gene_key: str = 'feature_name', - copy: bool = False + sdata: SpatialData, points_layer: str = "transcripts", gene_key: str = "feature_name", copy: bool = False ) -> SpatialData | None: """ - Creates a new table within the SpatialData object that contains a 'gene' column + Creates a new table within the SpatialData object that contains a 'gene' column with the unique gene names extracted from the specified points layer. - Parameters: + Parameters ---------- sdata : SpatialData The SpatialData object to modify. - + points_layer : str, optional The name of the layer in `sdata.points` from which to extract gene names. Default is 'transcripts'. - + gene_key : str, optional The key in the `points_layer` dataframe that contains the gene names. Default is 'feature_name'. - + copy : bool, optional If `True`, returns a copy of the `SpatialData` object with the new table added. If `False`, modifies the original `SpatialData` object in place. Default is `False`. - Returns: + Returns ------- SpatialData | None If `copy` is `True`, returns a copy of the modified `SpatialData` object. Otherwise, returns `None`. - Raises: + Raises ------ ValueError If the specified points layer does not exist in `sdata.points`. If the `gene_key` column is not present in the specified points layer. - Examples: + Examples -------- Add a metadata table for genes in the 'transcripts' layer: - >>> create_xrna_metadata(sdata, points_layer='transcripts', gene_key='feature_name') + >>> create_xrna_metadata(sdata, points_layer="transcripts", gene_key="feature_name") Modify a custom SpatialData layer and return a copy: - >>> updated_sdata = create_xrna_metadata(sdata, points_layer='custom_layer', gene_key='gene_id', copy=True) + >>> updated_sdata = create_xrna_metadata(sdata, points_layer="custom_layer", gene_key="gene_id", copy=True) - Notes: + Notes ----- - The function uses `scanpy` to create an AnnData object and integrates it into the SpatialData table model. - The unique gene names are extracted from the specified points layer and stored in the `.var` of the AnnData object. @@ -156,38 +153,39 @@ def create_xrna_metadata( # Check if the specified points layer exists if points_layer not in sdata.points: raise ValueError(f"Points layer '{points_layer}' not found in sdata.points.") - + # Extract unique gene names from the specified points layer points_data = sdata.points[points_layer] if gene_key not in points_data.columns: raise ValueError(f"The specified points layer '{points_layer}' does not contain a '{gene_key}' column.") - + unique_genes = points_data[gene_key].compute().unique().astype(str) - + # Create a DataFrame for unique genes gene_metadata = pd.DataFrame(index=unique_genes) # Convert to AnnData and then to SpatialData table model exrna_adata = sc.AnnData(var=gene_metadata) metadata_table = sd.models.TableModel.parse(exrna_adata) - + # Add the new table to the SpatialData object - sdata.tables['xrna_metadata'] = metadata_table + sdata.tables["xrna_metadata"] = metadata_table print(f"Added 'xrna_metadata' table with {len(unique_genes)} unique genes to the SpatialData object.") - + # Return copy or modify in place return sdata if copy else None + def quantify_overexpression( sdata: pd.DataFrame, codeword_column: str, - control_codewords: Union[List[str], str], - gene_id_column: str='feature_name', - layer: str = 'transcripts', + control_codewords: list[str] | str, + gene_id_column: str = "feature_name", + layer: str = "transcripts", percentile_threshold: float = 100, - copy=False -) -> Tuple[pd.DataFrame, pd.DataFrame, float]: + copy=False, +) -> tuple[pd.DataFrame, pd.DataFrame, float]: """Compare counts per gene with counts per non-gene feature. We define a threshold as the 'percentile_threshold' counts of non-gene counts (e.g. 'percentile_threshold = 100' corresponds to the maximum number of counts observed in any non-gene feature). Any gene whose counts are above the threshold are considered overexpressed. Args: @@ -199,121 +197,126 @@ def quantify_overexpression( save (bool, optional): Whether to save outputs to file. Defaults to True. saving_path (str, optional): Path to directory that files should be saved in. Defaults to "". - Returns: + Returns + ------- Tuple[pd.DataFrame, pd.DataFrame, float]: A tuple containing the updated sdata, scores per gene DataFrame, and the calculated threshold. """ - # Compute the data from the Dask DataFrame - data = sdata.points[layer][['extracellular',codeword_column,gene_id_column]].compute() - data=data[data['extracellular']==True] + data = sdata.points[layer][["extracellular", codeword_column, gene_id_column]].compute() + data = data[data["extracellular"] == True] # Ensure control_codewords is a list if isinstance(control_codewords, str): control_codewords = [control_codewords] - assert isinstance(control_codewords, List), \ - f"control_codewords should be a list but has type: {type(control_codewords)}" - + assert isinstance( + control_codewords, list + ), f"control_codewords should be a list but has type: {type(control_codewords)}" + # Get counts per control feature - counts_per_nongene = data.loc[ - (data.loc[:, codeword_column].isin(control_codewords)), - gene_id_column - ].value_counts().to_frame().reset_index() + counts_per_nongene = ( + data.loc[(data.loc[:, codeword_column].isin(control_codewords)), gene_id_column] + .value_counts() + .to_frame() + .reset_index() + ) threshold = np.percentile(counts_per_nongene.loc[:, "count"].values, percentile_threshold) - + # create dict - gene2genestatus=dict(zip(data[gene_id_column],data[codeword_column].isin(control_codewords))) + gene2genestatus = dict(zip(data[gene_id_column], data[codeword_column].isin(control_codewords), strict=False)) # Get counts per gene scores_per_gene = data[gene_id_column].value_counts().to_frame() - scores_per_gene.columns = ['count'] - scores_per_gene['control_probe']=scores_per_gene.index.map(gene2genestatus) + scores_per_gene.columns = ["count"] + scores_per_gene["control_probe"] = scores_per_gene.index.map(gene2genestatus) scores_per_gene.loc[:, "logfoldratio_over_noise"] = np.log(scores_per_gene.loc[:, "count"] / threshold) try: - sdata['xrna_metadata'] + sdata["xrna_metadata"] except: - create_xrna_metadata(sdata, points_layer = 'transcripts') + create_xrna_metadata(sdata, points_layer="transcripts") - sdata['xrna_metadata'].var=sdata['xrna_metadata'].var.join(scores_per_gene) - sdata['xrna_metadata'].var['control_probe']=sdata['xrna_metadata'].var['control_probe'].fillna(False) + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.join(scores_per_gene) + sdata["xrna_metadata"].var["control_probe"] = sdata["xrna_metadata"].var["control_probe"].fillna(False) return sdata if copy else None -def extracellular_enrichment(sdata, gene_id_column: str = 'feature_name', copy: bool = False): + +def extracellular_enrichment(sdata, gene_id_column: str = "feature_name", copy: bool = False): """ Calculate the proportion of extracellular and intracellular transcripts for each gene and integrate results into the AnnData object. This function computes the proportion of transcripts classified as extracellular or intracellular for each gene and calculates additional metrics, including log fold change of extracellular to intracellular proportions. The results are integrated into the `sdata` object under the 'xrna_metadata' layer. - Parameters: - ----------- + Parameters + ---------- sdata : AnnData - An AnnData object containing spatial transcriptomics data. The `points` attribute should include a - 'transcripts' DataFrame with columns for gene IDs (specified by `gene_id_column`) and a boolean + An AnnData object containing spatial transcriptomics data. The `points` attribute should include a + 'transcripts' DataFrame with columns for gene IDs (specified by `gene_id_column`) and a boolean 'extracellular' column indicating whether each transcript is classified as extracellular. gene_id_column : str, optional The name of the column in the 'transcripts' DataFrame containing gene identifiers. Defaults to 'feature_name'. copy : bool, optional - Whether to return a modified copy of the input `sdata` object. If `False`, the input object is modified + Whether to return a modified copy of the input `sdata` object. If `False`, the input object is modified in place. Defaults to `False`. - Returns: - -------- + Returns + ------- AnnData or None - If `copy=True`, returns a modified copy of the input `sdata` object with updated metadata. Otherwise, + If `copy=True`, returns a modified copy of the input `sdata` object with updated metadata. Otherwise, modifies `sdata` in place and returns `None`. - Notes: - ------ + Notes + ----- - The function assumes that the `sdata` object has a 'points' layer containing a 'transcripts' DataFrame. - - If the 'xrna_metadata' attribute does not exist in `sdata`, it will be created using the `create_xrna_metadata` + - If the 'xrna_metadata' attribute does not exist in `sdata`, it will be created using the `create_xrna_metadata` function. Example: -------- - >>> updated_sdata = extracellular_enrichment(sdata, gene_id_column='gene_symbol', copy=True) - >>> print(updated_sdata['xrna_metadata'].var) + >>> updated_sdata = extracellular_enrichment(sdata, gene_id_column="gene_symbol", copy=True) + >>> print(updated_sdata["xrna_metadata"].var) """ # Extract and compute the required data - data = sdata.points['transcripts'][[gene_id_column, 'extracellular']].compute() - + data = sdata.points["transcripts"][[gene_id_column, "extracellular"]].compute() + # Create a crosstab to count occurrences of intracellular and extracellular transcripts - feature_inout = pd.crosstab(data[gene_id_column], data['extracellular']) + feature_inout = pd.crosstab(data[gene_id_column], data["extracellular"]) norm_counts = feature_inout.div(feature_inout.sum(axis=0), axis=1) - norm_counts['extracellular_foldratio'] = norm_counts[False] / norm_counts[True] - + norm_counts["extracellular_foldratio"] = norm_counts[False] / norm_counts[True] + extracellular_proportion = feature_inout.div(feature_inout.sum(axis=1), axis=0) - extracellular_proportion.columns = extracellular_proportion.columns.map({ - True: 'intracellular_proportion', False: 'extracellular_proportion' - }) - extracellular_proportion['logfoldratio_extracellular'] = np.log(norm_counts['extracellular_foldratio']) - + extracellular_proportion.columns = extracellular_proportion.columns.map( + {True: "intracellular_proportion", False: "extracellular_proportion"} + ) + extracellular_proportion["logfoldratio_extracellular"] = np.log(norm_counts["extracellular_foldratio"]) + # Ensure the 'xrna_metadata' attribute exists try: - sdata['xrna_metadata'] + sdata["xrna_metadata"] except KeyError: - create_xrna_metadata(sdata, points_layer='transcripts') - + create_xrna_metadata(sdata, points_layer="transcripts") + # Join the results to the metadata - sdata['xrna_metadata'].var = sdata['xrna_metadata'].var.join(extracellular_proportion) + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.join(extracellular_proportion) return sdata if copy else None + def spatial_colocalization( - sdata, - coords_keys=['x', 'y'], - gene_id_key='feature_name', - - resolution=1000, - binsize=20, - n_threads=1, - threshold_colocalized=1,copy=False + sdata, + coords_keys=["x", "y"], + gene_id_key="feature_name", + resolution=1000, + binsize=20, + n_threads=1, + threshold_colocalized=1, + copy=False, ): """ Computes spatial variability of extracellular RNA using Moran's I. - Parameters: - ----------- + Parameters + ---------- sdata : SpatialData The spatial transcriptomics dataset in SpatialData format. coords_keys : list of str, optional @@ -331,14 +334,14 @@ def spatial_colocalization( spatial_autocorr_mode : str, optional The mode for spatial autocorrelation computation (default: "moran"). - Returns: - -------- + Returns + ------- pd.DataFrame A DataFrame containing Moran's I values for each gene, indexed by gene names. """ # Step 1: Extract and preprocess data - data = sdata.points['transcripts'][coords_keys + ['extracellular', gene_id_key]].compute() - data = data[data['extracellular'] == True] + data = sdata.points["transcripts"][coords_keys + ["extracellular", gene_id_key]].compute() + data = data[data["extracellular"] == True] data[gene_id_key] = data[gene_id_key].astype(str) # Rename columns for clarity @@ -364,21 +367,21 @@ def spatial_colocalization( # Step 4: Create AnnData object adata = sc.AnnData(allres) adata.var.index = embryo.counts.genes() - adata.obs['x'] = x_coords.flatten() - adata.obs['y'] = y_coords.flatten() - adata.obsm['spatial'] = np.array(adata.obs.loc[:, ['x', 'y']]) + adata.obs["x"] = x_coords.flatten() + adata.obs["y"] = y_coords.flatten() + adata.obsm["spatial"] = np.array(adata.obs.loc[:, ["x", "y"]]) - threshold_colocalized=1 + threshold_colocalized = 1 # Calculate positive and colocalized counts for each gene positive_counts = np.sum(adata.X > 0, axis=0) # Count non-zero (positive) values per gene colocalized_counts = np.sum(adata.X > threshold_colocalized, axis=0) # Colocalized counts per gene # Calculate the proportion of colocalized transcripts proportions = np.divide(colocalized_counts, positive_counts, where=(positive_counts > 0)) # Avoid div by zero # Create the result DataFrame - coloc = pd.DataFrame(data=proportions,index=adata.var.index, columns=['proportion_of_colocalized']) + coloc = pd.DataFrame(data=proportions, index=adata.var.index, columns=["proportion_of_colocalized"]) for column in coloc.columns: - if column in sdata['xrna_metadata'].var.columns: - sdata['xrna_metadata'].var=sdata['xrna_metadata'].var.drop([column],axis=1) - sdata['xrna_metadata'].var = sdata['xrna_metadata'].var.join(coloc) + if column in sdata["xrna_metadata"].var.columns: + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.drop([column], axis=1) + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.join(coloc) return sdata if copy else None diff --git a/src/troutpy/tl/segmentation_free.py b/src/troutpy/tl/segmentation_free.py index bdd792f..b411c9a 100644 --- a/src/troutpy/tl/segmentation_free.py +++ b/src/troutpy/tl/segmentation_free.py @@ -1,24 +1,23 @@ -import spatialdata_io import spatialdata as sd from points2regions import Points2Regions -import pandas as pd def segmentation_free_clustering( - sdata, - params: dict = {}, - x: str = 'x', - y: str = 'y', - feature_name: str = 'feature_name', - method: str = 'points2regions', - transcript_id: str = 'transcript_id', - copy: bool = False + sdata, + params: dict = {}, + x: str = "x", + y: str = "y", + feature_name: str = "feature_name", + method: str = "points2regions", + transcript_id: str = "transcript_id", + copy: bool = False, ): """Perform segmentation-free clustering on transcriptomic spatial data. This function clusters transcriptomic data without relying on pre-defined cell or tissue segmentations.It supports multiple clustering methods, with Points2Regions being the default. - Parameters: + Parameters + ---------- sdata : SpatialData A spatial data object containing transcriptomic information. params : dict, optional (default: {}) @@ -42,41 +41,42 @@ def segmentation_free_clustering( copy : bool, optional (default: False) If True, returns a copy of the clustering results. If False, updates `sdata` in-place. - Returns: + Returns + ------- Optional[anndata.AnnData]: If `copy` is True, returns an AnnData object containing the clustering results. Otherwise, updates the `sdata` object in-place and returns None. """ # Reset transcript indexing if not unique - sdata.points['transcripts'] = sdata.points['transcripts'].reset_index(drop=True) + sdata.points["transcripts"] = sdata.points["transcripts"].reset_index(drop=True) # Prepare data for clustering - data = sdata.points['transcripts'][[x, y, feature_name, transcript_id]].compute() + data = sdata.points["transcripts"][[x, y, feature_name, transcript_id]].compute() - if method == 'points2regions': + if method == "points2regions": # Validate required parameters for Points2Regions - required_keys = ['num_clusters', 'pixel_width', 'pixel_smoothing'] + required_keys = ["num_clusters", "pixel_width", "pixel_smoothing"] for key in required_keys: if key not in params: raise ValueError(f"Missing required parameter for 'points2regions': '{key}'") # Initialize and fit Points2Regions clustering model p2r = Points2Regions( - data[[x, y]], - data[feature_name].astype(str), - pixel_width=params['pixel_width'], - pixel_smoothing=params['pixel_smoothing'] + data[[x, y]], + data[feature_name].astype(str), + pixel_width=params["pixel_width"], + pixel_smoothing=params["pixel_smoothing"], ) - p2r.fit(num_clusters=params['num_clusters']) + p2r.fit(num_clusters=params["num_clusters"]) # Retrieve clustering results adata = p2r._get_anndata() - transcript_id_to_bin = dict(zip(adata.uns['reads'].index, adata.uns['reads']['pixel_ind'])) - data_all = sdata.points['transcripts'].compute().reset_index(drop=True) - data_all['segmentation_free_clusters'] = p2r.predict(output='marker')#.astype('category') - data_all['bin_id'] = data_all.index.map(transcript_id_to_bin) + transcript_id_to_bin = dict(zip(adata.uns["reads"].index, adata.uns["reads"]["pixel_ind"], strict=False)) + data_all = sdata.points["transcripts"].compute().reset_index(drop=True) + data_all["segmentation_free_clusters"] = p2r.predict(output="marker") # .astype('category') + data_all["bin_id"] = data_all.index.map(transcript_id_to_bin) - elif method == 'sainsc': + elif method == "sainsc": # Placeholder for another clustering method raise NotImplementedError("The 'sainsc' method is not yet implemented.") @@ -84,7 +84,7 @@ def segmentation_free_clustering( raise ValueError(f"Unknown method: {method}. Supported methods are 'points2regions' and 'sainsc'.") # Update the sdata object - sdata.points['transcripts'] = sd.models.PointsModel.parse(data_all) - sdata['segmentation_free_table'] = adata + sdata.points["transcripts"] = sd.models.PointsModel.parse(data_all) + sdata["segmentation_free_table"] = adata return adata if copy else None diff --git a/src/troutpy/tl/source_cell.py b/src/troutpy/tl/source_cell.py index 7cc8045..c7a7e6d 100644 --- a/src/troutpy/tl/source_cell.py +++ b/src/troutpy/tl/source_cell.py @@ -1,62 +1,56 @@ -import anndata as ad -import seaborn as sns -import matplotlib.pyplot as plt import numpy as np -from tqdm import tqdm import pandas as pd -import os -from sklearn.neighbors import KDTree import spatialdata as sd +from sklearn.neighbors import KDTree from spatialdata import SpatialData +from tqdm import tqdm + def create_xrna_metadata( - sdata: SpatialData, - points_layer: str = 'transcripts', - gene_key: str = 'feature_name', - copy: bool = False + sdata: SpatialData, points_layer: str = "transcripts", gene_key: str = "feature_name", copy: bool = False ) -> SpatialData | None: """ - Creates a new table within the SpatialData object that contains a 'gene' column + Creates a new table within the SpatialData object that contains a 'gene' column with the unique gene names extracted from the specified points layer. - Parameters: + Parameters ---------- sdata : SpatialData The SpatialData object to modify. - + points_layer : str, optional The name of the layer in `sdata.points` from which to extract gene names. Default is 'transcripts'. - + gene_key : str, optional The key in the `points_layer` dataframe that contains the gene names. Default is 'feature_name'. - + copy : bool, optional If `True`, returns a copy of the `SpatialData` object with the new table added. If `False`, modifies the original `SpatialData` object in place. Default is `False`. - Returns: + Returns ------- SpatialData | None If `copy` is `True`, returns a copy of the modified `SpatialData` object. Otherwise, returns `None`. - Raises: + Raises ------ ValueError If the specified points layer does not exist in `sdata.points`. If the `gene_key` column is not present in the specified points layer. - Examples: + Examples -------- Add a metadata table for genes in the 'transcripts' layer: - >>> create_xrna_metadata(sdata, points_layer='transcripts', gene_key='feature_name') + >>> create_xrna_metadata(sdata, points_layer="transcripts", gene_key="feature_name") Modify a custom SpatialData layer and return a copy: - >>> updated_sdata = create_xrna_metadata(sdata, points_layer='custom_layer', gene_key='gene_id', copy=True) + >>> updated_sdata = create_xrna_metadata(sdata, points_layer="custom_layer", gene_key="gene_id", copy=True) - Notes: + Notes ----- - The function uses `scanpy` to create an AnnData object and integrates it into the SpatialData table model. - The unique gene names are extracted from the specified points layer and stored in the `.var` of the AnnData object. @@ -64,40 +58,35 @@ def create_xrna_metadata( # Check if the specified points layer exists if points_layer not in sdata.points: raise ValueError(f"Points layer '{points_layer}' not found in sdata.points.") - + # Extract unique gene names from the specified points layer points_data = sdata.points[points_layer] if gene_key not in points_data.columns: raise ValueError(f"The specified points layer '{points_layer}' does not contain a '{gene_key}' column.") - + unique_genes = points_data[gene_key].compute().unique().astype(str) - + # Create a DataFrame for unique genes gene_metadata = pd.DataFrame(index=unique_genes) # Convert to AnnData and then to SpatialData table model exrna_adata = sc.AnnData(var=gene_metadata) metadata_table = sd.models.TableModel.parse(exrna_adata) - + # Add the new table to the SpatialData object - sdata.tables['xrna_metadata'] = metadata_table + sdata.tables["xrna_metadata"] = metadata_table print(f"Added 'xrna_metadata' table with {len(unique_genes)} unique genes to the SpatialData object.") - + # Return copy or modify in place return sdata if copy else None -def compute_source_cells( - sdata, - expression_threshold=1, - gene_id_column='feature_name', - layer='transcripts', - copy=False -): + +def compute_source_cells(sdata, expression_threshold=1, gene_id_column="feature_name", layer="transcripts", copy=False): """ Compute the source of extracellular RNA by linking detected extracellular transcripts to specific cell types in the spatial data. - Parameters: + Parameters ---------- sdata : SpatialData object The input spatial data object containing spatial transcriptomics data. @@ -110,53 +99,51 @@ def compute_source_cells( copy : bool, optional, default=False If True, returns a modified copy of the spatial data object. Otherwise, modifies in place. - Returns: + Returns ------- sdata : SpatialData object or None - The modified spatial data object with added `source` metadata if `copy=True`. + The modified spatial data object with added `source` metadata if `copy=True`. Otherwise, modifies the input object in place and returns None. """ - # Create a copy of the table containing spatial transcriptomics data - adata = sdata['table'].copy() - adata.X = adata.layers['raw'] # Use the 'raw' layer for calculations + adata = sdata["table"].copy() + adata.X = adata.layers["raw"] # Use the 'raw' layer for calculations # Generate a binary matrix where values above the threshold are set to True adata_bin = adata.copy() adata_bin.X = adata_bin.X > expression_threshold # Compute the proportion of cells expressing each feature per cell type - proportions = get_proportion_expressed_per_cell_type( - adata_bin, - cell_type_key='cell type' - ) + proportions = get_proportion_expressed_per_cell_type(adata_bin, cell_type_key="cell type") # Ensure the necessary `xrna_metadata` is present in `sdata` - if 'xrna_metadata' not in sdata: - create_xrna_metadata(sdata, points_layer='transcripts') + if "xrna_metadata" not in sdata: + create_xrna_metadata(sdata, points_layer="transcripts") # Create an output DataFrame and store computed proportions - outtable = pd.DataFrame(index=sdata['xrna_metadata'].var.index) - sdata['xrna_metadata'].varm['source'] = outtable.join(proportions).to_numpy() + outtable = pd.DataFrame(index=sdata["xrna_metadata"].var.index) + sdata["xrna_metadata"].varm["source"] = outtable.join(proportions).to_numpy() # Return the modified SpatialData object or None based on the `copy` parameter return sdata.copy() if copy else None + def distance_to_source_cell( - sdata, - layer='transcripts', - xcoord='x', - ycoord='y', - xcellcoord='x_centroid', - ycellcoord='y_centroid', - gene_id_column='feature_name', - copy=False + sdata, + layer="transcripts", + xcoord="x", + ycoord="y", + xcellcoord="x_centroid", + ycellcoord="y_centroid", + gene_id_column="feature_name", + copy=False, ): """Calculates the distance between extracellular RNA transcripts and their closest source cells. This function computes the distance from each extracellular RNA transcript to the nearest source cell based on their spatial coordinates. The function uses a KDTree to efficiently find the closest cell to each transcript, storing the results in the `sdata` object. - Parameters: + Parameters + ---------- sdata (AnnData): The AnnData object containing both transcript and cellular data. layer (str, optional): The layer in `sdata` containing the transcript data. Default is 'transcripts'. xcoord (str, optional): The column name in the transcript data for the x-coordinate. Default is 'x'. @@ -166,25 +153,26 @@ def distance_to_source_cell( gene_id_column (str, optional): The column name for the gene identifier. Default is 'feature_name'. copy (bool, optional): Whether to return a copy of the `sdata` object with updated distances, or modify in place. Default is False. - Returns: + Returns + ------- AnnData or None: If `copy` is True, returns the updated `sdata` object. Otherwise, modifies `sdata` in place and returns None. - Notes: - The function assumes that the transcript data contains a column `transcript_id` and that the cellular data contains + Notes + ----- + The function assumes that the transcript data contains a column `transcript_id` and that the cellular data contains cell centroids for spatial coordinates. The KDTree algorithm is used to compute the closest cell for each transcript. - The resulting distances are stored in the `distance_to_source_cell` column of the `sdata` object's transcript layer, + The resulting distances are stored in the `distance_to_source_cell` column of the `sdata` object's transcript layer, and the closest source cell is stored in the `closest_source_cell` column. The median distance for each gene is also added to the `xrna_metadata` in the `var` attribute of `sdata`. """ - # Extract transcript and cellular data - adata_bin = sdata['table'].copy() - adata_bin.X = sdata['table'].layers['raw'] - adata_bin.obs['x_centroid'] = [sp[0] for sp in adata_bin.obsm['spatial']] - adata_bin.obs['y_centroid'] = [sp[1] for sp in adata_bin.obsm['spatial']] + adata_bin = sdata["table"].copy() + adata_bin.X = sdata["table"].layers["raw"] + adata_bin.obs["x_centroid"] = [sp[0] for sp in adata_bin.obsm["spatial"]] + adata_bin.obs["y_centroid"] = [sp[1] for sp in adata_bin.obsm["spatial"]] transcripts = sdata.points[layer].compute() - extracellular_transcripts = transcripts[transcripts['extracellular']] - + extracellular_transcripts = transcripts[transcripts["extracellular"]] + # Initialize lists to store results tranid = [] dist = [] @@ -194,41 +182,46 @@ def distance_to_source_cell( for gene_of_interest in tqdm(adata_bin.var_names): gene_idx = np.where(adata_bin.var_names == gene_of_interest)[0][0] adata_filtered = adata_bin[adata_bin.X[:, gene_idx] > 0] - extracellular_transcripts_filtered = extracellular_transcripts[extracellular_transcripts[gene_id_column] == gene_of_interest].copy() + extracellular_transcripts_filtered = extracellular_transcripts[ + extracellular_transcripts[gene_id_column] == gene_of_interest + ].copy() # Only proceed if there are positive cells for the gene of interest if (adata_filtered.n_obs > 0) & (extracellular_transcripts_filtered.shape[0] > 0): # Extract coordinates of cells and transcripts cell_coords = np.array([adata_filtered.obs[xcellcoord], adata_filtered.obs[ycellcoord]]).T - transcript_coords = np.array([extracellular_transcripts_filtered[xcoord], extracellular_transcripts_filtered[ycoord]]).T + transcript_coords = np.array( + [extracellular_transcripts_filtered[xcoord], extracellular_transcripts_filtered[ycoord]] + ).T # Compute KDTree for nearest cell tree = KDTree(cell_coords) distances, closest_cells_indices = tree.query(transcript_coords, k=1) # Append results to lists - tranid.extend(extracellular_transcripts_filtered['transcript_id']) + tranid.extend(extracellular_transcripts_filtered["transcript_id"]) dist.extend([d[0] for d in distances]) - cell_ids = adata_filtered.obs['cell_id'].values[closest_cells_indices.flatten()] + cell_ids = adata_filtered.obs["cell_id"].values[closest_cells_indices.flatten()] cellids.extend(c[0] for c in cell_ids.reshape(closest_cells_indices.shape)) # Create a dictionary to map transcript IDs to distances and cell IDs - id2dist = dict(zip(tranid, dist)) - id2closeid = dict(zip(tranid, cellids)) + id2dist = dict(zip(tranid, dist, strict=False)) + id2closeid = dict(zip(tranid, cellids, strict=False)) # Store the results in the DataFrame - transcripts['distance_to_source_cell'] = transcripts['transcript_id'].map(id2dist) - transcripts['closest_source_cell'] = transcripts['transcript_id'].map(id2closeid) + transcripts["distance_to_source_cell"] = transcripts["transcript_id"].map(id2dist) + transcripts["closest_source_cell"] = transcripts["transcript_id"].map(id2closeid) sdata.points[layer] = sd.models.PointsModel.parse(transcripts) # Add median distance_to_source_cell - dist_to_source = transcripts.loc[:, [gene_id_column, 'distance_to_source_cell']].groupby(gene_id_column).median() - dist_to_source.columns = ['median_distance_to_source_cell'] - sdata['xrna_metadata'].var = sdata['xrna_metadata'].var.join(dist_to_source) + dist_to_source = transcripts.loc[:, [gene_id_column, "distance_to_source_cell"]].groupby(gene_id_column).median() + dist_to_source.columns = ["median_distance_to_source_cell"] + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.join(dist_to_source) return sdata.copy() if copy else None -def compute_distant_cells_prop(sdata, layer='transcripts', gene_id_column='feature_name', threshold=30,copy=False): + +def compute_distant_cells_prop(sdata, layer="transcripts", gene_id_column="feature_name", threshold=30, copy=False): """ Compute the proportion of transcripts for each gene that are located beyond a specified distance from their closest source cell, and add the result to the metadata of the SpatialData object. @@ -256,30 +249,31 @@ def compute_distant_cells_prop(sdata, layer='transcripts', gene_id_column='featu Example ------- ``` - compute_source_cells_beyond_distance(sdata, layer='transcripts', threshold=30) + compute_source_cells_beyond_distance(sdata, layer="transcripts", threshold=30) ``` """ - # Extract transcript data data = sdata.points[layer].compute() - + # Calculate the proportions of distances above the threshold - proportions_above_threshold = (data.groupby(gene_id_column)['distance_to_source_cell'] - .apply(lambda x: (x > threshold).mean())) - + proportions_above_threshold = data.groupby(gene_id_column)["distance_to_source_cell"].apply( + lambda x: (x > threshold).mean() + ) + # Create a DataFrame and rename the column proportions_above_threshold = pd.DataFrame(proportions_above_threshold) - proportions_above_threshold.columns = [f'frac_beyond_{threshold}_from_source'] - + proportions_above_threshold.columns = [f"frac_beyond_{threshold}_from_source"] + # Join the computed proportions with the metadata for column in proportions_above_threshold.columns: - if column in sdata['xrna_metadata'].var.columns: - sdata['xrna_metadata'].var=sdata['xrna_metadata'].var.drop([column],axis=1) - sdata['xrna_metadata'].var = sdata['xrna_metadata'].var.join(proportions_above_threshold) + if column in sdata["xrna_metadata"].var.columns: + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.drop([column], axis=1) + sdata["xrna_metadata"].var = sdata["xrna_metadata"].var.join(proportions_above_threshold) return sdata.copy() if copy else None -def get_proportion_expressed_per_cell_type(adata,cell_type_key='cell type'): + +def get_proportion_expressed_per_cell_type(adata, cell_type_key="cell type"): cell_types = adata.obs[cell_type_key].unique().dropna() proportions = pd.DataFrame(index=adata.var_names, columns=cell_types) for cell_type in cell_types: diff --git a/src/troutpy/tl/target_cell.py b/src/troutpy/tl/target_cell.py index cac740c..3ec3a10 100644 --- a/src/troutpy/tl/target_cell.py +++ b/src/troutpy/tl/target_cell.py @@ -1,38 +1,27 @@ import numpy as np import pandas as pd -import os -import scanpy as sc -import seaborn as sns -import matplotlib.pyplot as plt -from tqdm import tqdm import spatialdata as sd - - - -from typing import Optional -import numpy as np from tqdm import tqdm -import scanpy as sc -import spatialdata as sd + def calculate_target_cells( sdata: sd.SpatialData, - layer: str = 'transcripts', - xcoord: str = 'x', - ycoord: str = 'y', - xcellcoord: str = 'x_centroid', - ycellcoord: str = 'y_centroid', - celltype_key: str = 'cell type', - gene_id_key:str='feature_name', - copy: bool = False -) -> Optional[sd.SpatialData]: + layer: str = "transcripts", + xcoord: str = "x", + ycoord: str = "y", + xcellcoord: str = "x_centroid", + ycellcoord: str = "y_centroid", + celltype_key: str = "cell type", + gene_id_key: str = "feature_name", + copy: bool = False, +) -> sd.SpatialData | None: """ Calculate the closest target cell for each transcript in a spatial omics dataset. This function identifies the nearest cell to each transcript based on spatial coordinates and annotates the transcript data with the ID, cell type, and distance to the closest cell. - Parameters: + Parameters ---------- sdata : sd.SpatialData SpatialData object containing spatial and transcript data. @@ -53,21 +42,21 @@ def calculate_target_cells( copy : bool, optional If True, returns a copy of the modified SpatialData object. Default is False. - Returns: + Returns ------- Optional[sd.SpatialData] Modified SpatialData object with updated transcript annotations if `copy=True`. Otherwise, updates are made in place, and None is returned. """ # Copy AnnData object from the SpatialData table - adata = sdata['table'].copy() + adata = sdata["table"].copy() # Use the 'raw' layer for transcript data - adata.X = sdata['table'].layers['raw'] + adata.X = sdata["table"].layers["raw"] # Extract x and y centroid coordinates from cell data - adata.obs[xcellcoord] = [sp[0] for sp in adata.obsm['spatial']] - adata.obs[ycellcoord] = [sp[1] for sp in adata.obsm['spatial']] + adata.obs[xcellcoord] = [sp[0] for sp in adata.obsm["spatial"]] + adata.obs[ycellcoord] = [sp[1] for sp in adata.obsm["spatial"]] # Extract transcript data from the specified layer transcripts = sdata.points[layer].compute() @@ -87,34 +76,39 @@ def calculate_target_cells( distances[i] = np.min(dist) # Annotate the transcript DataFrame with the closest cell information - transcripts['closest_target_cell'] = adata.obs.index[closest_cells].values - transcripts['closest_target_cell_type'] = adata.obs[celltype_key].values[closest_cells] - transcripts['distance_to_target_cell'] = distances + transcripts["closest_target_cell"] = adata.obs.index[closest_cells].values + transcripts["closest_target_cell_type"] = adata.obs[celltype_key].values[closest_cells] + transcripts["distance_to_target_cell"] = distances # Update the SpatialData object with the modified transcript data sdata.points[layer] = sd.models.PointsModel.parse(transcripts) - - extracellular_transcripts = transcripts[transcripts['extracellular']] + + extracellular_transcripts = transcripts[transcripts["extracellular"]] # Compute cross-tabulation between features and cell types (raw counts) - celltype_by_feature_raw = pd.crosstab(extracellular_transcripts[gene_id_key], extracellular_transcripts['closest_target_cell_type']) + celltype_by_feature_raw = pd.crosstab( + extracellular_transcripts[gene_id_key], extracellular_transcripts["closest_target_cell_type"] + ) # Normalize by the total number of each feature (row-wise normalization) celltype_by_feature = celltype_by_feature_raw.div(celltype_by_feature_raw.sum(axis=1), axis=0) # Create an output DataFrame and store computed proportions - outtable = pd.DataFrame(index=sdata['xrna_metadata'].var.index) - sdata['xrna_metadata'].varm['target'] = outtable.join(celltype_by_feature).to_numpy() + outtable = pd.DataFrame(index=sdata["xrna_metadata"].var.index) + sdata["xrna_metadata"].varm["target"] = outtable.join(celltype_by_feature).to_numpy() # Return a copy of the modified SpatialData object if requested return sdata.copy() if copy else None -def define_target_by_celltype(sdata, layer='transcripts', closest_celltype_key='closest_target_cell_type', feature_key='feature_name'): + +def define_target_by_celltype( + sdata, layer="transcripts", closest_celltype_key="closest_target_cell_type", feature_key="feature_name" +): """ Computes the proportion of features (e.g., transcripts) associated with each cell type in the spatial dataset. This function calculates a cross-tabulation between features (e.g., extracellular transcripts) and cell types, and then normalizes the result to provide the proportion of each feature associated with each cell type. - Parameters: + Parameters ---------- sdata : dict-like spatial data object (with 'points' key) A spatial data object that contains transcript and cell type information. The relevant data is accessed from the @@ -130,19 +124,18 @@ def define_target_by_celltype(sdata, layer='transcripts', closest_celltype_key=' feature_key : str, optional The column name representing the feature (e.g., transcript or gene) in the transcript data (default: 'feature_name'). - Returns: + Returns ------- pd.DataFrame - A pandas DataFrame where the rows represent features (e.g., transcripts), and the columns represent cell types. + A pandas DataFrame where the rows represent features (e.g., transcripts), and the columns represent cell types. Each entry in the DataFrame is the proportion of that feature associated with the respective cell type. - Notes: + Notes ----- - The function uses `pd.crosstab` to compute the raw count of each feature for each cell type. - The resulting counts are normalized by the total count of each feature (i.e., row-wise) to produce proportions. - Useful for analyzing which cell types are associated with specific features in a spatial omics dataset. """ - # Extract transcript data from the specified layer transcripts = sdata.points[layer][[feature_key, closest_celltype_key]].compute() # Compute cross-tabulation between features and cell types (raw counts)