Skip to content

Commit

Permalink
Added sites plot.
Browse files Browse the repository at this point in the history
E.g. to visualize k points of a BZ object.
  • Loading branch information
pfebrer committed Sep 28, 2023
1 parent 41b6e49 commit 2a713cb
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 40 deletions.
42 changes: 28 additions & 14 deletions src/sisl/viz/_plotables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
This file provides tools to handle plotability of objects
"""
from collections import ChainMap
import inspect
from typing import Sequence, Type

Expand All @@ -15,11 +16,21 @@
class ClassPlotHandler(ClassDispatcher):
"""Handles all plotting possibilities for a class"""

def __init__(self, *args, **kwargs):
def __init__(self, cls, *args, inherited_handlers = (), **kwargs):
self._cls = cls
if not "instance_dispatcher" in kwargs:
kwargs["instance_dispatcher"] = ObjectPlotHandler
kwargs["type_dispatcher"] = None
super().__init__(*args, **kwargs)
super().__init__(*args, inherited_handlers=inherited_handlers, **kwargs)

self._dispatchs = ChainMap(self._dispatchs, *[handler._dispatchs for handler in inherited_handlers])

def set_default(self, key: str):
"""Sets the default plotting function for the class."""
if key not in self._dispatchs:
raise KeyError(f"Cannot set {key} as default since it is not registered.")
self._default = key



class ObjectPlotHandler(ObjectDispatcher):
Expand All @@ -37,7 +48,9 @@ def __init__(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
"""If the plot handler is called, we will run the default plotting function
unless the keyword method has been passed."""
return getattr(self, kwargs.pop("method", self._default) or self._default)(*args, **kwargs)
if self._default is None:
raise TypeError(f"No default plotting function has been defined for {self._obj.__class__.__name__}.")
return getattr(self, self._default)(*args, **kwargs)


class PlotDispatch(AbstractDispatch):
Expand Down Expand Up @@ -146,17 +159,22 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N
name = plot_cls.plot_class_key()

# Check if we already have a plot_handler
plot_handler = plotable.__dict__.get(plot_handler_attr, None)
plot_handler = getattr(plotable, plot_handler_attr, None)

# If it's the first time that the class is being registered,
# let's give the class a plot handler
if not isinstance(plot_handler, ClassPlotHandler):
if not isinstance(plot_handler, ClassPlotHandler) or plot_handler._cls is not plotable:

if isinstance(plot_handler, ClassPlotHandler):
inherited_handlers = [plot_handler]
else:
inherited_handlers = []

# If the user is passing an instance, we get the class
if not isinstance(plotable, type):
plotable = type(plotable)

setattr(plotable, plot_handler_attr, ClassPlotHandler(plot_handler_attr))
setattr(plotable, plot_handler_attr, ClassPlotHandler(plotable, plot_handler_attr, inherited_handlers=inherited_handlers))

plot_handler = getattr(plotable, plot_handler_attr)

Expand All @@ -166,9 +184,6 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N

def register_data_source(data_source_cls, plot_cls, setting_key, name=None, default: Sequence[Type] = [], plot_handler_attr='plot', **kwargs):

signatures = {}
params_infos = {}

plot_cls_params = {
name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY)
for name, param in inspect.signature(plot_cls).parameters.items() if name != setting_key
Expand Down Expand Up @@ -200,17 +215,16 @@ def register_data_source(data_source_cls, plot_cls, setting_key, name=None, defa

signature = signature.replace(parameters=new_parameters)

signatures[plotable] = signature
params_infos[plotable] = {
params_info = {
"data_args": data_args,
"replaced_data_args": replaced_data_args,
"data_var_kwarg": data_var_kwarg,
"plot_var_kwarg": new_parameters[-1].name if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD else None
}

def _plot(obj, *args, **kwargs):
sig = signatures[type(obj)]
params_info = params_infos[type(obj)]
def _plot(obj, *args, __params_info=params_info, __signature=signature, **kwargs):
sig = __signature
params_info = __params_info

bound = sig.bind_partial(**kwargs)

Expand Down
7 changes: 6 additions & 1 deletion src/sisl/viz/_plotables_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# the data source can digest

register_data_source(PDOSData, PdosPlot, "pdos_data", default=[siesta.pdosSileSiesta])
register_data_source(BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta, sisl.BandStructure])
register_data_source(BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta])
register_data_source(FatBandsData, FatbandsPlot, "bands_data")
register_data_source(EigenstateData, WavefunctionPlot, "eigenstate", default=[sisl.EigenstateElectron])

Expand All @@ -55,3 +55,8 @@

# # Grid
register(sisl.Grid, GridPlot, 'grid', default=True)

# Brilloiun zone
register(sisl.BrillouinZone, SitesPlot, 'sites_obj')

sisl.BandStructure.plot.set_default("bands")
9 changes: 8 additions & 1 deletion src/sisl/viz/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
"""
Plots
=====
Module containing all implemented plots, both in a functional form and as Workflows.
"""

from .bands import BandsPlot, FatbandsPlot, bands_plot, fatbands_plot
from .geometry import GeometryPlot, geometry_plot
from .geometry import GeometryPlot, geometry_plot, sites_plot, SitesPlot
from .grid import GridPlot, WavefunctionPlot, grid_plot, wavefunction_plot
from .merged import merge_plots
from .pdos import PdosPlot, pdos_plot
Loading

0 comments on commit 2a713cb

Please sign in to comment.