diff --git a/docs/api/index.rst b/docs/api/index.rst index 903f4ea5a8..df6c166fdd 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -25,6 +25,7 @@ All methods and submodules are listed :ref:`here ` and default_geom physics mixing + viz/index unit_constant utilities @@ -38,5 +39,4 @@ All methods and submodules are listed :ref:`here ` and :caption: Advanced usage nodes - viz/index diff --git a/docs/api/viz/index.rst b/docs/api/viz/index.rst index 82f5010485..e8bc0c3019 100644 --- a/docs/api/viz/index.rst +++ b/docs/api/viz/index.rst @@ -4,12 +4,37 @@ Visualization ============= -.. currentmodule:: sisl.viz +.. module:: sisl.viz -Visualizations of `sisl` objects and data. +The visualization module contains tools to plot common visualizations, as well +as to create custom visualizations that support multiple plotting backends +automatically. +Plot classes +----------------- + +Plot classes are workflow classes that implement some specific plotting. + +.. autosummary:: + :toctree: generated/ + + Plot + BandsPlot + FatbandsPlot + GeometryPlot + SitesPlot + GridPlot + WavefunctionPlot + PdosPlot + +Utilities +--------- + +Utilities to build custom plots .. autosummary:: - :toctree: generated/ - :recursive: + :toctree: generated/ + get_figure + merge_plots + Figure diff --git a/docs/toolbox/siesta/generated/sisl_toolbox.siesta.atom.AtomInput.rst b/docs/toolbox/siesta/generated/sisl_toolbox.siesta.atom.AtomInput.rst new file mode 100644 index 0000000000..d2babd8802 --- /dev/null +++ b/docs/toolbox/siesta/generated/sisl_toolbox.siesta.atom.AtomInput.rst @@ -0,0 +1,50 @@ +sisl\_toolbox.siesta.atom.AtomInput +=================================== + +.. currentmodule:: sisl_toolbox.siesta.atom + +.. autoclass:: AtomInput + + + + .. rubric:: Methods + + .. autosummary:: + + + + ~AtomInput.ae + + + ~AtomInput.excite + + + ~AtomInput.from_input + + + ~AtomInput.from_yaml + + + ~AtomInput.pg + + + ~AtomInput.plot + + + ~AtomInput.write_all_electron + + + ~AtomInput.write_generation + + + ~AtomInput.write_header + + + ~AtomInput.write_test + + + + + + + \ No newline at end of file diff --git a/docs/toolbox/transiesta/generated/sisl_toolbox.transiesta.poisson.solve_poisson.rst b/docs/toolbox/transiesta/generated/sisl_toolbox.transiesta.poisson.solve_poisson.rst new file mode 100644 index 0000000000..1c7f3c0642 --- /dev/null +++ b/docs/toolbox/transiesta/generated/sisl_toolbox.transiesta.poisson.solve_poisson.rst @@ -0,0 +1,6 @@ +solve_poisson +============================================= + +.. currentmodule:: sisl_toolbox.transiesta.poisson + +.. autofunction:: solve_poisson \ No newline at end of file diff --git a/docs/visualization/viz_module/index.rst b/docs/visualization/viz_module/index.rst index 2c2f3ed271..bca26717d0 100644 --- a/docs/visualization/viz_module/index.rst +++ b/docs/visualization/viz_module/index.rst @@ -35,6 +35,7 @@ The following notebooks will help you develop a deeper understanding of what eac :name: viz-plotly-showcase-gallery showcase/GeometryPlot.ipynb + showcase/SitesPlot.ipynb showcase/GridPlot.ipynb showcase/BandsPlot.ipynb showcase/FatbandsPlot.ipynb diff --git a/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb b/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb index 772db951c9..14811a063f 100644 --- a/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb +++ b/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb @@ -119,7 +119,7 @@ "metadata": {}, "outputs": [], "source": [ - "fatbands.split_groups(on=\"species\")" + "fatbands.split_orbs(on=\"species\", name=\"$species\")" ] }, { diff --git a/docs/visualization/viz_module/showcase/SitesPlot.ipynb b/docs/visualization/viz_module/showcase/SitesPlot.ipynb new file mode 100644 index 0000000000..4ca58075eb --- /dev/null +++ b/docs/visualization/viz_module/showcase/SitesPlot.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "notebook-header" + ] + }, + "source": [ + "[![GitHub issues by-label](https://img.shields.io/github/issues-raw/pfebrer/sisl/SitesPlot?style=for-the-badge)](https://github.com/pfebrer/sisl/labels/SitesPlot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " \n", + " \n", + "SitesPlot\n", + "=========\n", + "\n", + "The `SitesPlot` is simply an adaptation of `GeometryPlot`'s machinery to any class that can be represented as sites in space. The main difference is that it doesn't show bonds, and also inputs with the word `atoms` are renamed to `sites`. Therefore, see `GeometryPlot`'s showcase notebook to understand the full customization possibilities.\n", + "\n", + "We are just going to show how you can plot the k points of a `BrillouinZone` object with it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sisl\n", + "import sisl.viz\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a circle of K points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sisl.geom.graphene()\n", + "\n", + "# Create the circle\n", + "bz = sisl.BrillouinZone.param_circle(\n", + " g,\n", + " kR=0.0085,\n", + " origin= [0.0, 0.0, 0.0],\n", + " normal= [0.0, 0.0, 1.0],\n", + " N_or_dk=25,\n", + " loop=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And then generate some fake vectorial data for it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = np.zeros((len(bz), 3))\n", + "\n", + "data[:, 0] = - bz.k[:, 1]\n", + "data[:, 1] = bz.k[:, 0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now plot the k points, showing the vectorial data as arrows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot k points as sites\n", + "bz.plot.sites(\n", + " axes=\"xy\", drawing_mode=\"line\", sites_style={\"color\": \"black\", \"size\": 2},\n", + " arrows={\"data\": data, \"color\": \"red\", \"width\": 3, \"name\": \"Force\"}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----\n", + "This next cell is just to create the thumbnail for the notebook in the docs " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "nbsphinx-thumbnail" + ] + }, + "outputs": [], + "source": [ + "thumbnail_plot = _\n", + "\n", + "if thumbnail_plot:\n", + " thumbnail_plot.show(\"png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "notebook-footer" + ] + }, + "source": [ + "-------------" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/sisl/nodes/node.py b/src/sisl/nodes/node.py index 468fc79b1f..c94635c937 100644 --- a/src/sisl/nodes/node.py +++ b/src/sisl/nodes/node.py @@ -76,7 +76,7 @@ class Node(NDArrayOperatorsMixin): _prev_evaluated_inputs: Dict[str, Any] # Current output value of the node - _output: Any + _output: Any = _blank # Nodes that are connected to this node's inputs _input_nodes: Dict[str, Node] @@ -116,7 +116,7 @@ def __call__(self, *args, **kwargs): def setup(self, *args, **kwargs): """Sets up the node based on its initial inputs.""" # Parse inputs into arguments. - bound_params = self.__class__.__signature__.bind_partial(*args, **kwargs) + bound_params = inspect.signature(self.function).bind_partial(*args, **kwargs) bound_params.apply_defaults() self._inputs = bound_params.arguments @@ -192,7 +192,6 @@ def __init_subclass__(cls): if parameter.kind == parameter.VAR_KEYWORD: cls._kwargs_inputs_key = key - cls.__init__.__signature__ = init_sig cls.__signature__ = no_self_sig return super().__init_subclass__() @@ -401,7 +400,7 @@ def get_tree(self): @property def default_inputs(self): - params = self.__class__.__signature__.bind_partial() + params = inspect.signature(self.function).bind_partial() params.apply_defaults() return params.arguments @@ -486,7 +485,7 @@ def update_inputs(self, **inputs): explicit_kwargs = inputs.pop(self._kwargs_inputs_key, None) # Parse the inputs. We do this to separate the kwargs from the rest of the inputs. - bound = self.__class__.__signature__.bind_partial(**inputs) + bound = inspect.signature(self.function).bind_partial(**inputs) inputs = bound.arguments # Now that we have parsed the inputs, put back the args key (if any). diff --git a/src/sisl/nodes/workflow.py b/src/sisl/nodes/workflow.py index 9f33af9d8f..77d7353544 100644 --- a/src/sisl/nodes/workflow.py +++ b/src/sisl/nodes/workflow.py @@ -737,6 +737,7 @@ def function(*args, **kwargs): ) def setup(self, *args, **kwargs): + self.nodes = self.dryrun_nodes super().setup(*args, **kwargs) self.nodes = self.dryrun_nodes.copy(inputs=self._inputs) diff --git a/src/sisl/viz/__init__.py b/src/sisl/viz/__init__.py index 239076d25b..81b391cbc3 100644 --- a/src/sisl/viz/__init__.py +++ b/src/sisl/viz/__init__.py @@ -5,13 +5,6 @@ Visualization utilities ======================= -Various visualization modules are described here. - - -Plotly -====== - -The plotly backend. """ import os diff --git a/src/sisl/viz/_plotables.py b/src/sisl/viz/_plotables.py index 816076608d..7dbd38e362 100644 --- a/src/sisl/viz/_plotables.py +++ b/src/sisl/viz/_plotables.py @@ -182,7 +182,11 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N # Register the function in the plot_handler plot_handler.register(name, plot_dispatch, default=default, **kwargs) -def register_data_source(data_source_cls, plot_cls, setting_key, name=None, default: Sequence[Type] = [], plot_handler_attr='plot', **kwargs): +def register_data_source( + data_source_cls, plot_cls, setting_key, name=None, default: Sequence[Type] = [], plot_handler_attr='plot', + data_source_init_kwargs: dict = {}, + **kwargs +): plot_cls_params = { name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) @@ -195,6 +199,15 @@ def register_data_source(data_source_cls, plot_cls, setting_key, name=None, defa signature = inspect.signature(func) + register_this = True + for k in data_source_init_kwargs.keys(): + if k not in signature.parameters: + register_this = False + break + + if not register_this: + continue + new_parameters = [] data_args = [] replaced_data_args = {} @@ -241,6 +254,10 @@ def _plot(obj, *args, __params_info=params_info, __signature=signature, **kwargs data_kwargs[data_key] = bound.arguments.pop(k) except Exception as e: raise TypeError(f"Error while parsing arguments to create the {data_source_cls.__name__}") + + for k, v in data_source_init_kwargs.items(): + if k not in data_kwargs: + data_kwargs[k] = v data = data_source_cls.new(obj, *args, **data_kwargs) diff --git a/src/sisl/viz/_plotables_register.py b/src/sisl/viz/_plotables_register.py index 067e7ae1c0..d66e150971 100644 --- a/src/sisl/viz/_plotables_register.py +++ b/src/sisl/viz/_plotables_register.py @@ -31,7 +31,7 @@ register_data_source(PDOSData, PdosPlot, "pdos_data", default=[siesta.pdosSileSiesta]) register_data_source(BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta]) -register_data_source(FatBandsData, FatbandsPlot, "bands_data") +register_data_source(BandsData, FatbandsPlot, "bands_data", data_source_init_kwargs={"extra_vars": ("norm2", )}) register_data_source(EigenstateData, WavefunctionPlot, "eigenstate", default=[sisl.EigenstateElectron]) # ----------------------------------------------------- diff --git a/src/sisl/viz/splot.py b/src/sisl/viz/_splot.py similarity index 100% rename from src/sisl/viz/splot.py rename to src/sisl/viz/_splot.py diff --git a/src/sisl/viz/data/__init__.py b/src/sisl/viz/data/__init__.py index 1eb5fc62bc..38d64da2b2 100644 --- a/src/sisl/viz/data/__init__.py +++ b/src/sisl/viz/data/__init__.py @@ -1,4 +1,4 @@ -from .bands import BandsData, FatBandsData +from .bands import BandsData from .data import Data from .eigenstate import EigenstateData from .pdos import PDOSData diff --git a/src/sisl/viz/data/bands.py b/src/sisl/viz/data/bands.py index d10b273ae3..ad261ffb48 100644 --- a/src/sisl/viz/data/bands.py +++ b/src/sisl/viz/data/bands.py @@ -16,7 +16,7 @@ from .._single_dispatch import singledispatchmethod from ..data_sources import FileDataSIESTA, HamiltonianDataSource -from .xarray import OrbitalData, XarrayData +from .xarray import XarrayData try: import pathos @@ -113,20 +113,21 @@ def toy_example(cls, spin: Union[str, int, Spin] = "", n_states: int = 20, nk: i y = np.outer(x ** 2, random_polinomials[..., 0]) + np.outer(x, random_polinomials[..., 1]) + random_polinomials[..., 2].ravel() y = y.reshape(nk, *polynoms_shape) - if spin.is_polarized: - y = y.transpose(1, 0, 2) if spin.is_polarized: # Make sure that the top of the valence band and bottom of the conduction band - # are always spin 0 (to facilitate computation of the gap). - if y[0, ..., n_bands // 2 - 1].max() > y[1, ..., n_bands // 2 - 1].max(): - y[:, :, n_bands // 2 - 1] = y[::-1, :, n_bands // 2 - 1] - if y[0, ..., n_bands // 2].min() < y[1, ..., n_bands // 2].min(): - y[:, :, n_bands // 2] = y[::-1, :, n_bands // 2] + # are the same spin (to facilitate computation of the gap). + VB_spin = y[..., :n_bands // 2].argmin() // (nk * n_bands) + CB_spin = y[..., n_bands // 2:].argmax() // (nk * n_bands) + + if VB_spin != CB_spin: + y[..., n_bands // 2:] = np.flip(y[..., n_bands // 2:], axis=0) + + y = y.transpose(1, 0, 2) # Compute gap limits - top_VB = y[..., n_bands // 2 - 1].max() - bottom_CB = y[..., n_bands // 2].min() + top_VB = y[..., :n_bands // 2 ].max() + bottom_CB = y[..., n_bands // 2:].min() # Correct the gap if some specific value was requested generated_gap = bottom_CB - top_VB @@ -226,6 +227,12 @@ def from_dataset(cls, bands_data: xr.Dataset): **old_attrs, "spin": spin } + if "geometry" not in bands_data.attrs: + if "parent" in bands_data.attrs: + parent = bands_data.attrs["parent"] + if hasattr(parent, "geometry"): + bands_data.attrs['geometry'] = parent.geometry + return cls(bands_data) @new.register @@ -473,7 +480,7 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa } ).assign_coords(coords_values) - bands_data.attrs = {"parent": bz, "spin": spin} + bands_data.attrs = {"parent": bz, "spin": spin, "geometry": geometry} return cls.new(bands_data) @@ -584,69 +591,29 @@ def bands_wrapper(eigenstate, spin_index): return bands_wrapper, all_vars, coords_values -def _weights_from_eigenstate(eigenstate, spin, spin_index): +def _norm2_from_eigenstate(eigenstate, spin, spin_index): - weights = eigenstate.norm2(sum=False) + norm2 = eigenstate.norm2(sum=False) if not spin.is_diagonal: # If it is a non-colinear or spin orbit calculation, we have two weights for each # orbital (one for each spin component of the state), so we just pair them together # and sum their contributions to get the weight of the orbital. - weights = weights.reshape(len(weights), -1, 2).sum(2) + norm2 = norm2.reshape(len(norm2), -1, 2).sum(2) - return weights.real + return norm2.real def _spin_moment_getter(eigenstate, spin, spin_index): return eigenstate.spin_moment().real _KNOWN_EIGENSTATE_VARS = { - "weight": { + "norm2": { "coords": ("band", "orb"), - "name": "weight", - "getter": _weights_from_eigenstate + "name": "norm2", + "getter": _norm2_from_eigenstate }, "spin_moment": { "coords": ("axis", "band"), "coords_values": dict(axis=["x", "y", "z"]), "name": "spin_moments", "getter": _spin_moment_getter } } - -class FatBandsData(BandsData, OrbitalData): - """A BandsData subclass that adds the possibility to calculate fatbands. - - It is a thin wrapper around BandsData that adds the possibility to - calculate fatbands.""" - - @singledispatchmethod - @classmethod - def new(cls, bands_data: xr.Dataset): - if "geometry" not in bands_data.attrs: - if "parent" in bands_data.attrs: - bands_data.attrs["geometry"] = bands_data.attrs["parent"].geometry - - return super().from_dataset(bands_data) - - @new.register - @classmethod - def from_path(cls, path: Path, *args, **kwargs): - """Creates a sile from the path and tries to read the PDOS from it.""" - return cls.new(sisl.get_sile(path), *args, **kwargs) - - @new.register - @classmethod - def from_string(cls, string: str, *args, **kwargs): - """Assumes the string is a path to a file""" - return cls.new(Path(string), *args, **kwargs) - - - @new.register - @classmethod - def from_hamiltonian(cls, band_structure: sisl.BandStructure, H: Union[sisl.Hamiltonian, None] = None, extra_vars: Sequence[Union[Dict, str]] = ()): - extra_vars = (*extra_vars, "weight") - return super().from_hamiltonian(band_structure=band_structure, H=H, extra_vars=extra_vars) - - @new.register - @classmethod - def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars: Sequence[Union[Dict, str]] = (), need_H: bool = False): - extra_vars = (*extra_vars, "weight") - return super().from_wfsx(wfsx_file=wfsx_file, fdf=fdf, extra_vars=extra_vars, need_H=need_H) diff --git a/src/sisl/viz/figure/figure.py b/src/sisl/viz/figure/figure.py index bcbb9e9612..bf2d9a32fb 100644 --- a/src/sisl/viz/figure/figure.py +++ b/src/sisl/viz/figure/figure.py @@ -194,28 +194,6 @@ def _iter_subplots(self, plot_actions): def _iter_animation(self, plot_actions): raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _iter_animation method.") - - # def _ipython_display_(self, return_figWidget=False, **kwargs): - # """ Handles all things needed to display the plot in a jupyter notebook. - - # Parameters - # ------ - # return_figureWidget: bool, optional - # if the plot is displayed in a jupyter notebook, whether you want to - # get the figure widget as a return so that you can act on it. - # """ - # from IPython.display import display - - # widget = self.get_ipywidget() - - # # Else, show without shortcut support - # display(widget) - - # if return_figWidget: - # return widget - - def get_ipywidget(self): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a get_ipywidget method.") def clear(self): """Clears the figure so that we can draw again.""" @@ -239,7 +217,7 @@ def init_coloraxis(self, name, cmin=None, cmax=None, cmid=None, colorscale=None, } def draw_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, **kwargs): - """Should draw a line satisfying the specifications + """Draws a line satisfying the specifications Parameters ----------- @@ -270,7 +248,7 @@ def draw_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, co raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line method.") def draw_multicolor_line(self, *args, line={}, row=None, col=None, **kwargs): - """By default, multicoloured lines are drawn simply by drawing the points.""" + """By default, multicoloured lines are drawn simply by drawing scatter points.""" marker = { **kwargs.pop('marker', {}), 'color': line.get('color'), @@ -281,7 +259,7 @@ def draw_multicolor_line(self, *args, line={}, row=None, col=None, **kwargs): self.draw_multicolor_scatter(*args, marker=marker, row=row, col=col, **kwargs) def draw_multisize_line(self, *args, line={}, row=None, col=None, **kwargs): - """By default, multisized lines are drawn simple by drawing .""" + """By default, multisized lines are drawn simple by drawing scatter points.""" marker = { **kwargs.pop('marker', {}), 'color': line.get('color'), @@ -310,6 +288,9 @@ def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=Non contains the text asigned to each marker. On plotly this is seen on hover, other options could be annotating. However, it is not necessary that this argument is supported. + dependent_axis: str, optional + The axis that contains the dependent variable. This is important because + the area is drawn in parallel to that axis. row: int, optional If the figure contains subplots, the row where to draw. col: int, optional @@ -321,7 +302,7 @@ def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=Non raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_area_line method.") def draw_multicolor_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): - """Same as draw line, but to draw a line with an area. This is for example used to draw fatbands. + """Draw a line with an area with multiple colours. Parameters ----------- @@ -339,6 +320,9 @@ def draw_multicolor_area_line(self, x, y, name=None, line={}, text=None, depende contains the text asigned to each marker. On plotly this is seen on hover, other options could be annotating. However, it is not necessary that this argument is supported. + dependent_axis: str, optional + The axis that contains the dependent variable. This is important because + the area is drawn in parallel to that axis. row: int, optional If the figure contains subplots, the row where to draw. col: int, optional @@ -350,7 +334,9 @@ def draw_multicolor_area_line(self, x, y, name=None, line={}, text=None, depende raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_multicolor_area_line method.") def draw_multisize_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): - """Same as draw line, but to draw a line with an area. This is for example used to draw fatbands. + """Draw a line with an area with multiple colours. + + This is already usually supported by the normal draw_area_line. Parameters ----------- @@ -368,6 +354,9 @@ def draw_multisize_area_line(self, x, y, name=None, line={}, text=None, dependen contains the text asigned to each marker. On plotly this is seen on hover, other options could be annotating. However, it is not necessary that this argument is supported. + dependent_axis: str, optional + The axis that contains the dependent variable. This is important because + the area is drawn in parallel to that axis. row: int, optional If the figure contains subplots, the row where to draw. col: int, optional @@ -380,7 +369,7 @@ def draw_multisize_area_line(self, x, y, name=None, line={}, text=None, dependen return self.draw_area_line(x, y, name=name, line=line, text=text, dependent_axis=dependent_axis, row=row, col=col, **kwargs) def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs): - """Should draw a scatter satisfying the specifications + """Draws a scatter satisfying the specifications Parameters ----------- @@ -409,10 +398,18 @@ def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter method.") def draw_multicolor_scatter(self, *args, **kwargs): + """Draws a multicoloured scatter. + + Usually the normal scatter can already support this. + """ # Usually, multicoloured scatter plots are already supported. return self.draw_scatter(*args, **kwargs) def draw_multisize_scatter(self, *args, **kwargs): + """Draws a multisized scatter. + + Usually the normal scatter can already support this. + """ # Usually, multisized scatter plots are already supported. return self.draw_scatter(*args, **kwargs) @@ -432,6 +429,8 @@ def draw_arrows(self, x, y, dxy, arrowhead_scale=0.2, arrowhead_angle=20, scale: scale: float, optional multiplying factor to display the arrows. It does not affect the underlying data, therefore if the data is somehow displayed it should be without the scale factor. + annotate: + whether to annotate the arrows with the vector they represent. row: int, optional If the figure contains subplots, the row where to draw. col: int, optional @@ -480,7 +479,7 @@ def draw_arrows(self, x, y, dxy, arrowhead_scale=0.2, arrowhead_angle=20, scale: return self.draw_line(arrows[:, 0], arrows[:, 1], hovertext=list(hovertext), row=row, col=col, **kwargs) def draw_line_3D(self, x, y, z, name=None, line={}, marker={}, text=None, row=None, col=None, **kwargs): - """Should draw a 3D line satisfying the specifications + """Draws a 3D line satisfying the specifications. Parameters ----------- @@ -513,13 +512,15 @@ def draw_line_3D(self, x, y, z, name=None, line={}, marker={}, text=None, row=No raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line_3D method.") def draw_multicolor_line_3D(self, *args, **kwargs): + """Draws a multicoloured 3D line.""" self.draw_line_3D(*args, **kwargs) def draw_multisize_line_3D(self, *args, **kwargs): + """Draws a multisized 3D line.""" self.draw_line_3D(*args, **kwargs) def draw_scatter_3D(self, x, y, z, name=None, marker={}, text=None, row=None, col=None, **kwargs): - """Should draw a 3D scatter satisfying the specifications + """Draws a 3D scatter satisfying the specifications Parameters ----------- @@ -549,17 +550,30 @@ def draw_scatter_3D(self, x, y, z, name=None, marker={}, text=None, row=None, co raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter_3D method.") def draw_multicolor_scatter_3D(self, *args, **kwargs): + """Draws a multicoloured 3D scatter. + + Usually the normal 3D scatter can already support this. + """ # Usually, multicoloured scatter plots are already supported. return self.draw_scatter_3D(*args, **kwargs) def draw_multisize_scatter_3D(self, *args, **kwargs): + """Draws a multisized 3D scatter. + + Usually the normal 3D scatter can already support this. + """ # Usually, multisized scatter plots are already supported. return self.draw_scatter_3D(*args, **kwargs) def draw_balls_3D(self, x, y, z, name=None, markers={}, row=None, col=None, **kwargs): + """Draws points as 3D spheres.""" return NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_balls_3D method.") def draw_multicolor_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs): + """Draws points as 3D spheres with different colours. + + If marker_color is an array of numbers, a coloraxis is created and values are converted to rgb. + """ kwargs['marker'] = marker.copy() @@ -572,16 +586,24 @@ def draw_multicolor_balls_3D(self, x, y, z, name=None, marker={}, row=None, col= return self.draw_balls_3D(x, y, z, name=name, row=row, col=col, **kwargs) def draw_multisize_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs): + """Draws points as 3D spheres with different sizes. + + Usually supported by the normal draw_balls_3D + """ return self.draw_balls_3D(x, y, z, name=name, row=row, col=col, **kwargs) def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, scale: float = 1, row=None, col=None, **kwargs): - """Draws multiple arrows using the generic draw_line method. + """Draws multiple 3D arrows using the generic draw_line_3D method. Parameters ----------- - xy: np.ndarray of shape (n_arrows, 2) - the positions where the atoms start. - dxy: np.ndarray of shape (n_arrows, 2) + x: np.ndarray of shape (n_arrows, ) + the X coordinates of the arrow's origin. + y: np.ndarray of shape (n_arrows, ) + the Y coordinates of the arrow's origin. + z: np.ndarray of shape (n_arrows, ) + the Z coordinates of the arrow's origin. + dxyz: np.ndarray of shape (n_arrows, 2) the arrow vector. arrow_head_scale: float, optional how big is the arrow head in comparison to the arrow vector. @@ -649,9 +671,11 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, return self.draw_line_3D(arrows[:, 0], arrows[:, 1], arrows[:, 2], row=row, col=col, **kwargs) def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None): + """Draws a heatmap following the specifications.""" raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_heatmap method.") def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name=None, row=None, col=None, **kwargs): + """Draws a 3D mesh following the specifications.""" raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_mesh_3D method.") def set_axis(self, **kwargs): @@ -672,8 +696,7 @@ def to(self, key: str): Parameters ----------- key: str - the format to convert to. This can be any of the formats supported by - the `plotly.io` module. + the backend to convert to. """ return BACKENDS[key](self.plot_actions) diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py index 9ad879c489..264aa8a2c1 100644 --- a/src/sisl/viz/figure/plotly.py +++ b/src/sisl/viz/figure/plotly.py @@ -253,18 +253,6 @@ def _iter_animation(self, plot_actions): self.update_layout(sliders=[slider], updatemenus=updatemenus) - def _figure_animate_method(self, children, frame_names): - """ - In the animate method, we explicitly define frames, And the transition from one to the other - will be animated - """ - # Here are some things that were settings - - - - - return steps, updatemenus - def __getattr__(self, key): if key != "figure": return getattr(self.figure, key) @@ -293,266 +281,6 @@ def clear(self, frames=True, layout=False): return self - def get_ipywidget(self): - return go.FigureWidget(self.figure, ) - - def _update_ipywidget(self, fig_widget): - """ Updates a figure widget so that it is in sync with this plot's data - - Parameters - ---------- - fig_widget: plotly.graph_objs.FigureWidget - The figure widget that we need to extend. - """ - fig_widget.data = [] - fig_widget.add_traces(self.data) - fig_widget.layout = self.layout - fig_widget.update(frames=self.frames) - - #------------------------------------------- - # PLOT MANIPULATION METHODS - #------------------------------------------- - - def group_legend(self, by=None, names=None, show_all=False, extra_updates=None, **kwargs): - """ Joins plot traces in groups in the legend - - As the result of this method, plot traces end up with a legendgroup attribute. - You can use that for selecting traces in further processing of your plot. - - This also provides the ability to toggle the whole group from the legend, which is nice. - - Parameters - --------- - by: str or function, optional - it defines what are the criteria to group the traces. - - If it's a string: - It is the name of the trace attribute. Remember that plotly allows you to - lookup for nested attributes using underscores. E.g: "line_color" gets {line: {color: THIS VALUE}} - If it's a function: - It will recieve each trace and needs to decide which group to put it in by returning the group value. - Note that the value will also be used as the group name if `names` is not provided, so you can save yourself - some code and directly return the group's name. - If not provided: - All traces will be put in the same group - names: array-like, dict or function, optional - it defines what the names of the generated groups will be. - - If it's an array: - When a new group is found, the name will be taken from this array (order can be very arbitrary) - If it's a dict: - When a new group is found, the value of the group will be used as a key to get the name from this dictionary. - If the key is not found, the name will just be the value. - E.g.: If grouping by `line_color` and `blue` is found, the name will be `names.get('blue', 'blue')` - If it's a function: - It will recieve the group value and the trace and needs to return the name of the TRACE. - NOTE: If `show_all` is set to `True` all traces will appear in the legend, so it would be nice - to give them different names. Otherwise, you can just return the group's name. - If you provided a grouping function and `show_all` is False you don't need this, as you can return - directly the group name from there. - If not provided: - the values will be used as names. - show_all: boolean, optional - whether all the items of the group should be displayed in the legend. - If `False`, only one item per group will be displayed. - If `True`, all the items of the group will be displayed. - extra_updates: dict, optional - A dict stating extra updates that you want to do for each group. - - E.g.: `{"blue": {"line_width": 4}}` - - would also convert the lines with a group VALUE (not name) of "blue" to a width of 4. - - This is just for convenience so that you can run other methods after this one. - Note that you can always do something like this by doing - - ``` - plot.update_traces( - selector={"line_width": "blue"}, # Selects the traces that you will update - line_width=4, - ) - ``` - - If you use a function to return the group values, there is probably no point on using this - argument. Since you recieve the trace, you can run `trace.update(...)` inside your function. - **kwargs: - like extra_updates but they are passed to all groups without distinction - """ - unique_values = [] - - # Normalize the "by" parameter to a function - if by is None: - if show_all: - name = names[0] if names is not None else "Group" - self.figure.update_traces(showlegend=True, legendgroup=name, name=name) - return self - else: - func = lambda trace: 0 - if isinstance(by, str): - def func(trace): - try: - return trace[by] - except Exception: - return None - else: - func = by - - # Normalize also the names parameter to a function - if names is None: - def get_name(val, trace): - return str(val) if not show_all else f'{val}: {trace.name}' - elif callable(names): - get_name = names - elif isinstance(names, dict): - def get_name(val, trace): - name = names.get(val, val) - return str(name) if not show_all else f'{name}: {trace.name}' - else: - def get_name(val, trace): - name = names[len(unique_values) - 1] - return str(name) if not show_all else f'{name}: {trace.name}' - - # And finally normalize the extra updates - if extra_updates is None: - get_extra_updates = lambda *args, **kwargs: {} - elif isinstance(extra_updates, dict): - get_extra_updates = lambda val, trace: extra_updates.get(val, {}) - elif callable(extra_updates): - get_extra_updates = extra_updates - - # Build the function that will apply the change - def check_and_apply(trace): - - val = func(trace) - - if isinstance(val, np.ndarray): - val = val.tolist() - if isinstance(val, list): - val = ", ".join([str(item) for item in val]) - - if val in unique_values: - showlegend = show_all - else: - unique_values.append(val) - showlegend = True - - customdata = trace.customdata if trace.customdata is not None else [{}] - - trace.update( - showlegend=showlegend, - legendgroup=str(val), - name=get_name(val, trace=trace), - customdata=[{**customdata[0], "name": trace.name}, *customdata[1:]], - **get_extra_updates(val, trace=trace), - **kwargs - ) - - # And finally apply all the changes - self.figure.for_each_trace( - lambda trace: check_and_apply(trace) - ) - - return self - - def ungroup_legend(self): - """ Ungroups traces if a legend contains groups """ - self.figure.for_each_trace( - lambda trace: trace.update( - legendgroup=None, - showlegend=True, - name=trace.customdata[0]["name"] - ) - ) - - return self - - def normalize(self, min_val=0, max_val=1, axis="y", **kwargs): - """ Normalizes traces to a given range along an axis - - Parameters - ----------- - min_val: float, optional - The lower bound of the range. - max_val: float, optional - The upper part of the range - axis: {"x", "y", "z"}, optional - The axis along which we want to normalize. - **kwargs: - keyword arguments that are passed directly to plotly's Figure `for_each_trace` - method. You can check its documentation. One important thing is that you can pass a - 'selector', which will choose if the trace is updated or not. - """ - from ..plotutils import normalize_trace - - self.for_each_trace(partial(normalize_trace, min_val=min_val, max_val=max_val, axis=axis), **kwargs) - - return self - - def swap_axes(self, ax1='x', ax2='y', **kwargs): - """ Swaps two axes in the plot - - Parameters - ----------- - ax1, ax2: str, {'x', 'x*', 'y', 'y*', 'z', 'z*'} - The names of the axes that you want to swap. - **kwargs: - keyword arguments that are passed directly to plotly's Figure `for_each_trace` - method. You can check its documentation. One important thing is that you can pass a - 'selector', which will choose if the trace is updated or not. - """ - from ..plotutils import swap_trace_axes - - # Swap the traces - self.for_each_trace(partial(swap_trace_axes, ax1=ax1, ax2=ax2), **kwargs) - - # Try to also swap the axes - try: - self.update_layout({ - f'{ax1}axis': self.layout[f'{ax2}axis'].to_plotly_json(), - f'{ax2}axis': self.layout[f'{ax1}axis'].to_plotly_json(), - }, overwrite=True) - except: - pass - - return self - - def shift(self, shift, axis="y", **kwargs): - """ Shifts the traces of the plot by a given value in the given axis - - Parameters - ----------- - shift: float or array-like - If it's a float, it will be a solid shift (i.e. all points moved equally). - If it's an array, an element-wise sum will be performed - axis: {"x","y","z"}, optional - The axis along which we want to shift the traces. - **kwargs: - keyword arguments that are passed directly to plotly's Figure `for_each_trace` - method. You can check its documentation. One important thing is that you can pass a - 'selector', which will choose if the trace is updated or not. - """ - from ..plotutils import shift_trace - - self.for_each_trace(partial(shift_trace, shift=shift, axis=axis), **kwargs) - - return self - - # ----------------------------- - # SOME OTHER METHODS - # ----------------------------- - - def to_chart_studio(self, *args, **kwargs): - """ Sends the plot to chart studio if it is possible - - For it to work, the user should have their credentials correctly set up. - - It is a shortcut for chart_studio.plotly.plot(self.figure, ...etc) so you can pass any extra arguments as if - you were using `py.plot` - """ - import chart_studio.plotly as py - - return py.plot(self.figure, *args, **kwargs) - # -------------------------------- # METHODS TO STANDARIZE BACKENDS # -------------------------------- @@ -652,9 +380,9 @@ def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=Non "x": x, "y": y, "line": {"width": 0, "color": line.get('color')}, - #"showlegend": is_group_first and i_chunk == 0, "name": name, "legendgroup": name, + "showlegend": kwargs.pop("showlegend", None), "fill": "toself" }, row=row, col=col) diff --git a/src/sisl/viz/plot.py b/src/sisl/viz/plot.py index 77c939f19f..ddfc50329d 100644 --- a/src/sisl/viz/plot.py +++ b/src/sisl/viz/plot.py @@ -1,8 +1,8 @@ from sisl.messages import deprecate from sisl.nodes import Workflow - class Plot(Workflow): + """Base class for all plots""" def __getattr__(self, key): if key != "nodes": diff --git a/src/sisl/viz/plots/__init__.py b/src/sisl/viz/plots/__init__.py index e6b3100c54..88f2eaf27f 100644 --- a/src/sisl/viz/plots/__init__.py +++ b/src/sisl/viz/plots/__init__.py @@ -1,8 +1,5 @@ """ -Plots -===== - -Module containing all implemented plots, both in a functional form and as Workflows. +Module containing all sisl-provided plots, both in a functional form and as Workflows. """ from .bands import BandsPlot, FatbandsPlot, bands_plot, fatbands_plot diff --git a/src/sisl/viz/plots/bands.py b/src/sisl/viz/plots/bands.py index c0643e65f5..77b8bbd053 100644 --- a/src/sisl/viz/plots/bands.py +++ b/src/sisl/viz/plots/bands.py @@ -1,10 +1,10 @@ -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, Sequence, Dict import numpy as np -from sisl.viz.types import OrbitalQueries +from sisl.viz.types import OrbitalQueries, StyleSpec -from ..data.bands import BandsData, FatBandsData +from ..data.bands import BandsData from ..figure import Figure, get_figure from ..plot import Plot from ..plotters.plot_actions import combined @@ -19,11 +19,60 @@ def bands_plot(bands_data: BandsData, - Erange: Optional[Tuple[float, float]] = None, E0=0, E_axis: Literal["x", "y"] = "y", bands_range=None, spin=None, - bands_style={'color': 'black', 'width': 1, "opacity": 1}, spindown_style={"color": "blue", "width": 1}, colorscale=None, - gap=False, gap_tol=0.01, gap_color="red", gap_marker={"size": 7}, direct_gaps_only=False, custom_gaps=[], - line_mode: Literal["line", "scatter", "area_line"] = "line", backend: str = "plotly" + Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", + bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, + bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, + spindown_style: StyleSpec = {"color": "blue", "width": 1}, + colorscale: Optional[str] = None, + gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, + custom_gaps: Sequence[Dict] = [], + line_mode: Literal["line", "scatter", "area_line"] = "line", + backend: str = "plotly" ) -> Figure: + """Plots band structure energies, with plentiful of customization options. + + Parameters + ---------- + bands_data: + The object containing the data to plot. + Erange: + The energy range to plot. + If None, the range is determined by ``bands_range``. + E0: + The energy reference. + E_axis: + Axis to plot the energies. + bands_range: + The bands to plot. Only used if ``Erange`` is None. + If None, the 15 bands above and below the Fermi level are plotted. + spin: + Which spin channel to display. Only meaningful for spin-polarized calculations. + If None and the calculation is spin polarized, both are plotted. + bands_style: + Styling attributes for bands. + spindown_style: + Styling attributes for the spin down bands (if present). Any missing attribute + will be taken from ``bands_style``. + colorscale: + Colorscale to use for the bands in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + gap: + Whether to display the gap. + gap_tol: + Tolerance in k for determining whether two gaps are the same. + gap_color: + Color of the gap. + gap_marker: + Marker styles for the gap (as plotly marker's styles). + direct_gaps_only: + Whether to only display direct gaps. + custom_gaps: + List of custom gaps to display. See the showcase notebooks for examples. + line_mode: + The method used to draw the band lines. + backend: + The backend to use to generate the figure. + """ bands_data = accept_data(bands_data, cls=BandsData, check=True) @@ -49,21 +98,98 @@ def bands_plot(bands_data: BandsData, return get_figure(backend=backend, plot_actions=all_plottings) -def default_random_color(x): +def _default_random_color(x): return x.get("color") or random_color() + +def _group_traces(actions): + + seen_groups = [] + + new_actions = [] + for action in actions: + if action["method"].startswith("draw_"): + group = action["kwargs"].get("name") + action = action.copy() + action['kwargs']['legendgroup'] = group + + if group in seen_groups: + action["kwargs"]["showlegend"] = False + else: + seen_groups.append(group) + + new_actions.append(action) + + return new_actions + + # I keep the fatbands plot here so that one can see how similar they are. # I am yet to find a nice solution for extending workflows. -def fatbands_plot(bands_data: FatBandsData, - Erange: Optional[Tuple[float, float]] = None, E0=0, E_axis: Literal["x", "y"] = "y", bands_range=None, spin=None, - bands_style={'color': 'black', 'width': 1, "opacity": 1}, spindown_style={"color": "blue", "width": 1}, - gap=False, gap_tol=0.01, gap_color="red", gap_marker={"size": 7}, direct_gaps_only=False, custom_gaps=[], - bands_mode: Literal["line", "scatter", "area_line"] = "line", +def fatbands_plot(bands_data: BandsData, + Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", + bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, + bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, + spindown_style: StyleSpec = {"color": "blue", "width": 1}, + gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, + custom_gaps: Sequence[Dict] = [], + bands_mode: Literal["line", "scatter", "area_line"] = "line", # Fatbands inputs - groups: OrbitalQueries = [], fatbands_mode: Literal["line", "scatter", "area_line"] = "area_line", - fatbands_scale: float = 1., backend: str = "plotly" + groups: OrbitalQueries = [], + fatbands_var: str = "norm2", + fatbands_mode: Literal["line", "scatter", "area_line"] = "area_line", + fatbands_scale: float = 1., + backend: str = "plotly" ) -> Figure: - + """Plots band structure energies showing the contribution of orbitals to each state. + + Parameters + ---------- + bands_data: + The object containing the data to plot. + Erange: + The energy range to plot. + If None, the range is determined by ``bands_range``. + E0: + The energy reference. + E_axis: + Axis to plot the energies. + bands_range: + The bands to plot. Only used if ``Erange`` is None. + If None, the 15 bands above and below the Fermi level are plotted. + spin: + Which spin channel to display. Only meaningful for spin-polarized calculations. + If None and the calculation is spin polarized, both are plotted. + bands_style: + Styling attributes for bands. + spindown_style: + Styling attributes for the spin down bands (if present). Any missing attribute + will be taken from ``bands_style``. + gap: + Whether to display the gap. + gap_tol: + Tolerance in k for determining whether two gaps are the same. + gap_color: + Color of the gap. + gap_marker: + Marker styles for the gap (as plotly marker's styles). + direct_gaps_only: + Whether to only display direct gaps. + custom_gaps: + List of custom gaps to display. See the showcase notebooks for examples. + bands_mode: + The method used to draw the band lines. + groups: + Orbital groups to plots. See showcase notebook for examples. + fatbands_var: + The variable to use from bands_data to determine the width of the fatbands. + This variable must have as coordinates (k, band, orb, [spin]). + fatbands_mode: + The method used to draw the fatbands. + fatbands_scale: + Factor that scales the size of all fatbands. + backend: + The backend to use to generate the figure. + """ bands_data = accept_data(bands_data, cls=BandsData, check=True) # Filter the bands @@ -76,7 +202,7 @@ def fatbands_plot(bands_data: FatBandsData, orbital_manager = get_orbital_queries_manager( bands_data, key_gens={ - "color": default_random_color, + "color": _default_random_color, } ) fatbands_data = reduce_orbital_data( @@ -84,7 +210,7 @@ def fatbands_plot(bands_data: FatBandsData, group_vars=('color', 'dash'), groups_dim="group", drop_empty=True, spin_reduce=np.sum, ) - scaled_fatbands_data = scale_variable(fatbands_data, var="weight", scale=fatbands_scale, default_value=1, allow_not_present=True) + scaled_fatbands_data = scale_variable(fatbands_data, var=fatbands_var, scale=fatbands_scale, default_value=1, allow_not_present=True) # Determine what goes on each axis x = matches(E_axis, "x", ret_true="E", ret_false="k") @@ -93,7 +219,11 @@ def fatbands_plot(bands_data: FatBandsData, sanitized_fatbands_mode = matches(groups, [], ret_true="none", ret_false=fatbands_mode) # Get the actions to plot lines - fatbands_plottings = draw_xarray_xy(data=scaled_fatbands_data, x=x, y=y, color="color", width="weight", what=sanitized_fatbands_mode, dependent_axis=E_axis) + fatbands_plottings = draw_xarray_xy( + data=scaled_fatbands_data, x=x, y=y, color="color", width=fatbands_var, what=sanitized_fatbands_mode, dependent_axis=E_axis, + name="group" + ) + grouped_fatbands_plottings = _group_traces(fatbands_plottings) bands_plottings = draw_xarray_xy(data=styled_bands, x=x, y=y, set_axrange=True, what=bands_mode, dependent_axis=E_axis) # Gap calculation @@ -101,7 +231,7 @@ def fatbands_plot(bands_data: FatBandsData, # Plot it if the user has asked for it. gaps_plottings = draw_gaps(bands_data, gap, gap_info, gap_tol, gap_color, gap_marker, direct_gaps_only, custom_gaps, E_axis=E_axis) - all_plottings = combined(fatbands_plottings, bands_plottings, gaps_plottings, composite_method=None) + all_plottings = combined(grouped_fatbands_plottings, bands_plottings, gaps_plottings, composite_method=None) return get_figure(backend=backend, plot_actions=all_plottings) diff --git a/src/sisl/viz/plots/geometry.py b/src/sisl/viz/plots/geometry.py index 8100834daf..9a56467f50 100644 --- a/src/sisl/viz/plots/geometry.py +++ b/src/sisl/viz/plots/geometry.py @@ -18,6 +18,7 @@ add_xyz_to_dataset, bonds_to_lines, sites_obj_to_geometry, + get_sites_units, find_all_bonds, parse_atoms_style, sanitize_arrows, @@ -252,8 +253,8 @@ def sites_plot( Parameters ---------- - bz: - The brillouin zone object containing the k points to plot. + sites_obj: + The object to be converted to sites. axes: The axes to project the sites to. sites: @@ -300,7 +301,8 @@ def sites_plot( filtered_sites = select(sites_dataset, "atom", sanitized_sites) tiled_sites = tile_data_sc(filtered_sites, nsc=nsc) sc_sites = stack_sc_data(tiled_sites, newname="sc_atom", dims=["atom"]) - projected_sites = project_to_axes(sc_sites, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d) + sites_units = get_sites_units(sites_obj) + projected_sites = project_to_axes(sc_sites, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d, cartesian_units=sites_units) sites_scale = _sanitize_scale(sites_scale, ndim, sites_ndim_scale) final_sites = scale_variable(projected_sites, "size", scale=sites_scale) diff --git a/src/sisl/viz/plots/grid.py b/src/sisl/viz/plots/grid.py index a14cb525c5..c522eb65f2 100644 --- a/src/sisl/viz/plots/grid.py +++ b/src/sisl/viz/plots/grid.py @@ -33,7 +33,7 @@ from .geometry import geometry_plot -def get_structure_plottings(plot_geom, geometry, axes, nsc, geom_kwargs={},): +def _get_structure_plottings(plot_geom, geometry, axes, nsc, geom_kwargs={},): if plot_geom: geom_kwargs = ChainMap(geom_kwargs, {"axes": axes, "geometry": geometry, "nsc": nsc, "show_cell": False}) plot_actions = geometry_plot(**geom_kwargs).plot_actions @@ -42,14 +42,90 @@ def get_structure_plottings(plot_geom, geometry, axes, nsc, geom_kwargs={},): return plot_actions -def grid_plot(grid: Grid = None, axes: Axes = ["z"], represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", - transforms: Sequence[Union[str, Callable]] = (), reduce_method: Literal["average", "sum"] = "average", boundary_mode: str = "grid-wrap", - nsc: Sequence[int] = [1,1,1], interp: Sequence[int] = [1, 1, 1], isos: Sequence[dict] = [], smooth: bool = False, - colorscale: Optional[str] = None, crange: Optional[Tuple[float, float]] = None, cmid: Optional[float] = None, - show_cell: Literal["box", "axes", False] = "box", cell_style: dict = {}, - x_range: Optional[Sequence[float]] = None, y_range: Optional[Sequence[float]] = None, z_range: Optional[Sequence[float]] = None, - plot_geom: bool = False, geom_kwargs: dict = {}, backend: str = "plotly" +def grid_plot( + grid: Optional[Grid] = None, + axes: Axes = ["z"], + represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", + transforms: Sequence[Union[str, Callable]] = (), + reduce_method: Literal["average", "sum"] = "average", + boundary_mode: str = "grid-wrap", + nsc: Tuple[int, int, int] = (1, 1, 1), + interp: Tuple[int, int, int] = (1, 1, 1), + isos: Sequence[dict] = [], + smooth: bool = False, + colorscale: Optional[str] = None, + crange: Optional[Tuple[float, float]] = None, + cmid: Optional[float] = None, + show_cell: Literal["box", "axes", False] = "box", + cell_style: dict = {}, + x_range: Optional[Sequence[float]] = None, + y_range: Optional[Sequence[float]] = None, + z_range: Optional[Sequence[float]] = None, + plot_geom: bool = False, + geom_kwargs: dict = {}, + backend: str = "plotly" ) -> Figure: + """Plots a grid, with plentiful of customization options. + + Parameters + ---------- + grid: + The grid to plot. + axes: + The axes to project the grid to. + represent: + The representation of the grid to plot. + transforms: + List of transforms to apply to the grid before plotting. + reduce_method: + The method used to reduce the grid axes that are not displayed. + boundary_mode: + The method used to deal with the boundary conditions. + Only used if the grid is to be orthogonalized. + See scipy docs for more info on the possible values. + nsc: + The number of unit cells to display in each direction. + interp: + The interpolation factor to use for each axis to make the grid smoother. + isos: + List of isosurfaces or isocontours to plot. See the showcase notebooks for examples. + smooth: + Whether to ask the plotting backend to make an attempt at smoothing the grid display. + colorscale: + Colorscale to use for the grid display in the 2D representation. + If None, the default colorscale is used for each backend. + crange: + Min and max values for the colorscale. + cmid: + The value at which the colorscale is centered. + show_cell: + Method used to display the unit cell. If False, the cell is not displayed. + cell_style: + Style specification for the cell. See the showcase notebooks for examples. + x_range: + The range of the x axis to take into account. + Even if the X axis is not displayed! This is important because the reducing + operation will only be applied on this range. + y_range: + The range of the y axis to take into account. + Even if the Y axis is not displayed! This is important because the reducing + operation will only be applied on this range. + z_range: + The range of the z axis to take into account. + Even if the Z axis is not displayed! This is important because the reducing + operation will only be applied on this range. + plot_geom: + Whether to plot the associated geometry (if any). + geom_kwargs: + Keyword arguments to pass to the geometry plot of the associated geometry. + backend: + The backend to use to generate the figure. + + See also + ---------- + scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed. + """ + axes = sanitize_axes(axes) @@ -82,23 +158,109 @@ def grid_plot(grid: Grid = None, axes: Axes = ["z"], represent: Literal["real", ) # And maybe plot the strucuture - geom_plottings = get_structure_plottings(plot_geom=plot_geom, geometry=geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=nsc) + geom_plottings = _get_structure_plottings(plot_geom=plot_geom, geometry=geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=nsc) all_plottings = combined(grid_plottings, cell_plottings, geom_plottings, composite_method=None) return get_figure(backend=backend, plot_actions=all_plottings) -def wavefunction_plot(eigenstate: EigenstateData, i: int = 0, geometry: Optional[Geometry] = None, grid_prec: float = 0.2, - +def wavefunction_plot( + eigenstate: EigenstateData, + i: int = 0, + geometry: Optional[Geometry] = None, + grid_prec: float = 0.2, # All grid inputs. - grid: Optional[Grid] = None, axes: Axes = ["x", "y", "z"], represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", - transforms: Sequence[Union[str, Callable]] = (), reduce_method: Literal["average", "sum"] = "average", boundary_mode: str = "grid-wrap", - nsc: Sequence[int] = [1,1,1], interp: Sequence[int] = [1, 1, 1], isos: Sequence[dict] = [], smooth: bool = False, - colorscale: Optional[str] = None, crange: Optional[Tuple[float, float]] = None, cmid: Optional[float] = None, - show_cell: Literal["box", "axes", False] = "box", cell_style: dict = {}, - x_range: Optional[Sequence[float]] = None, y_range: Optional[Sequence[float]] = None, z_range: Optional[Sequence[float]] = None, - plot_geom: bool = True, geom_kwargs: dict = {}, backend: str = "plotly" + grid: Optional[Grid] = None, + axes: Axes = ["z"], + represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", + transforms: Sequence[Union[str, Callable]] = (), + reduce_method: Literal["average", "sum"] = "average", + boundary_mode: str = "grid-wrap", + nsc: Tuple[int, int, int] = (1, 1, 1), + interp: Tuple[int, int, int] = (1, 1, 1), + isos: Sequence[dict] = [], + smooth: bool = False, + colorscale: Optional[str] = None, + crange: Optional[Tuple[float, float]] = None, + cmid: Optional[float] = None, + show_cell: Literal["box", "axes", False] = "box", + cell_style: dict = {}, + x_range: Optional[Sequence[float]] = None, + y_range: Optional[Sequence[float]] = None, + z_range: Optional[Sequence[float]] = None, + plot_geom: bool = False, + geom_kwargs: dict = {}, + backend: str = "plotly" ) -> Figure: + """Plots a wavefunction in real space. + + Parameters + ---------- + eigenstate: + The eigenstate object containing information about eigenstates. + i: + The index of the eigenstate to plot. + geometry: + Geometry to use to project the eigenstate to real space. + If None, the geometry associated with the eigenstate is used. + grid_prec: + The precision of the grid where the wavefunction is projected. + grid: + The grid to plot. + axes: + The axes to project the grid to. + represent: + The representation of the grid to plot. + transforms: + List of transforms to apply to the grid before plotting. + reduce_method: + The method used to reduce the grid axes that are not displayed. + boundary_mode: + The method used to deal with the boundary conditions. + Only used if the grid is to be orthogonalized. + See scipy docs for more info on the possible values. + nsc: + The number of unit cells to display in each direction. + interp: + The interpolation factor to use for each axis to make the grid smoother. + isos: + List of isosurfaces or isocontours to plot. See the showcase notebooks for examples. + smooth: + Whether to ask the plotting backend to make an attempt at smoothing the grid display. + colorscale: + Colorscale to use for the grid display in the 2D representation. + If None, the default colorscale is used for each backend. + crange: + Min and max values for the colorscale. + cmid: + The value at which the colorscale is centered. + show_cell: + Method used to display the unit cell. If False, the cell is not displayed. + cell_style: + Style specification for the cell. See the showcase notebooks for examples. + x_range: + The range of the x axis to take into account. + Even if the X axis is not displayed! This is important because the reducing + operation will only be applied on this range. + y_range: + The range of the y axis to take into account. + Even if the Y axis is not displayed! This is important because the reducing + operation will only be applied on this range. + z_range: + The range of the z axis to take into account. + Even if the Z axis is not displayed! This is important because the reducing + operation will only be applied on this range. + plot_geom: + Whether to plot the associated geometry (if any). + geom_kwargs: + Keyword arguments to pass to the geometry plot of the associated geometry. + backend: + The backend to use to generate the figure. + + See also + ---------- + scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed. + """ # Create a grid with the wavefunction in it. i_eigenstate = get_eigenstate(eigenstate, i) @@ -138,7 +300,7 @@ def wavefunction_plot(eigenstate: EigenstateData, i: int = 0, geometry: Optional ) # And maybe plot the strucuture - geom_plottings = get_structure_plottings(plot_geom=plot_geom, geometry=tiled_geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=grid_nsc) + geom_plottings = _get_structure_plottings(plot_geom=plot_geom, geometry=tiled_geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=grid_nsc) all_plottings = combined(grid_plottings, cell_plottings, geom_plottings, composite_method=None) diff --git a/src/sisl/viz/plots/orbital_groups_plot.py b/src/sisl/viz/plots/orbital_groups_plot.py index 631fb5927b..4f26884409 100644 --- a/src/sisl/viz/plots/orbital_groups_plot.py +++ b/src/sisl/viz/plots/orbital_groups_plot.py @@ -171,7 +171,6 @@ def split_groups(self, *i_or_names, on="species", only=None, exclude=None, remov groups = [] for req in reqs: - new_groups = queries_manager._split_query( req, on=on, only=only, exclude=exclude, ignore_constraints=ignore_constraints, **kwargs @@ -186,3 +185,32 @@ def split_groups(self, *i_or_names, on="species", only=None, exclude=None, remov groups = [*old_groups, *groups] return self.update_inputs(**{self._orbital_groups_input_key: groups}) + + def split_orbs(self, on="species", only=None, exclude=None, clean=True, **kwargs): + """ + Splits the orbitals into different groups. + + Parameters + -------- + on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"}, or list of str + the parameter to split along. + Note that you can combine parameters with a "+" to split along multiple parameters + at the same time. You can get the same effect also by passing a list. + only: array-like, optional + if desired, the only values that should be plotted out of + all of the values that come from the splitting. + exclude: array-like, optional + values that should not be plotted + clean: boolean, optional + whether the plot should be cleaned before drawing. + If False, all the requests that come from the method will + be drawn on top of what is already there. + **kwargs: + keyword arguments that go directly to each request. + + This is useful to add extra filters. For example: + `plot.split_orbs(on="orbitals", species=["C"])` + will split on the different orbitals but will take + only those that belong to carbon atoms. + """ + return self.split_groups(on=on, only=only, exclude=exclude, clean=clean, **kwargs) diff --git a/src/sisl/viz/plots/pdos.py b/src/sisl/viz/plots/pdos.py index e3725618f7..2d6a8b90d4 100644 --- a/src/sisl/viz/plots/pdos.py +++ b/src/sisl/viz/plots/pdos.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Sequence +from typing import Any, Literal, Optional, Sequence, Tuple import numpy as np @@ -18,11 +18,35 @@ def pdos_plot( - pdos_data: PDOSData, groups: Sequence[OrbitalStyleQuery]=[{"name": "DOS"}], - Erange=[-2, 2], E_axis: Literal["x", "y"] = "x", line_mode: Literal["line", "scatter", "area_line"] = "line", - line_scale: float = 1., backend: str = "plotly", + pdos_data: PDOSData, + groups: Sequence[OrbitalStyleQuery]=[{"name": "DOS"}], + Erange: Tuple[float, float] = (-2, 2), + E_axis: Literal["x", "y"] = "x", + line_mode: Literal["line", "scatter", "area_line"] = "line", + line_scale: float = 1., + backend: str = "plotly", ) -> Figure: - """Plots PDOS""" + """Plot the projected density of states. + + Parameters + ---------- + pdos_data: + The object containing the raw PDOS data (individual PDOS for each orbital/spin). + groups: + List of orbital specifications to filter and accumulate the PDOS. + The contribution of each group will be displayed in a different line. + See showcase notebook for examples. + Erange: + The energy range to plot. + E_axis: + Axis to project the energies. + line_mode: + Mode used to draw the PDOS lines. + line_scale: + Scaling factor for the width of all lines. + backend: + The backend to generate the figure. + """ pdos_data = accept_data(pdos_data, cls=PDOSData, check=True) E_PDOS = filter_energy_range(pdos_data, Erange=Erange, E0=0) diff --git a/src/sisl/viz/plotters/__init__.py b/src/sisl/viz/plotters/__init__.py index 77c130adba..9a5ac8cc0b 100644 --- a/src/sisl/viz/plotters/__init__.py +++ b/src/sisl/viz/plotters/__init__.py @@ -1 +1,3 @@ +"""Functions that generate plot actions to be passed to figures.""" + from . import plot_actions \ No newline at end of file diff --git a/src/sisl/viz/plotters/plot_actions.py b/src/sisl/viz/plotters/plot_actions.py index 83ab91be52..69f224f60a 100644 --- a/src/sisl/viz/plotters/plot_actions.py +++ b/src/sisl/viz/plotters/plot_actions.py @@ -1,3 +1,5 @@ +"""Contains all the individual actions that can be performed on a figure.""" + import functools import inspect import sys @@ -22,7 +24,8 @@ def a(*args, __method_name__=function.__name__, **kwargs): return dict(method=__method_name__, args=args, kwargs=kwargs) a.__signature__ = sig.replace(parameters=list(sig.parameters.values())[1:]) - + a.__module__ = module + setattr(module, name, a) _register_actions(Figure) diff --git a/src/sisl/viz/plotters/xarray.py b/src/sisl/viz/plotters/xarray.py index 677e9d5c1d..c5922eba4f 100644 --- a/src/sisl/viz/plotters/xarray.py +++ b/src/sisl/viz/plotters/xarray.py @@ -216,12 +216,18 @@ def drawing_function(*args, **kwargs): fixed_coords_values = {k: arr.values for k, arr in fixed_coords.items()} single_line = len(data.iterate_dim) == 1 - name_prefix = f"{name}_" if name and not single_line else name + if name in data.iterate_dim.coords: + name_prefix = "" + else: + name_prefix = f"{name}_" if name and not single_line else name # Now just iterate over each line and plot it. for values, *styles in iterator: + names = values.iterate_dim.values[()] - if single_line and not isinstance(names[0], str): + if name in values.iterate_dim.coords: + line_name = f"{name_prefix}{values.iterate_dim.coords[name].values[()]}" + elif single_line and not isinstance(names[0], str): line_name = name_prefix elif len(names) == 1: line_name = f"{name_prefix}{names[0]}" diff --git a/src/sisl/viz/plotutils.py b/src/sisl/viz/plotutils.py index 7070e28168..e66bc3b0d6 100644 --- a/src/sisl/viz/plotutils.py +++ b/src/sisl/viz/plotutils.py @@ -109,16 +109,14 @@ def get_plot_classes(): list all the plot classes that the module is aware of. """ - from . import Animation, MultiplePlot, Plot, SubPlots + from . import Plot def get_all_subclasses(cls): all_subclasses = [] for Subclass in cls.__subclasses__(): - - if Subclass not in [MultiplePlot, Animation, SubPlots] and not getattr(Subclass, 'is_only_base', False): - all_subclasses.append(Subclass) + all_subclasses.append(Subclass) all_subclasses.extend(get_all_subclasses(Subclass)) @@ -179,122 +177,6 @@ def get_plotable_variables(variables): return plotables -def get_configurable_docstring(cls): - """ Builds the docstring for a class that inherits from Configurable - - Parameters - ----------- - cls: - the class you want the docstring for - - Returns - ----------- - str: - the docs with the settings added. - """ - import re - - if isinstance(cls, type): - params = cls._parameters - doc = cls.__doc__ - if doc is None: - doc = "" - else: - # It's really an instance, not the class - params = cls.params - doc = "" - - configurable_settings = "\n".join([param._get_docstring() for param in params]) - - html_cleaner = re.compile('<.*?>') - configurable_settings = re.sub(html_cleaner, '', configurable_settings) - - if "Parameters\n--" not in doc: - doc += f'\n\nParameters\n-----------\n{configurable_settings}' - else: - doc += f'\n{configurable_settings}' - - return doc - - -def get_configurable_kwargs(cls_or_inst, fake_default): - """ Builds a string to help you define all the kwargs coming from the settings. - - The main point is to avoid wasting time writing all the kwargs manually, and - at the same time makes it easy to keep it consistent with the defaults. - - This may be useful, for example, for the __init__ method of plots. - - Parameters - ------------ - cls_or_inst: - the class (or instance) you want the kwargs for. - fake_default: str - only floats, ints, bools and strings can be parsed safely into strings and then into values again. - For this reason, the rest of the settings will just be given a fake default that you need to handle. - - Returns - ----------- - str: - the string containing the described kwargs. - """ - # TODO why not just repr(val)? that seems to be the same in all cases? - def get_string(val): - if isinstance(val, (float, int, bool)) or val is None: - return val - elif isinstance(val, str): - return val.__repr__() - else: - return fake_default.__repr__() - - if isinstance(cls_or_inst, type): - params = cls_or_inst._parameters - return ", ".join([f'{param.key}={get_string(param.default)}' for param in params]) - - # It's really an instance, not the class - # In this case, the defaults for the method will be the current values. - params = cls_or_inst.params - return ", ".join([f'{param.key}={get_string(cls_or_inst.settings[param.key])}' for param in params]) - - -def get_configurable_kwargs_to_pass(cls): - """ Builds a string to help you pass kwargs that you got from the function `get_configurable_kwargs`. - - E.g.: If `get_configurable_kwargs` gives you 'param1=None, param2="nothing"' - `get_configurable_kwargs_to_pass` will give you param1=param1, param2=param2 - - Parameters - ------------ - cls: - the class you want the kwargs for - - Returns - ----------- - str: - the string containing the described kwargs. - """ - if isinstance(cls, type): - params = cls._parameters - else: - # It's really an instance, not the class - params = cls.params - - return ", ".join([f'{param.key}={param.key}' for param in params]) - - -def get_session_classes(): - """ Returns the available session classes - - Returns - -------- - dict - keys are the name of the class and values are the class itself. - """ - from .session import Session - - return {sbcls.__name__: sbcls for sbcls in Session.__subclasses__()} - - def get_avail_presets(): """ Gets the names of the currently available presets. @@ -412,96 +294,10 @@ def dictOfLists2listOfDicts(dictOfLists): return [dict(zip(dictOfLists, t)) for t in zip(*dictOfLists.values())] -def call_method_if_present(obj, method_name, *args, **kwargs): - """ Calls a method of the object if it is present. - - If the method is not there, it just does nothing. - - Parameters - ----------- - method_name: str - the name of the method that you want to call. - *args and **kwargs: - arguments passed to the method call. - """ - - method = getattr(obj, method_name, None) - if callable(method): - return method(*args, **kwargs) - - -def copy_params(params, only=(), exclude=()): - """ Function that returns a copy of the provided plot parameters. - - Arguments - ---------- - params: tuple - The parameters that have to be copied. This will come presumably from the "_parameters" variable of some plot class. - only: array-like - Use this if you only want a certain set of parameters. Pass the wanted keys as a list. - exclude: array-like - Use this if there are some parameters that you don't want. Pass the unwanted keys as a list. - This argument will not be used if "only" is present. - - Returns - ---------- - copiedParams: tuple - The params that the user asked for. They are not linked to the input params, so they can be modified independently. - """ - if only: - return tuple(param for param in deepcopy(params) if param.key in only) - return tuple(param for param in deepcopy(params) if param.key not in exclude) - - -def copy_dict(dictInst, only=(), exclude=()): - """ Function that returns a copy of a dict. This function is thought to be used for the settings dictionary, for example. - - Arguments - ---------- - dictInst: dict - The dictionary that needs to be copied. - only: array-like - Use this if you only want a certain set of values. Pass the wanted keys as a list. - exclude: array-like - Use this if there are some values that you don't want. Pass the unwanted keys as a list. - This argument will not be used if "only" is present. - - Returns - ---------- - copiedDict: dict - The dictionary that the user asked for. It is not linked to the input dict, so it can be modified independently. - """ - if only: - return {k: v for k, v in deepcopy(dictInst).iteritems() if k in only} - return {k: v for k, v in deepcopy(dictInst).iteritems() if k not in exclude} - #------------------------------------- # Filesystem #------------------------------------- - -def load(path): - """ - Loads a previously saved python object using pickle. To be used for plots, sessions, etc... - - Arguments - ---------- - path: str - The path to the saved object. - - Returns - ---------- - loadedObj: object - The object that was saved. - """ - import dill - - with open(path, 'rb') as handle: - loadedObj = dill.load(handle) - - return loadedObj - - def find_files(root_dir=Path("."), search_string = "*", depth = [0, 0], sort = True, sort_func = None, case_insensitive=False): """ Function that finds files (or directories) according to some conditions. @@ -597,338 +393,6 @@ def find_plotable_siles(dir_path=None, depth=0): return files -#------------------------------------- -# Multiprocessing -#------------------------------------- - -_MAX_NPROCS = get_environ_variable("SISL_VIZ_NUM_PROCS") - - -def _apply_method(args_tuple): - """ Apply a method to an object. This function is meant for multiprocessing """ - - method, obj, args, kwargs = args_tuple - - if args is None: - args = [] - - method(obj, *args, **kwargs) - - return obj - - -def _init_single_plot(args_tuple): - """ Initialize a single plot. This function is meant to be used in multiprocessing, when multiple plots need to be initialized """ - - PlotClass, args, kwargs = args_tuple - - return PlotClass(**kwargs) - - -def run_multiple(func, *args, argsList = None, kwargsList = None, messageFn = None, serial = False): - """ - Makes use of the pathos.multiprocessing module to run a function simultanously multiple times. - This is meant mainly to update multiple plots at the same time, which can accelerate significantly the process of visualizing data. - - All arguments passed to the function, except func, can be passed as specified in the arguments section of this documentation - or as a list containing multiple instances of them. - If a list is passed, each time the function needs to be run it will take the next item of the list. - If a single item is passed instead, this item will be repeated for each function run. - However, at least one argument must be a list, so that the number of times that the function has to be ran is defined. - - Arguments - ---------- - func: function - The function to be executed. It has to be prepared to recieve the arguments as they are provided to it (zipped). - - See the applyMethod() function as an example. - *args: - Contains all the arguments that are specific to the individual function that we want to run. - See each function separately to understand what you need to pass (you may not need this parameter). - argsList: array-like - An array of arguments that have to be passed to the executed function. - - Can also be a list of arrays (see this function's description). - - WARNING: Currently it only works properly for a list of arrays. Didn't fix this because the lack of interest - of argsList on Plot's methods (everything is passed as keyword arguments). - kwargsList: dict - A dictionary with the keyword arguments that have to be passed to the executed function. - - If the executed function is a Plot's method, these can be the settings, for example. - - Can also be a list of dicts (see this function's description). - - messageFn: function - Function that recieves the number of tasks and nodes and needs to return a string to display as a description of the progress bar. - serial: bool - If set to true, multiprocessing is not used. - - This seems to have little sense, but it is useful to switch easily between multiprocessing and serial with the same code. - - Returns - ---------- - results: list - A list with all the returned values or objects from each function execution. - This list is ordered, so results[0] is the result of executing the function with argsList[0] and kwargsList[0]. - """ - #Prepare the arguments to be passed to the initSinglePlot function - toZip = [*args, argsList, kwargsList] - for i, arg in enumerate(toZip): - if not isinstance(arg, (list, tuple, np.ndarray)): - toZip[i] = itertools.repeat(arg) - else: - nTasks = len(arg) - - # Run things in serial mode in case it is demanded or pathos is not available - serial = not pathos_avail or serial or _MAX_NPROCS == 1 or nTasks == 1 - if serial: - return [func(argsTuple) for argsTuple in zip(*toZip)] - - #Create a pool with the appropiate number of processes - pool = Pool(min(nTasks, _MAX_NPROCS)) - #Define the plots array to store all the plots that we initialize - results = [None]*nTasks - - #Initialize the pool iterator and the progress bar that controls it - imap = pool.imap(func, zip(*toZip)) - if tqdm_avail: - imap = tqdm.tqdm(imap, total = nTasks) - - #Set a description for the progress bar - if not callable(messageFn): - message = "Updating {} plots in {} processes".format(nTasks, pool.nodes) - else: - message = messageFn(nTasks, pool.nodes) - - imap.set_description(message) - - #Run the processes and store each result in the plots array - for i, res in enumerate(imap): - results[i] = res - - pool.close() - pool.join() - pool.clear() - - return results - - -def init_multiple_plots(PlotClass, argsList = None, kwargsList = None, **kwargs): - """ Initializes a set of plots in multiple processes simultanously making use of runMultiple() - - All arguments passed to the function, can be passed as specified in the arguments section of this documentation - or as a list containing multiple instances of them. - If a list is passed, each time the function needs to be run it will take the next item of the list. - If a single item is passed instead, this item will be repeated for each function run. - However, at least one argument must be a list, so that the number of times that the function has to be ran is defined. - - Arguments - ---------- - PlotClass: child class of sisl.viz.plotly.Plot - The plot class that must be initialized - - Can also be a list of classes (see this function's description). - argsList: array-like - An array of arguments that have to be passed to the executed function. - - Can also be a list of arrays (see this function's description). - - WARNING: Currently it only works properly for a list of arrays. Didn't fix this because the lack of interest - of argsList on Plot's methods (everything is passed as keyword arguments). - kwargsList: dict - A dictionary with the keyword arguments that have to be passed to the executed function. - - If the executed function is a Plot's method, these can be the settings, for example. - - Can also be a list of dicts (see this function's description). - - Returns - ---------- - plots: list - A list with all the initialized plots. - This list is ordered, so plots[0] is the plot initialized with argsList[0] and kwargsList[0]. - """ - - return run_multiple(_init_single_plot, PlotClass, argsList = argsList, kwargsList = kwargsList, **kwargs) - - -def apply_method_on_multiple_objs(method, objs, argsList = None, kwargsList = None, **kwargs): - """ Applies a given method to the objects provided on multiple processes simultanously making use of the runMultiple() function. - - This is useful in principle for any kind of object and any method, but has been tested only on plots. - - All arguments passed to the function, except method, can be passed as specified in the arguments section of this documentation - or as a list containing multiple instances of them. - If a list is passed, each time the function needs to be run it will take the next item of the list. - If a single item is passed instead, this item will be repeated for each function run. - However, at least one argument must be a list, so that the number of times that the function has to be ran is defined. - - Arguments - ---------- - method: func - The method to be executed. - objs: object - The object to which we need to apply the method (e.g. a plot) - - Can also be a list of objects (see this function's description). - argsList: array-like - An array of arguments that have to be passed to the executed function. - - Can also be a list of arrays (see this function's description). - - WARNING: Currently it only works properly for a list of arrays. Didn't fix this because the lack of interest - of argsList on Plot's methods (everything is passed as keyword arguments). - kwargsList: dict - A dictionary with the keyword arguments that have to be passed to the executed function. - - If the executed function is a Plot's method, these can be the settings, for example. - - Can also be a list of dicts (see this function's description). - - Returns - ---------- - plots: list - A list with all the initialized plots. - This list is ordered, so plots[0] is the plot initialized with argsList[0] and kwargsList[0]. - """ - - return run_multiple(_apply_method, method, objs, argsList = argsList, kwargsList = kwargsList, **kwargs) - - -def repeat_if_children(method): - """ Decorator that will force a method to be run on all the plot's children in case there are any """ - - def apply_to_all_plots(obj, *args, children_sel=None, **kwargs): - - if hasattr(obj, "children"): - - kwargs_list = kwargs.get("kwargs_list", kwargs) - - if isinstance(children_sel, int): - children_sel = [children_sel] - - # Get all the child plots that we are going to modify - children = obj.children - if children_sel is not None: - children = np.array(children)[children_sel].tolist() - else: - children_sel = range(len(children)) - - new_children = apply_method_on_multiple_objs(method, children, kwargsList=kwargs_list, serial=True) - - # Set the new plots. We need to do this because apply_method_on_multiple_objs - # can use multiprocessing, and therefore will not modify the plot in place. - for i, new_child in zip(children_sel, new_children): - obj.children[i] = new_child - - obj.get_figure() - - else: - - return method(obj, *args, **kwargs) - - return apply_to_all_plots - -#------------------------------------- -# Fun stuff -#------------------------------------- - -# TODO these would be ideal to put in the sisl configdir so users can -# alter the commands used ;) -# However, not really needed now. - - -def trigger_notification(title, message, sound="Submarine"): - """ Triggers a notification. - - Will not do anything in Windows (oops!) - - Parameters - ----------- - title: str - message: str - sound: str - """ - - if sys.platform == 'linux': - os.system(f"""notify-send "{title}" "{message}" """) - elif sys.platform == 'darwin': - sound_string = f'sound name "{sound}"' if sound else '' - os.system(f"""osascript -e 'display notification "{message}" with title "{title}" {sound_string}' """) - else: - info(f"sisl cannot issue notifications through the operating system ({sys.platform})") - - -def spoken_message(message): - """ Trigger a spoken message. - - In linux espeak must be installed (sudo apt-get install espeak) - - Will not do anything in Windows (oops!) - - Parameters - ----------- - title: str - message: str - sound: str - """ - - if sys.platform == 'linux': - os.system(f"""espeak -s 150 "{message}" 2>/dev/null""") - elif sys.platform == 'darwin': - os.system(f"""osascript -e 'say "{message}"' """) - else: - info(f"sisl cannot issue notifications through the operating system ({sys.platform})") - -#------------------------------------- -# Plot manipulation -#------------------------------------- - - -def shift_trace(trace, shift, axis="y"): - """ Shifts a trace by a given value in the given axis. - - Parameters - ----------- - shift: float or array-like - If it's a float, it will be a solid shift (i.e. all points moved equally). - If it's an array, an element-wise sum will be performed - axis: {"x","y","z"}, optional - The axis along which we want to shift the traces. - """ - trace[axis] = np.array(trace[axis]) + shift - - -def normalize_trace(trace, min_val=0, max_val=1, axis='y'): - """ Normalizes a trace to a given range along an axis. - - Parameters - ----------- - min_val: float, optional - The lower bound of the range. - max_val: float, optional - The upper part of the range - axis: {"x", "y", "z"}, optional - The axis along which we want to normalize. - """ - t = np.array(trace[axis]) - tmin = t.min() - trace[axis] = (t - tmin) / (t.max() - tmin) * (max_val - min_val) + min_val - - -def swap_trace_axes(trace, ax1='x', ax2='y'): - """ Swaps two axes of a trace. - - Parameters - ----------- - ax1, ax2: str, {'x', 'x*', 'y', 'y*', 'z', 'z*'} - The names of the axes that you want to swap. - """ - ax1_data = trace[ax1] - trace[ax1] = trace[ax2] - trace[ax2] = ax1_data - #------------------------------------- # Colors diff --git a/src/sisl/viz/processors/axes.py b/src/sisl/viz/processors/axes.py index fa1b09c815..e37e6797d0 100644 --- a/src/sisl/viz/processors/axes.py +++ b/src/sisl/viz/processors/axes.py @@ -36,7 +36,7 @@ def sanitize_axes(val: Union[str, Sequence[Union[str, int, np.ndarray]]]) -> Lis val = re.findall("[+-]?[xyzabc012]", val) return [sanitize_axis(ax) for ax in val] -def get_ax_title(ax: Union[Axis, Callable]) -> str: +def get_ax_title(ax: Union[Axis, Callable], cartesian_units: str = "Ang") -> str: """Generates the title for a given axis""" if hasattr(ax, "__name__"): title = ax.__name__ @@ -45,7 +45,7 @@ def get_ax_title(ax: Union[Axis, Callable]) -> str: elif not isinstance(ax, str): title = "" elif re.match("[+-]?[xXyYzZ]", ax): - title = f'{ax.upper()} axis [Ang]' + title = f'{ax.upper()} axis [{cartesian_units}]' elif re.match("[+-]?[aAbBcC]", ax): title = f'{ax.upper()} lattice vector' else: diff --git a/src/sisl/viz/processors/coords.py b/src/sisl/viz/processors/coords.py index f164bc5683..3c77e1f206 100644 --- a/src/sisl/viz/processors/coords.py +++ b/src/sisl/viz/processors/coords.py @@ -210,8 +210,12 @@ def projected_3D_data(coords_data: CoordsDataset) -> CoordsDataset: return coords_data -def project_to_axes(coords_data: CoordsDataset, axes: Axes, - dataaxis_1d: Optional[Union[npt.ArrayLike, Callable]] = None, sort_by_depth: bool = False) -> CoordsDataset: +def project_to_axes( + coords_data: CoordsDataset, axes: Axes, + dataaxis_1d: Optional[Union[npt.ArrayLike, Callable]] = None, + sort_by_depth: bool = False, + cartesian_units: str = "Ang" +) -> CoordsDataset: ndim = len(axes) if ndim == 3: xaxis, yaxis, zaxis = axes @@ -228,7 +232,7 @@ def project_to_axes(coords_data: CoordsDataset, axes: Axes, for ax, plot_ax in zip(axes, plot_axes): coords_data[plot_ax].attrs["axis"] = { - "title": get_ax_title(ax), + "title": get_ax_title(ax, cartesian_units=cartesian_units), } coords_data.attrs['ndim'] = ndim diff --git a/src/sisl/viz/processors/geometry.py b/src/sisl/viz/processors/geometry.py index e1c5f9f58a..4fd2df088d 100644 --- a/src/sisl/viz/processors/geometry.py +++ b/src/sisl/viz/processors/geometry.py @@ -490,5 +490,12 @@ def sites_obj_to_geometry(sites_obj: BrillouinZone): return Geometry(sites_obj.k.dot(sites_obj.rcell), lattice=sites_obj.rcell) else: raise ValueError(f"Cannot convert {sites_obj.__class__.__name__} to a geometry.") + +def get_sites_units(sites_obj: BrillouinZone): + """Units of space for an object that is to be converted into a geometry""" + if isinstance(sites_obj, BrillouinZone): + return "1/Ang" + else: + return "" diff --git a/src/sisl/viz/processors/orbital.py b/src/sisl/viz/processors/orbital.py index 406587a995..18495856c0 100644 --- a/src/sisl/viz/processors/orbital.py +++ b/src/sisl/viz/processors/orbital.py @@ -361,7 +361,7 @@ def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignor if ignore_constraints is False: ignore_constraints = () - constraints = {key: val for key, val in constraints.items() if key not in ignore_constraints} + constraints = {key: val for key, val in constraints.items() if key not in ignore_constraints and val is not None} # Knowing what are our constraints (which may be none), get the available options values = self.get_options("+".join(on), **constraints) diff --git a/src/sisl/viz/processors/tests/test_bands.py b/src/sisl/viz/processors/tests/test_bands.py index ea80f29ecf..ccf7f60b94 100644 --- a/src/sisl/viz/processors/tests/test_bands.py +++ b/src/sisl/viz/processors/tests/test_bands.py @@ -75,7 +75,7 @@ def test_calculate_gap(bands_data, gap): VB = len(bands_data.band) // 2 - 1 assert isinstance(gap_info['bands'], tuple) and len(gap_info['bands']) == 2 - assert gap_info['bands'] == (VB, VB + 1) + assert gap_info['bands'][0] < gap_info['bands'][1] assert isinstance(gap_info['spin'], tuple) and len(gap_info['spin']) == 2 if not spin.is_polarized: