From 2a713cbab6dd608efc5625afcb3c024a5a1db1c6 Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Wed, 27 Sep 2023 20:35:52 +0200 Subject: [PATCH] Added sites plot. E.g. to visualize k points of a BZ object. --- src/sisl/viz/_plotables.py | 42 ++++-- src/sisl/viz/_plotables_register.py | 7 +- src/sisl/viz/plots/__init__.py | 9 +- src/sisl/viz/plots/geometry.py | 219 +++++++++++++++++++++++++--- src/sisl/viz/processors/geometry.py | 21 ++- 5 files changed, 258 insertions(+), 40 deletions(-) diff --git a/src/sisl/viz/_plotables.py b/src/sisl/viz/_plotables.py index 65547c41e..816076608 100644 --- a/src/sisl/viz/_plotables.py +++ b/src/sisl/viz/_plotables.py @@ -4,6 +4,7 @@ """ This file provides tools to handle plotability of objects """ +from collections import ChainMap import inspect from typing import Sequence, Type @@ -15,11 +16,21 @@ class ClassPlotHandler(ClassDispatcher): """Handles all plotting possibilities for a class""" - def __init__(self, *args, **kwargs): + def __init__(self, cls, *args, inherited_handlers = (), **kwargs): + self._cls = cls if not "instance_dispatcher" in kwargs: kwargs["instance_dispatcher"] = ObjectPlotHandler kwargs["type_dispatcher"] = None - super().__init__(*args, **kwargs) + super().__init__(*args, inherited_handlers=inherited_handlers, **kwargs) + + self._dispatchs = ChainMap(self._dispatchs, *[handler._dispatchs for handler in inherited_handlers]) + + def set_default(self, key: str): + """Sets the default plotting function for the class.""" + if key not in self._dispatchs: + raise KeyError(f"Cannot set {key} as default since it is not registered.") + self._default = key + class ObjectPlotHandler(ObjectDispatcher): @@ -37,7 +48,9 @@ def __init__(self, *args, **kwargs): def __call__(self, *args, **kwargs): """If the plot handler is called, we will run the default plotting function unless the keyword method has been passed.""" - return getattr(self, kwargs.pop("method", self._default) or self._default)(*args, **kwargs) + if self._default is None: + raise TypeError(f"No default plotting function has been defined for {self._obj.__class__.__name__}.") + return getattr(self, self._default)(*args, **kwargs) class PlotDispatch(AbstractDispatch): @@ -146,17 +159,22 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N name = plot_cls.plot_class_key() # Check if we already have a plot_handler - plot_handler = plotable.__dict__.get(plot_handler_attr, None) + plot_handler = getattr(plotable, plot_handler_attr, None) # If it's the first time that the class is being registered, # let's give the class a plot handler - if not isinstance(plot_handler, ClassPlotHandler): + if not isinstance(plot_handler, ClassPlotHandler) or plot_handler._cls is not plotable: + + if isinstance(plot_handler, ClassPlotHandler): + inherited_handlers = [plot_handler] + else: + inherited_handlers = [] # If the user is passing an instance, we get the class if not isinstance(plotable, type): plotable = type(plotable) - setattr(plotable, plot_handler_attr, ClassPlotHandler(plot_handler_attr)) + setattr(plotable, plot_handler_attr, ClassPlotHandler(plotable, plot_handler_attr, inherited_handlers=inherited_handlers)) plot_handler = getattr(plotable, plot_handler_attr) @@ -166,9 +184,6 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N def register_data_source(data_source_cls, plot_cls, setting_key, name=None, default: Sequence[Type] = [], plot_handler_attr='plot', **kwargs): - signatures = {} - params_infos = {} - plot_cls_params = { name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) for name, param in inspect.signature(plot_cls).parameters.items() if name != setting_key @@ -200,17 +215,16 @@ def register_data_source(data_source_cls, plot_cls, setting_key, name=None, defa signature = signature.replace(parameters=new_parameters) - signatures[plotable] = signature - params_infos[plotable] = { + params_info = { "data_args": data_args, "replaced_data_args": replaced_data_args, "data_var_kwarg": data_var_kwarg, "plot_var_kwarg": new_parameters[-1].name if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD else None } - def _plot(obj, *args, **kwargs): - sig = signatures[type(obj)] - params_info = params_infos[type(obj)] + def _plot(obj, *args, __params_info=params_info, __signature=signature, **kwargs): + sig = __signature + params_info = __params_info bound = sig.bind_partial(**kwargs) diff --git a/src/sisl/viz/_plotables_register.py b/src/sisl/viz/_plotables_register.py index 5e2bfb542..067e7ae1c 100644 --- a/src/sisl/viz/_plotables_register.py +++ b/src/sisl/viz/_plotables_register.py @@ -30,7 +30,7 @@ # the data source can digest register_data_source(PDOSData, PdosPlot, "pdos_data", default=[siesta.pdosSileSiesta]) -register_data_source(BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta, sisl.BandStructure]) +register_data_source(BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta]) register_data_source(FatBandsData, FatbandsPlot, "bands_data") register_data_source(EigenstateData, WavefunctionPlot, "eigenstate", default=[sisl.EigenstateElectron]) @@ -55,3 +55,8 @@ # # Grid register(sisl.Grid, GridPlot, 'grid', default=True) + +# Brilloiun zone +register(sisl.BrillouinZone, SitesPlot, 'sites_obj') + +sisl.BandStructure.plot.set_default("bands") diff --git a/src/sisl/viz/plots/__init__.py b/src/sisl/viz/plots/__init__.py index f0c8a82a0..e6b3100c5 100644 --- a/src/sisl/viz/plots/__init__.py +++ b/src/sisl/viz/plots/__init__.py @@ -1,5 +1,12 @@ +""" +Plots +===== + +Module containing all implemented plots, both in a functional form and as Workflows. +""" + from .bands import BandsPlot, FatbandsPlot, bands_plot, fatbands_plot -from .geometry import GeometryPlot, geometry_plot +from .geometry import GeometryPlot, geometry_plot, sites_plot, SitesPlot from .grid import GridPlot, WavefunctionPlot, grid_plot, wavefunction_plot from .merged import merge_plots from .pdos import PdosPlot, pdos_plot diff --git a/src/sisl/viz/plots/geometry.py b/src/sisl/viz/plots/geometry.py index 05c3baea8..8100834da 100644 --- a/src/sisl/viz/plots/geometry.py +++ b/src/sisl/viz/plots/geometry.py @@ -1,12 +1,12 @@ -from typing import Literal, Optional, Sequence, Tuple, Union +from typing import Literal, Optional, Sequence, Tuple, Callable, Union, TypeVar import numpy as np -from sisl import Geometry +from sisl import Geometry, BrillouinZone from sisl.typing import AtomsArgument from sisl.viz.figure import Figure, get_figure from sisl.viz.plotters import plot_actions as plot_actions -from sisl.viz.types import AtomArrowSpec, AtomsStyleSpec, Axes, GeometryLike +from sisl.viz.types import AtomArrowSpec, AtomsStyleSpec, Axes, StyleSpec from ..plot import Plot from ..plotters.cell import cell_plot_actions, get_ndim, get_z @@ -17,6 +17,7 @@ add_xyz_to_bonds_dataset, add_xyz_to_dataset, bonds_to_lines, + sites_obj_to_geometry, find_all_bonds, parse_atoms_style, sanitize_arrows, @@ -26,7 +27,7 @@ style_bonds, tile_data_sc, ) -from ..processors.logic import switch +from ..processors.logic import switch, matches from ..processors.xarray import scale_variable, select @@ -40,7 +41,7 @@ def _get_atom_mode(drawing_mode, ndim): return drawing_mode -def get_arrow_plottings(atoms_data, arrows, nsc=[1,1,1]): +def _get_arrow_plottings(atoms_data, arrows, nsc=[1,1,1]): reps = np.prod(nsc) actions = [] @@ -70,22 +71,84 @@ def get_arrow_plottings(atoms_data, arrows, nsc=[1,1,1]): return actions -def hide_1D(show: Union[bool, str], ndim: int): - return ndim > 1 and show - -def sanitize_scale(scale: float, ndim: int, ndim_scale: Tuple[float, float, float] = (16, 16, 1)): +def _sanitize_scale(scale: float, ndim: int, ndim_scale: Tuple[float, float, float] = (16, 16, 1)): return ndim_scale[ndim-1] * scale -def geometry_plot(geometry: Geometry, axes: Axes = ["x", "y", "z"], atoms: AtomsArgument = None, - atoms_style: Sequence[AtomsStyleSpec] = [], atoms_scale: float = 1., atoms_colorscale: Optional[str] = None, +def geometry_plot(geometry: Geometry, + axes: Axes = ["x", "y", "z"], + atoms: AtomsArgument = None, + atoms_style: Sequence[AtomsStyleSpec] = [], + atoms_scale: float = 1., + atoms_colorscale: Optional[str] = None, drawing_mode: Literal["scatter", "balls", None] = None, - bind_bonds_to_ats: bool = True, points_per_bond: int = 20, bonds_style={}, bonds_scale: float = 1., bonds_colorscale: Optional[str] = None, - show_atoms: bool = True, show_bonds: bool = True, show_cell: Literal["box", "axes", False] = "box", - cell_style={}, nsc: Sequence[int] = [1,1,1], + bind_bonds_to_ats: bool = True, + points_per_bond: int = 20, + bonds_style: StyleSpec = {}, + bonds_scale: float = 1., + bonds_colorscale: Optional[str] = None, + show_atoms: bool = True, + show_bonds: bool = True, + show_cell: Literal["box", "axes", False] = "box", + cell_style: StyleSpec = {}, + nsc: Tuple[int, int, int] = (1, 1, 1), atoms_ndim_scale: Tuple[float, float, float] = (16, 16, 1), bonds_ndim_scale: Tuple[float, float, float] = (1, 1, 10), - dataaxis_1d=None, arrows: Sequence[AtomArrowSpec] = (), backend="plotly", + dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, + arrows: Sequence[AtomArrowSpec] = (), + backend="plotly", ) -> Figure: + """Plots a geometry structure, with plentiful of customization options. + + Parameters + ---------- + geometry: + The geometry to plot. + axes: + The axes to project the geometry to. + atoms: + The atoms to plot. If None, all atoms are plotted. + atoms_style: + List of style specifications for the atoms. See the showcase notebooks for examples. + atoms_scale: + Scaling factor for the size of all atoms. + atoms_colorscale: + Colorscale to use for the atoms in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + drawing_mode: + The method used to draw the atoms. + bind_bonds_to_ats: + Whether to display only bonds between atoms that are being displayed. + points_per_bond: + When the points are drawn using points instead of lines (e.g. in some frameworks + to draw multicolor bonds), the number of points used per bond. + bonds_style: + Style specification for the bonds. See the showcase notebooks for examples. + bonds_scale: + Scaling factor for the width of all bonds. + bonds_colorscale: + Colorscale to use for the bonds in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + show_atoms: + Whether to display the atoms. + show_bonds: + Whether to display the bonds. + show_cell: + Mode to display the cell. If False, the cell is not displayed. + cell_style: + Style specification for the cell. See the showcase notebooks for examples. + nsc: + Number of unit cells to display in each direction. + atoms_ndim_scale: + Scaling factor for the size of the atoms for different dimensionalities (1D, 2D, 3D). + bonds_ndim_scale: + Scaling factor for the width of the bonds for different dimensionalities (1D, 2D, 3D). + dataaxis_1d: + Only meaningful for 1D plots. The data to plot on the Y axis. + arrows: + List of arrow specifications to display. See the showcase notebooks for examples. + backend: + The backend to use to generate the figure. + """ # INPUTS ARE NOT GETTING PARSED BECAUSE WORKFLOWS RUN GET ON FINAL NODE # SO PARSING IS DELEGATED TO NODES. @@ -106,7 +169,7 @@ def geometry_plot(geometry: Geometry, axes: Axes = ["x", "y", "z"], atoms: Atoms sc_atoms = stack_sc_data(tiled_atoms, newname="sc_atom", dims=["atom"]) projected_atoms = project_to_axes(sc_atoms, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d) - atoms_scale = sanitize_scale(atoms_scale, ndim, atoms_ndim_scale) + atoms_scale = _sanitize_scale(atoms_scale, ndim, atoms_ndim_scale) final_atoms = scale_variable(projected_atoms, "size", scale=atoms_scale) atom_mode = _get_atom_mode(drawing_mode, ndim) atom_plottings = draw_xarray_xy( @@ -116,7 +179,7 @@ def geometry_plot(geometry: Geometry, axes: Axes = ["x", "y", "z"], atoms: Atoms # Here we start to process bonds bonds = find_all_bonds(geometry) - show_bonds = hide_1D(show_bonds, ndim) + show_bonds = matches(ndim, 1, False, show_bonds) styled_bonds = style_bonds(bonds, bonds_style) bonds_dataset = add_xyz_to_bonds_dataset(styled_bonds) bonds_filter = sanitize_bonds_selection(bonds_dataset, sanitized_atoms, bind_bonds_to_ats, show_bonds) @@ -126,12 +189,12 @@ def geometry_plot(geometry: Geometry, axes: Axes = ["x", "y", "z"], atoms: Atoms projected_bonds = project_to_axes(tiled_bonds, axes=axes) bond_lines = bonds_to_lines(projected_bonds, points_per_bond=points_per_bond) - bonds_scale = sanitize_scale(bonds_scale, ndim, bonds_ndim_scale) + bonds_scale = _sanitize_scale(bonds_scale, ndim, bonds_ndim_scale) final_bonds = scale_variable(bond_lines, "width", scale=bonds_scale) bond_plottings = draw_xarray_xy(data=final_bonds, x="x", y="y", z=z, set_axequal=True, name="Bonds", colorscale=bonds_colorscale) # And now the cell - show_cell = hide_1D(show_cell, ndim) + show_cell = matches(ndim, 1, False, show_cell) cell_plottings = cell_plot_actions( cell=geometry, show_cell=show_cell, cell_style=cell_style, axes=axes, dataaxis_1d=dataaxis_1d @@ -139,7 +202,7 @@ def geometry_plot(geometry: Geometry, axes: Axes = ["x", "y", "z"], atoms: Atoms # And the arrows arrow_data = sanitize_arrows(geometry, arrows, atoms=sanitized_atoms, ndim=ndim, axes=axes) - arrow_plottings = get_arrow_plottings(projected_atoms, arrow_data, nsc=nsc) + arrow_plottings = _get_arrow_plottings(projected_atoms, arrow_data, nsc=nsc) all_actions = plot_actions.combined(bond_plottings, atom_plottings, cell_plottings, arrow_plottings, composite_method=None) @@ -151,4 +214,118 @@ class GeometryPlot(Plot): @property def geometry(self): - return self.nodes.inputs['geometry']._output \ No newline at end of file + return self.nodes.inputs['geometry']._output + +_T = TypeVar("_T", list, tuple, dict) + +def _sites_specs_to_atoms_specs(sites_specs: _T) -> _T: + + if isinstance(sites_specs, dict): + if "sites" in sites_specs: + sites_specs = sites_specs.copy() + sites_specs['atoms'] = sites_specs.pop('sites') + return sites_specs + else: + return type(sites_specs)(_sites_specs_to_atoms_specs(style_spec) for style_spec in sites_specs) + +def sites_plot( + sites_obj: BrillouinZone, + axes: Axes = ["x", "y", "z"], + sites: AtomsArgument = None, + sites_style: Sequence[AtomsStyleSpec] = [], + sites_scale: float = 1., + sites_name: str = "Sites", + sites_colorscale: Optional[str] = None, + drawing_mode: Literal["scatter", "balls", "line", None] = None, + show_cell: Literal["box", "axes", False] = False, + cell_style: StyleSpec = {}, + nsc: Tuple[int, int, int] = (1, 1, 1), + sites_ndim_scale: Tuple[float, float, float] = (1, 1, 1), + dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, + arrows: Sequence[AtomArrowSpec] = (), + backend="plotly", +) -> Figure: + """Plots sites from an object that can be parsed into a geometry. + + The only differences between this plot and a geometry plot is the naming of the inputs + and the fact that there are no options to plot bonds. + + Parameters + ---------- + bz: + The brillouin zone object containing the k points to plot. + axes: + The axes to project the sites to. + sites: + The sites to plot. If None, all sites are plotted. + sites_style: + List of style specifications for the sites. See the showcase notebooks for examples. + sites_scale: + Scaling factor for the size of all sites. + sites_name: + Name to give to the trace that draws the sites. + sites_colorscale: + Colorscale to use for the sites in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + drawing_mode: + The method used to draw the sites. + show_cell: + Mode to display the reciprocal cell. If False, the cell is not displayed. + cell_style: + Style specification for the reciprocal cell. See the showcase notebooks for examples. + nsc: + Number of unit cells to display in each direction. + sites_ndim_scale: + Scaling factor for the size of the sites for different dimensionalities (1D, 2D, 3D). + dataaxis_1d: + Only meaningful for 1D plots. The data to plot on the Y axis. + arrows: + List of arrow specifications to display. See the showcase notebooks for examples. + backend: + The backend to use to generate the figure. + """ + + # INPUTS ARE NOT GETTING PARSED BECAUSE WORKFLOWS RUN GET ON FINAL NODE + # SO PARSING IS DELEGATED TO NODES. + axes = sanitize_axes(axes) + fake_geometry = sites_obj_to_geometry(sites_obj) + sanitized_sites = sanitize_atoms(fake_geometry, atoms=sites) + ndim = get_ndim(axes) + z = get_z(ndim) + + # Process sites + atoms_style = _sites_specs_to_atoms_specs(sites_style) + parsed_sites_style = parse_atoms_style(fake_geometry, atoms_style=atoms_style) + sites_dataset = add_xyz_to_dataset(parsed_sites_style) + filtered_sites = select(sites_dataset, "atom", sanitized_sites) + tiled_sites = tile_data_sc(filtered_sites, nsc=nsc) + sc_sites = stack_sc_data(tiled_sites, newname="sc_atom", dims=["atom"]) + projected_sites = project_to_axes(sc_sites, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d) + + sites_scale = _sanitize_scale(sites_scale, ndim, sites_ndim_scale) + final_sites = scale_variable(projected_sites, "size", scale=sites_scale) + sites_mode = _get_atom_mode(drawing_mode, ndim) + site_plottings = draw_xarray_xy( + data=final_sites, x="x", y="y", z=z, width="size", what=sites_mode, colorscale=sites_colorscale, + set_axequal=True, name=sites_name, + ) + + # And now the cell + show_cell = matches(ndim, 1, False, show_cell) + cell_plottings = cell_plot_actions( + cell=fake_geometry, show_cell=show_cell, cell_style=cell_style, + axes=axes, dataaxis_1d=dataaxis_1d + ) + + # And the arrows + atom_arrows = _sites_specs_to_atoms_specs(arrows) + arrow_data = sanitize_arrows(fake_geometry, atom_arrows, atoms=sanitized_sites, ndim=ndim, axes=axes) + arrow_plottings = _get_arrow_plottings(projected_sites, arrow_data, nsc=nsc) + + all_actions = plot_actions.combined(site_plottings, cell_plottings, arrow_plottings, composite_method=None) + + return get_figure(backend=backend, plot_actions=all_actions) + +class SitesPlot(Plot): + + function = staticmethod(sites_plot) diff --git a/src/sisl/viz/processors/geometry.py b/src/sisl/viz/processors/geometry.py index 36c6e4086..e1c5f9f58 100644 --- a/src/sisl/viz/processors/geometry.py +++ b/src/sisl/viz/processors/geometry.py @@ -8,10 +8,8 @@ import numpy.typing as npt from xarray import Dataset -from sisl import Geometry, PeriodicTable -from sisl.lattice import Lattice, LatticeChild +from sisl import Geometry, PeriodicTable, BrillouinZone from sisl.messages import warn -from sisl.nodes import Node from sisl.typing import AtomsArgument from sisl.utils.mathematics import fnorm from sisl.viz.types import AtomArrowSpec @@ -476,4 +474,21 @@ def bonds_to_lines(bonds_data: BondsDataset, points_per_bond: int = 2) -> BondsD return bonds_data +def sites_obj_to_geometry(sites_obj: BrillouinZone): + """Converts anything that contains sites into a geometry. + + Possible conversions: + - BrillouinZone object to geometry, kpoints to atoms. + + Parameters + ----------- + sites_obj + the object to be converted. + """ + + if isinstance(sites_obj, BrillouinZone): + return Geometry(sites_obj.k.dot(sites_obj.rcell), lattice=sites_obj.rcell) + else: + raise ValueError(f"Cannot convert {sites_obj.__class__.__name__} to a geometry.") +