diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index 6f66320..606d897 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections.abc import Mapping, Sequence import cycler @@ -72,6 +73,12 @@ def branches( ------- ax - The axes that the plot was drawn on. """ # noqa: D205 + # Setup + if not ax: + ax = plt.gca() + if (ax.name == "polar" and not polar) or (ax.name != "polar" and polar): + warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2) + fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None) kwargs = kwargs if kwargs else {} if not key: key = next(iter(tdata.obst.keys())) @@ -116,18 +123,20 @@ def branches( else: raise ValueError("Invalid linewidth value. Must be int, float, or an str specifying an attribute of the edges.") # Plot - if not ax: - subplot_kw = {"projection": "polar"} if polar else None - fig, ax = plt.subplots(subplot_kw=subplot_kw) - elif (ax.name == "polar") != polar: - raise ValueError("Provided axis does not match the requested 'polar' setting.") ax.add_collection(LineCollection(zorder=1, **kwargs)) - # Configure plot - lat_lim = (-0.2, depth) - lon_lim = (0, 2 * np.pi) - ax.set_xlim(lon_lim if polar else lat_lim) - ax.set_ylim(lat_lim if polar else lon_lim) - ax.axis("off") + if polar: + ax.set_ylim((-depth * 0.05, depth * 1.05)) + ax.spines["polar"].set_visible(False) + else: + ax.set_ylim((-0.03 * np.pi, 2.03 * np.pi)) + ax.set_xlim((-depth * 0.05, depth * 1.05)) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.tick_params(length=0) + ax.set_xticks([]) + ax.set_yticks([]) ax._attrs = { "node_coords": node_coords, "leaves": leaves, @@ -388,19 +397,23 @@ def annotation( # Plot if attrs["polar"]: ax.pcolormesh(lons, lats, rgb_array.swapaxes(0, 1), zorder=2, **kwargs) - ax.set_ylim(-0.2, end_lat) + ax.set_ylim(-attrs["depth"] * 0.05, end_lat) else: ax.pcolormesh(lats, lons, rgb_array, zorder=2, **kwargs) - ax.set_xlim(-0.2, end_lat) - labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1) - labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2) - for idx, label in enumerate(labels): - if is_array and len(labels) == 1: - ax.text(labels_lats[idx], -0.1, label, ha="center", va="top") - ax.set_ylim(-0.5, 2 * np.pi) - else: - ax.text(labels_lats[idx], -0.1, label, ha="center", va="top", rotation=90) - ax.set_ylim(-1, 2 * np.pi) + ax.set_xlim(-attrs["depth"] * 0.05, end_lat) + # Add labels + if labels and len(labels) > 0: + labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1) + labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2) + existing_ticks = ax.get_xticks() + existing_labels = [label.get_text() for label in ax.get_xticklabels()] + ax.set_xticks(np.append(existing_ticks, labels_lats[:-1])) + ax.set_xticklabels(existing_labels + labels) + for label in ax.get_xticklabels()[len(existing_ticks) :]: + if is_array and len(labels) == 1: + label.set_rotation(0) + else: + label.set_rotation(90) ax._attrs.update({"offset": end_lat}) return ax diff --git a/tests/test_plot_tree.py b/tests/test_plot_tree.py index c1c9b1d..fc5c2fd 100755 --- a/tests/test_plot_tree.py +++ b/tests/test_plot_tree.py @@ -9,7 +9,7 @@ def test_polar_with_clades(tdata): - fig, ax = plt.subplots(dpi=600, subplot_kw={"polar": True}) + fig, ax = plt.subplots(dpi=300, subplot_kw={"polar": True}) pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax) pycea.pl.nodes(tdata, color="clade", palette="Set1", style="clade", ax=ax) pycea.pl.annotation(tdata, keys="clade", ax=ax) @@ -18,19 +18,18 @@ def test_polar_with_clades(tdata): def test_angled_numeric_annotations(tdata): - fig, ax = plt.subplots(dpi=600) pycea.pl.branches( - tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True, ax=ax + tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True ) - pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20, ax=ax) - pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05, ax=ax) - pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes", ax=ax) + pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20) + pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05) + pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes") plt.savefig(plot_path / "angled_numeric.png") plt.close() def test_matrix_annotation(tdata): - fig, ax = plt.subplots(dpi=600) + fig, ax = plt.subplots(dpi=300) pycea.pl.tree( tdata, key="tree", @@ -44,19 +43,19 @@ def test_matrix_annotation(tdata): plt.close() -def test_branches_invalid_input(tdata): +def test_branches_bad_input(tdata): fig, ax = plt.subplots() with pytest.raises(ValueError): pycea.pl.branches(tdata, key="tree", color=["bad"] * 5) with pytest.raises(ValueError): pycea.pl.branches(tdata, key="tree", linewidth=["bad"] * 5) - # Can't plot polar with non-polar axis - with pytest.raises(ValueError): + # Warns about polar + with pytest.warns(match="Polar"): pycea.pl.branches(tdata, key="tree", polar=True, ax=ax) plt.close() -def test_annotation_invalid_input(tdata): +def test_annotation_bad_input(tdata): # Need to plot branches first fig, ax = plt.subplots() with pytest.raises(ValueError):