diff --git a/docs/api/index.rst b/docs/api/index.rst index 903f4ea5a8..df6c166fdd 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -25,6 +25,7 @@ All methods and submodules are listed :ref:`here ` and default_geom physics mixing + viz/index unit_constant utilities @@ -38,5 +39,4 @@ All methods and submodules are listed :ref:`here ` and :caption: Advanced usage nodes - viz/index diff --git a/docs/api/viz/index.rst b/docs/api/viz/index.rst index 82f5010485..e8bc0c3019 100644 --- a/docs/api/viz/index.rst +++ b/docs/api/viz/index.rst @@ -4,12 +4,37 @@ Visualization ============= -.. currentmodule:: sisl.viz +.. module:: sisl.viz -Visualizations of `sisl` objects and data. +The visualization module contains tools to plot common visualizations, as well +as to create custom visualizations that support multiple plotting backends +automatically. +Plot classes +----------------- + +Plot classes are workflow classes that implement some specific plotting. + +.. autosummary:: + :toctree: generated/ + + Plot + BandsPlot + FatbandsPlot + GeometryPlot + SitesPlot + GridPlot + WavefunctionPlot + PdosPlot + +Utilities +--------- + +Utilities to build custom plots .. autosummary:: - :toctree: generated/ - :recursive: + :toctree: generated/ + get_figure + merge_plots + Figure diff --git a/docs/tutorials/tutorial_es_1.ipynb b/docs/tutorials/tutorial_es_1.ipynb index f10e17fd44..46ec752227 100644 --- a/docs/tutorials/tutorial_es_1.ipynb +++ b/docs/tutorials/tutorial_es_1.ipynb @@ -9,6 +9,7 @@ "import numpy as np\n", "from sisl import *\n", "import sisl.viz\n", + "from sisl.viz import merge_plots\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] @@ -119,8 +120,8 @@ "system = graphene.remove(index)\n", "graphene.plot(axes=\"xy\", atoms_style=[\n", " {\"opacity\": 0.5}, # Default style for all atoms\n", - " {\"atoms\": indices, \"color\": \"black\", \"size\": 20, \"opacity\": 1}, # Styling for indices_close_to_center on top of defaults.\n", - " {\"atoms\": index, \"color\": \"red\", \"size\": 10, \"opacity\": 1} # Styling for center_atom_index on top of defaults.\n", + " {\"atoms\": indices, \"color\": \"black\", \"size\": 1.2, \"opacity\": 1}, # Styling for indices_close_to_center on top of defaults.\n", + " {\"atoms\": index, \"color\": \"red\", \"size\": 1, \"opacity\": 1} # Styling for center_atom_index on top of defaults.\n", "])" ] }, @@ -195,12 +196,13 @@ "es = H.eigenstate()\n", "# Reduce the contained eigenstates to only 3 states around the Fermi-level\n", "es_fermi = es.sub(range(len(H) // 2 - 1, len(H) // 2 + 2))\n", - "system.plot(\n", - " subplots=\"atoms_style\", cols=3,\n", - " axes=\"xy\", \n", - " atoms_style=[{\"size\": n * 300, \"color\": c}\n", - " for n, c in zip(es_fermi.norm2(sum=False), (\"red\", \"blue\", \"green\"))]\n", - ")" + "\n", + "plots = [\n", + " system.plot(axes=\"xy\", atoms_style=[{\"size\": n * 20, \"color\": c}]) \n", + " for n, c in zip(es_fermi.norm2(sum=False), (\"red\", \"blue\", \"green\"))\n", + "]\n", + "\n", + "merge_plots(*plots, composite_method=\"subplots\", cols=3)" ] }, { @@ -243,7 +245,7 @@ "E = np.linspace(-1, -.5, 100)\n", "dE = E[1] - E[0]\n", "PDOS = es.PDOS(E).sum((0, 2)) * dE # perform integration\n", - "system.plot(axes=\"xy\", atoms_style={\"size\": PDOS * 300})\n", + "system.plot(axes=\"xy\", atoms_style={\"size\": PDOS * 15})\n", "#plt.scatter(system.xyz[:, 0], system.xyz[:, 1], 500 * PDOS);\n", "#plt.scatter(xyz_remove[0], xyz_remove[1], c='k', marker='*'); # mark the removed atom" ] @@ -274,8 +276,8 @@ "source": [ "band = BandStructure(H, [[0, 0, 0], [0, 0.5, 0], \n", " [1/3, 2/3, 0], [0, 0, 0]], 400, \n", - " [r'$\\Gamma$', r'$M$', \n", - " r'$K$', r'$\\Gamma$'])" + " [r'Gamma', r'M', \n", + " r'K', r'Gamma'])" ] }, { @@ -479,7 +481,7 @@ "metadata": {}, "outputs": [], "source": [ - "grid.plot(axes=\"xy\", xaxis_range=(0, None))" + "grid.plot(axes=\"xy\")" ] }, { @@ -517,9 +519,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.9" + "version": "3.11.4" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/tutorials/tutorial_siesta_1.ipynb b/docs/tutorials/tutorial_siesta_1.ipynb index 72190fd926..7bae56496f 100644 --- a/docs/tutorials/tutorial_siesta_1.ipynb +++ b/docs/tutorials/tutorial_siesta_1.ipynb @@ -11,6 +11,7 @@ "import numpy as np\n", "from sisl import *\n", "import sisl.viz\n", + "from sisl.viz import merge_plots\n", "from functools import partial\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" @@ -234,16 +235,17 @@ "# Find the index of the smallest positive eigenvalue\n", "idx_lumo = (es.eig > 0).nonzero()[0][0]\n", "es = es.sub([idx_lumo - 1, idx_lumo])\n", - "h2o.plot(\n", - " subplots=\"atoms_style\", cols=2,\n", - " axes=\"xy\", \n", - " atoms_style=[{\"size\": n * 30, \"color\": c}\n", - " for n, c in zip(h2o.apply(es.norm2(sum=False),\n", + "\n", + "plots = [\n", + " h2o.plot(axes=\"xy\", atoms_style={\"size\": n * 1.5, \"color\": c})\n", + " for n, c in zip(h2o.apply(es.norm2(sum=False),\n", " np.sum,\n", " mapper=partial(h2o.a2o, all=True),\n", " axis=1),\n", - " (\"red\", \"blue\", \"green\"))]\n", - ")" + " (\"red\", \"blue\", \"green\"))\n", + "]\n", + "\n", + "merge_plots(*plots, composite_method=\"subplots\", cols=2)" ] }, { @@ -312,6 +314,13 @@ "DM.density(diff)\n", "print('Real space integrated density difference: {:.3e}'.format(diff.grid.sum() * diff.dvolume))" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -330,9 +339,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.9" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/tutorials/tutorial_siesta_2.ipynb b/docs/tutorials/tutorial_siesta_2.ipynb index 658d90a746..5961b436a1 100644 --- a/docs/tutorials/tutorial_siesta_2.ipynb +++ b/docs/tutorials/tutorial_siesta_2.ipynb @@ -363,9 +363,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.9" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/visualization/viz_module/basic-tutorials/Demo.ipynb b/docs/visualization/viz_module/basic-tutorials/Demo.ipynb index a2c4036812..f8dcafe17d 100644 --- a/docs/visualization/viz_module/basic-tutorials/Demo.ipynb +++ b/docs/visualization/viz_module/basic-tutorials/Demo.ipynb @@ -33,8 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "import sisl.viz\n", - "from sisl.viz import Plot" + "import sisl.viz" ] }, { @@ -53,93 +52,31 @@ " \n", "Note\n", " \n", - "If you use sisl to run high performance calculations where you initialize sisl frequently it's better to have the autoloading turned off (default), as it might introduce an overhead of about a second.\n", + "If you use sisl to run high performance calculations where you initialize sisl frequently it's better to have the autoloading turned off (default), as it might introduce an overhead.\n", " \n", "\n", "\n", "Now that the framework has been loaded, we can start plotting!\n", "\n", - "## My first plots\n", + "## Your first plots\n", "\n", - "The most straightforward way to plot things in sisl is to call the `Plot` class, which you can import as shown in the next cell:" + "The most straightforward way to plot things in sisl is to call their `plot` method. For example if we have the path to a bands file we can call plot:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "from sisl.viz import Plot\n", - "\n", - "Plot(siesta_files / \"SrTiO3.bands\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note how we just passed the path to our bands file and **sisl recognized what was the plot that we wanted to generate**.\n", - "\n", - "Let's try now passing a *.RHO* file to check the electronic density:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Plot(siesta_files / \"SrTiO3.RHO\", axes=\"xy\", nsc=[2,1,1], zsmooth='best')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You probably noticed that we used some extra arguments (`axes`, `nsc`, `zsmooth`) to get the exact plot that we wanted. These arguments are called **settings**. Settings define how the plot will process and show your plot. You can **provide settings on initialization or update them later**. \n", - "\n", - "`Plot()` returns a plot object. If you want to keep that plot object for later (to do any modification on it) you will have to, of course, store it in a variable. Let's do that:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rho_plot = Plot(siesta_files / \"SrTiO3.RHO\", axes=\"xy\", nsc=[2,1,1], zsmooth=\"best\")" - ] - }, - { - "cell_type": "markdown", "metadata": {}, - "source": [ - "And now that we have it, let's try to get some help from it to understand the plot object better." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, "outputs": [], "source": [ - "print(rho_plot.__class__)\n", - "print(rho_plot.__doc__)" + "sisl.get_sile(siesta_files / \"SrTiO3.bands\").plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can see two interesting things:\n", - "\n", - "- Our plot is a `GridPlot`, not simply a `Plot`. This means that **it knows you are dealing with a grid** and consequently it will help you by **providing useful methods and settings**.\n", - "- On the documentation, under `Parameters`, you can see the arguments that this plot understands. If you've guessed these are the so-called *settings*, then you've guessed right! A way to know the current settings of your plot is to check the `settings` attribute:\n" + "You can pass arguments to the plotting function:" ] }, { @@ -148,16 +85,15 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot.settings" + "rho_file = sisl.get_sile(siesta_files / \"SrTiO3.RHO\")\n", + "rho_file.plot(axes=\"xy\", nsc=[2,1,1], smooth=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The names might already give you a quick intuition of what each setting does, but for more detail you can go to the documentation. The **showcase notebooks** show examples and are designed to **help you understand what each setting does in a visual way**. It is always worth checking them out if you are dealing with a new plot type!\n", - "\n", - "One of the interesting methods that grid plots have is the `scan` method. Here we use it to do a simple scan of 15 steps with the default settings, but you can play with it:" + "Some objects can be plotted in different ways, and just calling `plot` will do it in the default way. You can however **choose which plot you want** from the available representations. For example, out of a PDOS file you can plot the PDOS (the default):" ] }, { @@ -166,49 +102,16 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot.scan(along=\"z\", num=15)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Plotable objects\n", + "pdos_file = sisl.get_sile(siesta_files / \"SrTiO3.PDOS\")\n", "\n", - "In this section we'd like to point out that using the `Plot` class is not the most convenient thing for day to day usage. Instead, everything that is *plotable* in `sisl` will receive a `plot` method that you can use. One example of a plotable object is the `bandsSileSiesta`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "bands_sile = sisl.get_sile(siesta_files/\"SrTiO3.bands\")\n", - "bands_sile.plot()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Although the `plot` attribute is not exactly a method, but **a manager that organizes all the plotting possibilities for an object**. If you call it, as we did, you get the default plot, but you can specify which plot type you want specifically:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "bands_sile.plot.bands(bands_color=\"red\")" + "pdos_file.plot(groups=[{\"species\": \"O\", \"name\": \"O\"}, {\"species\": \"Ti\", \"name\": \"Ti\"}])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "One can quickly check what are the options:" + "or the geometry (not the default, you need to specify it):" ] }, { @@ -217,23 +120,17 @@ "metadata": {}, "outputs": [], "source": [ - "dir(bands_sile.plot)" + "pdos_file.plot.geometry(atoms_scale=0.7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And you then see that there's the option to plot the fatbands from this object. We won't do it here because it needs the `.WFSX` file, which we don't have." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Updating your plots\n", + "Updating your plots\n", + "----------------\n", "\n", - "As we mentioned earlier, **plots have settings and they can be updated**. This stems from the fact that the framework is designed with GUIs in mind, where the user will have visual input fields that they may tweak to see how the plot changes. So you might do as if you were interacting from a GUI and update the settings:" + "When you call `.plot()`, you receive a `Plot` object:" ] }, { @@ -242,30 +139,15 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot.update_settings(z_range=[1, 3], axes=\"xyz\", isos=[{\"frac\": 0.05, \"color\":\"lightgreen\", \"opacity\": 0.3}])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The most important thing is that, by the time we do this update, the *.RHO* file could have changed its location or even disappeared and it wouldn't matter. When you update a setting, **the plot reruns only from the point where that setting is used**. This **avoids rerunning time-consuming initializations** like reading a very big file or diagonalizing a hamiltonian.\n", - "\n", - "However, this is not the only useful point. Since **plots are self-contained**, you can **share this plot with someone else and they will be able to tweak all the settings** that they wish if they don't involve reading data again. Isn't this nice? This brings us to the next section. " + "pdos_plot = pdos_file.plot()\n", + "type(pdos_plot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Storing and loading plots" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "After a time-consuming calculation or data reading, you clearly want your results to be saved. Plots provide a `save` method:" + "`Plot` objects are a kind of `Workflow`. You can check the `sisl.nodes` documentation to understand what exactly this means. But long story short, this means that the computation is split in multiple nodes, as you can see in the following diagram:\n" ] }, { @@ -274,14 +156,14 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot.save(\"rho_plot.plot\")" + "pdos_plot.network.visualize(notebook=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "As discussed in the last paragraph of the previous section this **stores not only the current visualization**, but the full self contained plot that you can tweak as you wish when you load it again:" + "With that knowledge, when you update the inputs of a plot, only the necessary parts are recalculated. In that way, you may avoid repeating expensive calculations or reading to files that no longer exist. Inputs are updated with `update_inputs`:" ] }, { @@ -290,23 +172,14 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot_from_colleague = sisl.viz.load(\"rho_plot.plot\")" + "pdos_plot.update_inputs(Erange=[-3, 3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And do whatever you want with it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rho_plot_from_colleague.update_settings(axes=\"x\")" + "Some inputs are a bit cumbersome to write by hand, and therefore along your journey you'll find that plots have some helper methods to modify inputs much faster. For example, `PdosPlot` has the `split_DOS` method, which generates groups of orbitals for you." ] }, { @@ -315,16 +188,21 @@ "metadata": {}, "outputs": [], "source": [ - "# Let's clean the working directory.\n", - "import os\n", - "os.remove(\"rho_plot.plot\")" + "pdos_plot.split_DOS()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You might ask yourself now what happens if you just want to store the representation, not the full self-contained plot. For this, we first need to discuss the next section." + "
\n", + " \n", + "Don't worry!\n", + " \n", + "Each plot class has its own dedicated notebook in the documentation to guide you through all the knobs that they have!\n", + " \n", + "
\n", + "\n" ] }, { @@ -333,9 +211,9 @@ "source": [ "## Different plotting backends\n", "\n", - "Hidden between all the settings, you can find a **very special setting**: `backend`.\n", + "Hidden between all the inputs, you can find a **very special input**: `backend`.\n", "\n", - "Initially, the visualization framework was written to plot things using `plotly`. However, we noticed that this might not be the appropiate choice for everyone. Therefore, we changed the design to make it very modular and **allow rendering the plot with any framework you like**. There's a dedicated notebook on how to register your own backends. Here however we just want to show you **how you can switch between the sisl-provided backends**. It is very simple:" + "This input allows you to choose the plotting backend used to display a plot. If you don't like the default one, just change it!" ] }, { @@ -344,34 +222,14 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot_from_colleague.update_settings(backend=\"matplotlib\", axes=\"x\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rho_plot_from_colleague.update_settings(backend=\"plotly\", axes=\"xy\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "rho_plot_from_colleague.update_settings(backend=\"matplotlib\")" + "pdos_plot.update_inputs(backend=\"matplotlib\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note that you can always know what backends are available for the plot by checking the options of the `backend` parameter:" + "Let's go back to the default one, `plotly`." ] }, { @@ -380,18 +238,16 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot_from_colleague.get_param(\"backend\").options" + "pdos_plot.update_inputs(backend=\"plotly\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note that the options here will only show the backends that have been loaded. `sisl` only **loads backends if the required python packages are present**. Currently, `sisl` provides backends for three frameworks: `plotly`, `matplotlib` and `blender`.\n", + "## Further customization\n", "\n", - "If you have one backend selected, you will have available all the methods that the framework provides. For example, if you are using the `matplotlib` backend, you can use all the methods that matplotlib implements for the `Axes` object directly on the plot. You also have the figure (axes) under the `figure` (`axes`) attribute, for whatever you want to do.\n", - "\n", - "Let's for example draw a line, using `Axes.plot`:" + "If you are a master of some backend, you'll be happy to know that you can run any backend specific method on the plot. For example, plotly has a method called `add_vline` that draws a vertical line:" ] }, { @@ -400,15 +256,14 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot_from_colleague.plot([1,2,3,4], [0,0,1,2])\n", - "rho_plot_from_colleague" + "pdos_plot.add_vline(-1).add_vline(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And now let's do the same with `plotly`. In this case, all methods are looked for in the `Figure` object that is stored under the `figure` attribute." + "In fact, if you need the raw figure for something, you can find it under the `figure` attribute." ] }, { @@ -417,21 +272,18 @@ "metadata": {}, "outputs": [], "source": [ - "rho_plot_from_colleague.update_settings(backend=\"plotly\")\n", - "rho_plot_from_colleague.add_scatter(x=[1,2,3,4], y=[0,0,1,2])" + "type(pdos_plot.figure)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "At this point, you probably already know how you will be able to save these plots to images, html or whatever other format. Use **the methods that each framework provides**!\n", - "\n", - "Also, this will also allow you to modify the plot as you wish (*adding lines, changing titles, showing legends...*) once `sisl` has render it. Again, you just have to use the methods that the framework provides to do so :)\n", - "\n", "## Discover more\n", "\n", - "Until here, we have covered the most basic concepts of the framework. If you enjoyed it, we encourage you to check the rest of notebooks to find out about more specific and complex aspects of it." + "This notebook has shown you the most basic features of the framework with the hope that you will be hooked into it :)\n", + "\n", + "If it succeeded, we invite you to check the rest of the documentation. **It only gets better from here!**" ] }, { @@ -452,16 +304,23 @@ }, "outputs": [], "source": [ - "thumbnail_plot = rho_plot_from_colleague\n", + "thumbnail_plot = rho_file.plot(axes=\"xy\", nsc=[2,1,1], smooth=True)\n", "\n", "if thumbnail_plot:\n", " thumbnail_plot.show(\"png\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -475,9 +334,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/visualization/viz_module/blender/First animation.rst b/docs/visualization/viz_module/blender/First animation.rst new file mode 100644 index 0000000000..fb127adb0b --- /dev/null +++ b/docs/visualization/viz_module/blender/First animation.rst @@ -0,0 +1,28 @@ +First animation +--------------- + +Below is a script that generates an animation of graphene breathing in blender: + +.. code-block:: python + + import sisl + import sisl.viz + from sisl.viz import merge_plots + + plots = [] + for color, opacity, scale in zip(["red", "orange", "green"], [1, 0.2, 1], [0.5, 1, 0.5]): + geom_plot = sisl.geom.graphene().plot(backend="blender", + atoms_style={"color": color, "opacity": opacity}, + bonds_scale=0.01, + atoms_scale=scale + ) + + plots.append(geom_plot) + + merge_plots(*plots, backend="blender", composite_method="animation", interpolated_frames=50).show() + +.. raw:: html + +
View post on imgur.com
+ + diff --git a/docs/visualization/viz_module/blender/Getting started.rst b/docs/visualization/viz_module/blender/Getting started.rst index 9fcd4931f9..6f89948b7a 100644 --- a/docs/visualization/viz_module/blender/Getting started.rst +++ b/docs/visualization/viz_module/blender/Getting started.rst @@ -23,22 +23,21 @@ shipped with. Being `blender` the name of the executable, you can run:: blender -b --python-expr "import sys; print(f'PYTHON VERSION: {sys.version}')" -In blender 2.93 it gives an output that looks like this:: +In blender 3.6 it gives an output that looks like this:: + + Blender 3.6.3 (hash d3e6b08276ba built 2023-09-21 06:13:29) + PYTHON VERSION: 3.10.12 (main, Aug 14 2023, 22:14:01) [GCC 11.2.1 20220127 (Red Hat 11.2.1-9)] - Blender 2.93.4 (hash b7205031cec4 built 2021-08-31 23:36:18) - Read prefs: /home/.config/blender/2.93/config/userpref.blend - PYTHON VERSION: 3.9.2 (default, Feb 25 2021, 12:19:39) - [GCC 9.3.1 20200408 (Red Hat 9.3.1-2)] Blender quit -Therefore, we know that **blender 2.93.4 uses python 3.9.2.** +Therefore, we know that **blender 3.6.3 uses python 3.10.12.** 3. **Create an environment with that python version** and install sisl (*skip if you have it already*). In this case, we will use conda as the environment manager, since it lets us very easily select the python version. -You probably don't need the exact micro version. In our case asking for ``3.9`` is enough:: +You probably don't need the exact micro version. In our case asking for ``3.10`` is enough:: - conda create -n blender-python python=3.9 + conda create -n blender-python python=3.10 Then install all the packages you want to use in blender:: @@ -91,7 +90,7 @@ We want to plot graphene, so the simplest way is import sisl import sisl.viz - geom_plot = sisl.geom.graphene().plot(backend="blender") + geom_plot = sisl.geom.graphene().plot(backend="blender", bonds_scale=0.01) geom_plot.show() If we write these lines on the console, we should get the graphene structure in the viewport. diff --git a/docs/visualization/viz_module/combining-plots/Intro to combining plots.ipynb b/docs/visualization/viz_module/combining-plots/Intro to combining plots.ipynb new file mode 100644 index 0000000000..31db0d09dd --- /dev/null +++ b/docs/visualization/viz_module/combining-plots/Intro to combining plots.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Intro to combining plots\n", + "=====================\n", + "\n", + "In this notebook you will learn how to combine plots in a simple way." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Types of multiple plots\n", + "\n", + "There are three ways of combining your plots in the visualization framework, each with its associated class:\n", + "\n", + "- `\"multiple\"`: it's the most basic one. It just takes the drawinfs from all plots and displays them in the same plot.\n", + "- `\"subplots\"`: Creates a grid of subplots, where each item of the grid contains a plot.\n", + "- `\"multiple_x\"` and `\"multiple_y\"` (multiple_A): Creates a plot where a separate A axis is created for each plot, while the rest of axes are shared.\n", + "- `\"animation\"`: Creates an animation where each child plot is represented in a frame.\n", + "\n", + "They can all be acheived with the `merge_plots` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from sisl.viz import merge_plots" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a simple tight-binding model for *hBN* to experiment with it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sisl\n", + "import numpy as np\n", + "\n", + "r = np.linspace(0, 3.5, 50)\n", + "f = np.exp(-r)\n", + "\n", + "orb = sisl.AtomicOrbital('2pzZ', (r, f))\n", + "geom = sisl.geom.graphene(orthogonal=False, atoms=[sisl.Atom(5, orb), sisl.Atom(7, orb)])\n", + "geom = geom.move([0, 0, 5])\n", + "H = sisl.Hamiltonian(geom)\n", + "H.construct([(0.1, 1.44), (0, -2.7)], )\n", + "H[0, 0] = -0.7\n", + "H[1, 1] = 0.7" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Individual plots\n", + "\n", + "As an example, from the hamiltonian that we constructed, let's build a bands plot and a pdos plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "band_structure = sisl.BandStructure(\n", + " H, \n", + " [[0, 0, 0], [0, 0.5, 0],[1/3, 2/3, 0], [0, 0, 0]], \n", + " 400,\n", + " [r'Gamma', r'M',r'K', r'Gamma']\n", + ")\n", + "bands_plot = band_structure.plot()\n", + "pdos_plot = H.plot.pdos(data_Erange=[-10, 10], Erange=[-10,10], kgrid=[121, 121, 1], nE=1000).split_DOS(name=\"$species\")\n", + "\n", + "plots = [bands_plot, pdos_plot]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's check the plots individually:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bands_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pdos_plot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now, we will merge them.\n", + "\n", + "Merging into a single plot\n", + "----" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "merge_plots(*plots)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, `merge_plots` uses the `\"multiple\"` method to merge the plots. In this case, it is not very nice, because the two axes are different for bands and pdos.\n", + "\n", + "However, they have one axis in common! The energy axis. We can use this fact to combine them in a way that they share the energy axis but have each a separate one for the other axis. \n", + "\n", + "Independent axes\n", + "-------------\n", + "\n", + "First, we need to make sure that both energy axis are on the X or Y axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pdos_plot = pdos_plot.update_inputs(E_axis=\"y\")\n", + "bands_plot = bands_plot.update_inputs(E_axis=\"y\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And then we can use `multiple_x` so that each plot has a separate X axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "merge_plots(*plots, composite_method=\"multiple_x\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Much better, right? Now we can easily see that B contributes more to the bottom band, while N contributes more to the top band." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Subplots\n", + "--------\n", + "\n", + "Let's try now to use the `\"subplots\"` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "merge_plots(*plots, composite_method=\"subplots\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default it puts one plot on each row, but we can manage that with the arguments `rows` (number of rows), `cols` (number of columns), and `arrange` (if rows or cols are missing, way to determine the missing value, can be \"rows\", \"cols\" or \"square\").\n", + "\n", + "Let's put the two plots in separate columns:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "merge_plots(*plots, composite_method=\"subplots\", cols=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Merging merged plots\n", + "-------------------\n", + "\n", + "We can recursively merge plots. Unfortunately however, for the moment only the top level merge method is taken into account. The other levels are simply taken as `\"multiple\"`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "merged_plot = merge_plots(*plots, composite_method=\"multiple_x\")\n", + "\n", + "merge_plots(merged_plot, bands_plot, composite_method=\"subplots\", cols=2, backend=\"plotly\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the future, separate axes within subplots might be supported." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animations\n", + "----------\n", + "\n", + "Animations can be very cool but they are sometimes hard to build. `merge_plots` makes it as easy as possible for you, you just need to use the `\"animation\"` method.\n", + "\n", + "Let's create an animation to see the convergence of graphene's PDOS with the number of k points. We first create the plots:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the number of k points that we are going to try.\n", + "# Do 1 by 1 from 1 to 12 and then in steps of 5 from 15 to 90.\n", + "ks = [*np.arange(1, 12), *np.arange(15, 90, 5)]\n", + "\n", + "# Generate all plots. \n", + "# We use the scatter trace instead of a line because it looks better in animations :)\n", + "pdos_plots = [\n", + " H.plot.pdos(\n", + " data_Erange=[-10, 10], Erange=[-10,10], kgrid=[k, k, 1], nE=1000, line_mode=\"scatter\", line_scale=2\n", + " ).split_DOS(name=\"$species\") for k in ks\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now all the heavy computation is done! We can merge the plots into an animation, using the ks as frame names. Other arguments that you can pass to an animation are `frame_duration` (in ms), `transition` (in ms) and `redraw` (Wether to redraw the whole plot for each frame).\n", + "\n", + "
\n", + " \n", + "Note\n", + " \n", + "We suggest that you go to the last frame and click the house icon to set the y axis range. Then press play and see the PDOS converge!\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "merge_plots(*pdos_plots, composite_method=\"animation\", frame_names=ks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "------------------------------------------\n", + "This next cell is just to create a thumbnail" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "nbsphinx-thumbnail" + ] + }, + "outputs": [], + "source": [ + "merge_plots(*plots, composite_method=\"subplots\", cols=2).show(\"png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/visualization/viz_module/combining-plots/Intro to multiple plots.ipynb b/docs/visualization/viz_module/combining-plots/Intro to multiple plots.ipynb deleted file mode 100644 index f2479668da..0000000000 --- a/docs/visualization/viz_module/combining-plots/Intro to multiple plots.ipynb +++ /dev/null @@ -1,257 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Intro to multiple plots\n", - "=====================\n", - "\n", - "In this notebook you will learn what multiple plots can do for you and what is the best way to create them." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Multiple better than individual\n", - "\n", - "The power of multiple plots is threefold:\n", - "\n", - "- Multiple plot is just a coordinator. It doesn't hold any data related to it's childplots, which are full instances of `Plot` themselves. This allows you to **grab the separate child plots whenever you want**, as well as modifying only some of its child plots very easily.\n", - "- It will **create the layout for you** without requiring any effort on your side.\n", - "- If all your plots inside a multiple plot need to read from the same data, the plot will coordinate all plots so that they can share the data. Therefore, **data will be read only once, saving both time and memory**." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Types of multiple plots\n", - "\n", - "There are three ways of combining your plots in the visualization framework, each with its associated class:\n", - "\n", - "- In the same plot (`MultiplePlot`): it's the most basic one. It just takes the traces from all plots and displays them in the same plot.\n", - "- As subplots (`SubPlots`): Creates a grid of subplots, where each item of the grid contains a plot. Uses [plotly's subplots capabilities](https://plotly.com/python/subplots/).\n", - "- As frames of an animation (`Animation`): Creates an animation where each child plot is represented in a frame. [plotly's animation capabilities](https://plotly.com/python/animations/)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sisl.viz import MultiplePlot, Animation, SubPlots" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create a simple tight-binding model for the plots in this notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sisl\n", - "import numpy as np\n", - "\n", - "r = np.linspace(0, 3.5, 50)\n", - "f = np.exp(-r)\n", - "\n", - "orb = sisl.AtomicOrbital('2pzZ', (r, f))\n", - "geom = sisl.geom.graphene(orthogonal=True, atoms=sisl.Atom(6, orb))\n", - "geom = geom.move([0, 0, 5])\n", - "H = sisl.Hamiltonian(geom)\n", - "H.construct([(0.1, 1.44), (0, -2.7)], )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Merging existing plots\n", - "\n", - "This is the most simple way of creating a multiple plot: you just build your plots, and then pass them to the multiple plot constructor.\n", - "\n", - "However, this will miss one key feature of multiple plots. Since you've created each plot separately, **each plot has its own data**, even if they would be able to share it.\n", - "\n", - "Therefore, this is only recommended **when the plots are independent from each other**.\n", - "\n", - "As an example, from the hamiltonian that we constructed, let's build a wavefunction plot and a pdos plot:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wf_plot = H.plot.wavefunction(i=1, axes=\"xy\", transforms=[\"square\"], zsmooth=\"best\")\n", - "pdos_plot = H.plot.pdos(Erange=[-10,10])\n", - "\n", - "plots = [wf_plot, pdos_plot]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And now, we will merge them. There are two main ways to do the merge:\n", - "\n", - "- Calling the multiple plot class that we want to use (`MultiplePlot`, `Animation` or `Subplots`):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "# You just pass the plots and then any extra arguments for the plot class (see help(SubPlots))\n", - "SubPlots(plots=plots, cols=2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- Using the `merge` method that all plots have." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plots[0].merge(plots[1:], to=\"subplots\", cols=2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Both are exactly equivalent, but this second one is probably better since you don't need to import the class.\n", - "\n", - "You do need to specify somehow how to merge the plots though! As you may have noticed, there's a `to` argument that lets you specify how you want the plots to be merged.\n", - "\n", - "Here are the docs for `Plot.merge`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "help(plots[0].merge)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Let `MultiplePlot` handle plot creation\n", - "\n", - "As already mentioned, creating your plots beforehand is only a good strategy if plots are independent from each other, in the sense that they can not share data.\n", - "\n", - "In cases where plots can share data, just let `MultiplePlot` create your plots. It is so easy that you will end up doing it in this way all the time, even in cases where it doesn't have efficiency benefits :)\n", - "\n", - "Everytime you create a plot, there are three special keyword arguments that you can pass: `varying`, `animate` and `subplots`. These keywords let you easily initialize `MultiplePlot`, `Animation` and `Subplots` instances, respectively. \n", - "\n", - "They can be used in two ways:\n", - "\n", - "- You can pass a dictionary with the keys of the settings that you want to vary and the value for each \"step\"." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "H.plot.wavefunction(axes=\"xy\", transforms=[\"square\"], animate={\"i\":[1,2], \"zsmooth\": [\"best\", False]})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this case we animated the wavefunction plot to see the squares of wavefunctions 1 and 2. The second one, for some reason, we wanted it to display \"a bit\" pixelated.\n", - "\n", - "- You can also pass the list of values as regular settings and then inform the multiple plot keyword (in this case `subplots`) which settings to vary." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "H.plot.wavefunction(\n", - " colorscale=[\"temps\", \"portland\", \"peach\", \"viridis\"], i=[0,1,2,3], axes=\"xy\", transforms=[\"square\"], zsmooth=\"best\", \n", - " subplots=[\"colorscale\", \"i\"], rows=2\n", - ") " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There you go, four wavefunctions, each one displayed in a different colorscale :)\n", - "\n", - "Remember that these subplots are all sharing the same data, so the eigenstates of the hamiltonian have only been stored once!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "------------------------------------------\n", - "This next cell is just to create a thumbnail" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "nbsphinx-thumbnail" - ] - }, - "outputs": [], - "source": [ - "_.show(\"png\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/visualization/viz_module/diy/00-Intro.rst b/docs/visualization/viz_module/diy/00-Intro.rst deleted file mode 100644 index 41f33ae5c5..0000000000 --- a/docs/visualization/viz_module/diy/00-Intro.rst +++ /dev/null @@ -1,62 +0,0 @@ -Intro to the framework -====================== - -Before starting to show you how to build things, we might as well show you **what is it that will support your plots**. - - -The plotting backend --------------------- - -`Plotly `_ is the backend used to do all the plotting. You can check all the cool things that -can be done with it to get some inspiration. Its main strength is the interactivity it provides seamlessly, which gives -a dinamic feel to visualizations. - -.. note:: - In the future, plot classes might be able to support multiple plotting backends so that users can chose their preferred one. - - -sisl's wrapper --------------- - -If everything was done by plotting backend, we wouldn't call this the sisl visualization module. Sisl **wraps the -plotting process** to separate the data processing steps from the rendering steps, provide scientifically meaningful plot settings -that know which steps to run when updated, as well as scientifically meaningful methods to modify the plot and support to be displayed in -sisl's graphical interface. - - -The `Plot` class -################ - -Each representation is a python class that inherits from the :code:`Plot` class. We all have things in common, and so do plots. For this reason, we have put all the repetitive stuff in this class so that **you can focus on what makes your plot special**. - -But wait, there's more to this class. It will **control the flow of your plots** for you so that you don't need to think about it: - -*As an example, let's say you have developed a plot that reads data from a 20GB file and takes some bits of it to plot them. Now, 10 days later, another user, which is excited about the plot they got with almost no effort, wants to add a new line to the plot using the information already read. It would be a pity if the plot had to be reset and it took 5 more minutes to read the file again, right? This won't happen thanks to the :code:`Plot` class, because it automatically knows which methods to run in order to waste as little time as possible.* - -This control of the flow will also **make the behaviour of all the plots consistent**, so that you can confidently use a plot developed by another user because it will be familiar to you. - -This class is meant to **make your job as simple as possible**, so we encourage you to get familiar with it and understand all its possibilities. - -.. note :: - :code:`MultiplePlot`, :code:`SubPlots` and :code:`Animation` are classes that mostly work like :code:`Plot` but are adapted to particular use cases (and support multiprocessing to keep things fast). - - -The `Configurable` class -######################## - -Although you will probably not need to ever write this class' name in your code, it is good to know that every plot class you build automatically inherits from it. This will **make your plots automatically tunable** and it will provide them with some useful methods to **safely tweak parameters, keep a settings history**, etc... - -That's all you need to know for now, you will see more about the details in other notebooks. - - -The `Session` class -################### - -Just as :code:`Plot` is the parent of all plots, :code:`Session` is the parent of all sessions. **Sessions store plots and allow you to organize them into tabs.** They are specially useful for the `graphical user interface `_, where the users can easily see all their plots at the same time and easily modify them as they wish. - -However, clicking things to create your plots may be slow and specially annoying if you have to repeat the same process time and time again. That's why you have the possibility to **create custom sessions that will do all the repetitive work with very little input**, so that all the user needs to do is enjoy the beauty of their automatically created plots in the GUI. - -For an example on how to use sessions to your benefit, see `this notebook <../basic-tutorials/GUI%20with%20Python%20Demo.html>`_. - -.. note :: - You can find all these classes under :code:`sisl.viz.plotly`. diff --git a/docs/visualization/viz_module/diy/Adding new backends.ipynb b/docs/visualization/viz_module/diy/Adding new backends.ipynb index 869486f874..6cfc05574b 100644 --- a/docs/visualization/viz_module/diy/Adding new backends.ipynb +++ b/docs/visualization/viz_module/diy/Adding new backends.ipynb @@ -5,7 +5,11 @@ "id": "b8465167", "metadata": {}, "source": [ - "# Adding new backends " + "# Adding new backends\n", + "\n", + "This notebook displays how to integrate a new plotting backend to `sisl.viz`.\n", + "\n", + "Let's create a toy graphene band structure to illustrate the conceps throughout this notebook:" ] }, { @@ -18,7 +22,6 @@ "import sisl\n", "import sisl.viz\n", "\n", - "# This is a toy band structure to illustrate the concepts treated throughout the notebook\n", "geom = sisl.geom.graphene(orthogonal=True)\n", "H = sisl.Hamiltonian(geom)\n", "H.construct([(0.1, 1.44), (0, -2.7)], )\n", @@ -31,524 +34,185 @@ "id": "df72db95", "metadata": {}, "source": [ - "In the `sisl.viz` framework, the rendering part of the visualization is **completely detached from the processing part**. Because of that, we have the flexibility to add new ways of generating the final product by registering what we call `backends`.\n", - "\n", - "We will guide you through how you might customize this part of the framework. There are however, very distinct scenarios where you might find yourself. Each of the following sections explains the details of each situation, which are ordered in increasing complexity.\n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "Even if you want to go to the most complex situation, make sure that you first understand the simpler ones!\n", - " \n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "2cae0efe", - "metadata": {}, - "source": [ - "## Extending an existing backend\n", - "\n", - "This is by far the easiest situation. For example, `sisl` **already provides a backend to plot bands with** `plotly`, but **you are not totally happy with the way it's done**.\n", - "\n", - "In this case, you grab the provided backend:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7baccb3a", - "metadata": {}, - "outputs": [], - "source": [ - "from sisl.viz.backends.plotly import PlotlyBandsBackend" - ] - }, - { - "cell_type": "markdown", - "id": "acf63183", - "metadata": {}, - "source": [ - "And then create your own class that inherits from it:" + "The final display in the visualization module is controlled by the `Figure` class." ] }, { "cell_type": "code", "execution_count": null, - "id": "8d32d73e", + "id": "739c2e2f-ef7f-48dc-a757-10a92be82f4e", "metadata": {}, "outputs": [], "source": [ - "class MyOwnBandsBackend(PlotlyBandsBackend):\n", - " pass" + "from sisl.viz import Figure" ] }, { "cell_type": "markdown", - "id": "6dc45378", + "id": "11089ee9-431f-484d-8b83-ceddd5b0b00d", "metadata": {}, "source": [ - "The only thing left to do now is to **let** `BandsPlot` **know that there's a new backend available**. This action is called *registering* a backend." + "And backends are stored in `sisl.viz.figure.BACKENDS`. It is just a dictionary containing extensions of the `Figure` class for particular plotting frameworks." ] }, { "cell_type": "code", "execution_count": null, - "id": "062d33e4", + "id": "c16cebe2-faec-491b-89b8-12ef74d6c504", "metadata": {}, "outputs": [], "source": [ - "from sisl.viz import BandsPlot\n", + "from sisl.viz.figure import BACKENDS\n", "\n", - "BandsPlot.backends.register(\"plotly_myown\", MyOwnBandsBackend)\n", - "# Pass default=True if you want to make it the default backend" + "BACKENDS" ] }, { "cell_type": "markdown", - "id": "ac6d25e5", + "id": "fe8e1f24-a3f6-4b01-b980-7d4dce7cb74f", "metadata": {}, "source": [ - "All good, you can already use your new backend!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0c1192b3", - "metadata": {}, - "outputs": [], - "source": [ - "band_struct.plot(backend=\"plotly_myown\")" - ] - }, - { - "cell_type": "markdown", - "id": "71e6c80f", - "metadata": {}, - "source": [ - "Now that we know that it can be registered, we can try to add new functionality. But of course, **we need to know how the backend works** if we need to modify it. All backends to draw bands inherit from `BandsBackend`, and you can find some information there on how it works. Let's read its documentation:" + "Therefore, to add a new backend we must follow two steps:\n", + "1. **Subclass `Figure`**, adding backend specific functionality.\n", + "2. **Register** the backend.\n", + "\n", + "The documentation of the `Figure` class explains what you should do to extend it:" ] }, { "cell_type": "code", "execution_count": null, - "id": "56fd66cd", - "metadata": {}, + "id": "ae3dcf66-fd10-47bb-a706-b24b4bb2ba40", + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ - "from sisl.viz.backends.templates import BandsBackend\n", - "\n", - "print(BandsBackend.__doc__)" + "help(Figure)" ] }, { "cell_type": "markdown", - "id": "cec0e02b", + "id": "33aabd81-749e-468f-9f3b-fd19cf8704aa", "metadata": {}, "source": [ - "
\n", - "\n", - "Note\n", - " \n", - "This already gives you an overview of how the backend works. If you want to know the very fine details, you can always go to the source code.\n", - " \n", - "
\n", + "Therefore, we need to implement some of the methods of the `Figure` class. The more we implement, the more we will support `sisl.viz`.\n", "\n", - "So, clearly `PlotlyBandsBackend` already contains the `draw_gap` method, otherwise it would not work. \n", - "\n", - "From the workflow description, we understand that each band is drawn with the `_draw_band` method, which calls the generic `draw_line` method. In plotly, line information is passed as dictionaries that contain several parameters. One of them is, for example, `showlegend`, which controls whether the line appears in the legend. We can use therefore our plotly knowledge to only show at the legend those bands that are below the fermi level:" + "Here's an example of a very simple backend that just writes text:" ] }, { "cell_type": "code", "execution_count": null, - "id": "4a72408b", + "id": "b61046a2-d4a9-4e49-b1a7-3fbf11d2eee7", "metadata": {}, "outputs": [], "source": [ - "# Create my new backend\n", - "class MyOwnBandsBackend(PlotlyBandsBackend):\n", + "import numpy as np\n", + "class TextFigure(Figure):\n", + " \n", + " def _init_figure(self, *args, **kwargs):\n", + " self.text = \"\"\n", " \n", - " def _draw_band(self, x, y, *args, **kwargs):\n", - " kwargs[\"showlegend\"] = bool(y.max() < 0)\n", - " super()._draw_band(x, y, *args, **kwargs)\n", + " def clear(self):\n", + " self.text = \"\"\n", " \n", - "# And register it again\n", - "BandsPlot.backends.register(\"plotly_myown\", MyOwnBandsBackend)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bd11fcac", - "metadata": {}, - "outputs": [], - "source": [ - "band_struct.plot(backend=\"plotly_myown\")" - ] - }, - { - "cell_type": "markdown", - "id": "5e2cf114", - "metadata": {}, - "source": [ - "This is not very interesting, but it does its job at illustrating the fact that you can register a slightly modified *backend*.\n", - "\n", - "You could use your fresh knowledge to, for example draw something after the bands are drawn:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bdc1df56", - "metadata": {}, - "outputs": [], - "source": [ - "class MyOwnBandsBackend(PlotlyBandsBackend):\n", + " def draw_line(self, x, y, name, **kwargs):\n", + " self.text += f\"\\nLINE: {name}\\n{np.array(x)}\\n{np.array(y)}\"\n", " \n", - " def draw_bands(self, *args, **kwargs):\n", - " super().draw_bands(*args, **kwargs)\n", - " # Now that all bands are drawn, draw a very interesting line at -2eV.\n", - " self.add_hline(y=-2, line_color=\"red\")\n", + " def draw_scatter(self, x, y, name, **kwargs):\n", + " self.text += f\"\\nSCATTER: {name}\\n{np.array(x)}\\n{np.array(y)}\"\n", " \n", - "BandsPlot.backends.register(\"plotly_myown\", MyOwnBandsBackend)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "81d832d0", - "metadata": {}, - "outputs": [], - "source": [ - "band_struct.plot(backend=\"plotly_myown\")" - ] - }, - { - "cell_type": "markdown", - "id": "274e7b6b", - "metadata": {}, - "source": [ - "We finish this section by stating that:\n", - "\n", - "- To extend a backend, you have to have **some knowledge about the corresponding framework** (in this case `plotly`)\n", - "- You **don't need to create a new backend for every modification**. You can modify plots interactively however you want after the plot is generated. Creating a backend that extends an existing one is only useful if **there are changes that you will always want to do** because of personal preference or because you are building a graphical interface, for example." - ] - }, - { - "cell_type": "markdown", - "id": "537045a7", - "metadata": {}, - "source": [ - "## Creating a backend for a supported framework\n", - "\n", - "Now imagine that, for some reason, `sisl` didn't provide a `PlotlyBandsBackend`. However, `sisl` does have a generic plotly backend:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "86fa0668", - "metadata": {}, - "outputs": [], - "source": [ - "from sisl.viz.backends.plotly import PlotlyBackend" - ] - }, - { - "cell_type": "markdown", - "id": "90dcf337", - "metadata": {}, - "source": [ - "And also a generic bands backend:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "659a35bf", - "metadata": {}, - "outputs": [], - "source": [ - "from sisl.viz.backends.templates import BandsBackend" - ] - }, - { - "cell_type": "markdown", - "id": "cacb06ba", - "metadata": {}, - "source": [ - "In these cases, your situation is not that bad. As you saw, the template backends make use of generic functions like `draw_line` as much as they can, so the effort to implement a plotly bands backend is reduced to those things that can't be generalized in that way.\n", + " def show(self):\n", + " print(self.text)\n", "\n", - "One thing is for sure, we need to combine the two pieces to create the backend that we want:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8b06d177", - "metadata": {}, - "outputs": [], - "source": [ - "class MyPlotlyBandsBackend(BandsBackend, PlotlyBackend):\n", - " pass" + " def _ipython_display_(self):\n", + " self.show()" ] }, { "cell_type": "markdown", - "id": "ab103394", + "id": "946c792e-5152-4d09-a5eb-2ff5d1f2e422", "metadata": {}, "source": [ - "But is this enough? Let's see the documentation of `BandsBackend` one more time:" + "And all that is left now is to register the backend by simply adding it to the `BACKENDS` dictionary." ] }, { "cell_type": "code", "execution_count": null, - "id": "480b0224", + "id": "d1986892-98c0-4b9c-b8fe-735f00de30c2", "metadata": {}, "outputs": [], "source": [ - "print(BandsBackend.__doc__)" + "BACKENDS[\"text\"] = TextFigure" ] }, { "cell_type": "markdown", - "id": "0bef77a7", + "id": "5a7054c9-06ba-467a-9b40-f83a3612160d", "metadata": {}, "source": [ - "So, there are to things that need to be implemented: `draw_spin_textured_band` and `draw_gap`.\n", - "\n", - "We won't bother to give our backend support for spin texture representations, but the `draw_gap` method is compulsory, so we have no choice. Let's understand what is expected from this method:" + "Let's plot the bands to check that it works." ] }, { "cell_type": "code", "execution_count": null, - "id": "e5ca8118", + "id": "0af709ce-2b96-4a45-8264-7dcc87968b26", "metadata": {}, "outputs": [], "source": [ - "help(BandsBackend.draw_gap)" + "plot = band_struct.plot()\n", + "plot" ] }, { "cell_type": "markdown", - "id": "d21dcd7f", + "id": "c9ffe0d9-6ee8-4a61-996d-f30bd2ef2391", "metadata": {}, "source": [ - "Quite simple, isn't it? It seems like we are provided with the coordinates of the gap and then we can display it however we want." + "The default backend has been used, let's now change it to our new `\"text\"` backend." ] }, { "cell_type": "code", "execution_count": null, - "id": "68cbdb4e", - "metadata": {}, - "outputs": [], - "source": [ - "class MyPlotlyBandsBackend(BandsBackend, PlotlyBackend):\n", - " \n", - " def draw_gap(self, ks, Es, color, name, **kwargs):\n", - " \n", - " self.draw_line(\n", - " ks, Es, name=name,\n", - " text=f\"{Es[1]- Es[0]:.2f} eV\",\n", - " mode=\"lines+markers\",\n", - " line={\"color\": color},\n", - " marker_symbol = [\"triangle-up\", \"triangle-down\"],\n", - " marker={\"color\": color, \"size\": 20},\n", - " **kwargs\n", - " )\n", - "\n", - "# Make it the default backend for bands, since it is awesome.\n", - "BandsPlot.backends.register(\"plotly_fromscratch\", MyPlotlyBandsBackend, default=True)" - ] - }, - { - "cell_type": "markdown", - "id": "233dbbbc", + "id": "df588bab-8bd0-4dd0-b7e8-9c354dcb4b65", "metadata": {}, - "source": [ - "Let's see our masterpiece:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b8a0f8cc", - "metadata": { - "scrolled": false - }, "outputs": [], "source": [ - "band_struct.plot(gap=True)" - ] - }, - { - "cell_type": "markdown", - "id": "54b7f92b", - "metadata": {}, - "source": [ - "Beautiful!\n", - "\n", - "So, to end this section, just two remarks:\n", - "\n", - "- We have understood that if the framework is supported, the starting point is to **combine the generic backend for the framework** (`PlotlyBackend`) **with the template backend of the specific plot** (`BandsBackend`). Afterwards, we may have to tweak things a little.\n", - "- **Knowing how the generic framework backend works** helps to make your code simpler. E.g. if you check `PlotlyBackend.__doc__`, you will find that we could have easily included some defaults for the axes titles." + "plot.update_inputs(backend=\"text\")" ] }, { "cell_type": "markdown", - "id": "2b461994", - "metadata": {}, - "source": [ - "## Creating a backend for a non supported framework\n", - "\n", - "Armed with our knowledge from the previous sections, we face the most difficult of the challenges: *there's not even a generic backend for the framework that we want to use*.\n", - "\n", - "What we have to do is quite clear, **develop our own generic backend**. But how? Let's go to the `Backend` class for help:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a3a8142", + "id": "0f2397d3", "metadata": {}, - "outputs": [], "source": [ - "from sisl.viz.backends.templates import Backend\n", + "Not a very visually appealing backend, but it serves the purpose of demonstrating how it is done. Now it is your turn!\n", "\n", - "print(Backend.__doc__)" - ] - }, - { - "cell_type": "markdown", - "id": "1c1a01b0", - "metadata": {}, - "source": [ "
\n", - "\n", - "Note\n", - " \n", - "You can always look at the help of each specific method to understand exactly what you need to implement. E.g. `help(Backend.draw_line)`.\n", " \n", - "
\n", - "\n", - "To make it simple, let's say we want to create a backend for \"text\". This backend will **store everything as text in its state**, and it will print it on `show`. Here would be a minimal design:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b5201f73", - "metadata": {}, - "outputs": [], - "source": [ - "class TextBackend(Backend):\n", - " \n", - " def __init__(self, *args, **kwargs):\n", - " super().__init__(*args, **kwargs)\n", - " \n", - " self.text = \"\"\n", - " \n", - " def clear(self):\n", - " self.text = \"\"\n", - " \n", - " def draw_line(self, x, y, name, **kwargs):\n", - " self.text += f\"\\nLINE: {name}\\n{x}\\n{y}\"\n", + "Note\n", " \n", - " def draw_scatter(self, x, y, name, **kwargs):\n", - " self.text += f\"\\nSCATTER: {name}\\n{x}\\n{y}\"\n", - " \n", - " def draw_on(self, other_backend):\n", - " # Set the text attribute to the other backend's text, but store ours\n", - " self_text = self.text\n", - " self.text = other_backend.text\n", - " # Make the plot draw the figure\n", - " self._plot.get_figure(backend=self._backend_name, clear_fig=False)\n", - " # Restore our text attribute\n", - " self.text = self_text\n", + "For a complex framework you might take inspiration from the already implemented backends in `sisl.viz.figure.*`.\n", " \n", - " def show(self):\n", - " print(self.text)" - ] - }, - { - "cell_type": "markdown", - "id": "201f4a52", - "metadata": {}, - "source": [ - "This could very well be our generic backend for the \"text\" framework. Now we can use the knowledge of the previous section to create a backend for the bands plot:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ac8369ad", - "metadata": {}, - "outputs": [], - "source": [ - "class TextBandsBackend(BandsBackend, TextBackend):\n", - " \n", - " def draw_gap(self, ks, Es, name, **kwargs):\n", - " self.draw_line(ks, Es, name=name)\n", - " \n", - "# Register it, as always\n", - "BandsPlot.backends.register(\"text\", TextBandsBackend)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c3d3d1db", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "band_struct.plot(backend=\"text\", gap=True, _debug=True)" - ] - }, - { - "cell_type": "markdown", - "id": "c2a7a05b", - "metadata": {}, - "source": [ - "And everything works great! Note that since the backend is **independent of the processing logic**, I can use any setting of `BandsPlot` and it will work:" + "" ] }, { "cell_type": "code", "execution_count": null, - "id": "4d420a8d", + "id": "543c11eb-c0ed-43bf-88bb-df1b7d7e982d", "metadata": {}, "outputs": [], - "source": [ - "bands_plot = band_struct.plot(backend=\"text\", gap=True, _debug=True)\n", - "bands_plot.update_settings(\n", - " bands_range=[0,1], \n", - " custom_gaps=[{\"from\": \"Gamma\", \"to\": \"Gamma\"}, {\"from\": \"X\", \"to\": \"X\"}]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "0f2397d3", - "metadata": {}, - "source": [ - "*It wasn't that difficult, right?*\n", - "\n", - "We are very thankful that you took the time to understand how to build backends on top of the `sisl.viz` framework! Any feedback on it will be highly appreciated and we are looking forward to see your implementations!" - ] + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -562,7 +226,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.15" } }, "nbformat": 4, diff --git a/docs/visualization/viz_module/diy/Building a new plot.ipynb b/docs/visualization/viz_module/diy/Building a new plot.ipynb new file mode 100644 index 0000000000..b257d23cba --- /dev/null +++ b/docs/visualization/viz_module/diy/Building a new plot.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Building a new plot\n", + "-----------" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Following this guide**, you will create a new plot in no time. Remember to check the [introduction notebook to the framework](../basic-tutorials/Demo.ipynb) to understand that:\n", + "- Your plot will support multiple plotting backends.\n", + "- Your plot will only recompute what is needed when its inputs are updated.\n", + "\n", + "Let's get started!\n", + "\n", + "## The tools\n", + "\n", + "We provide you with a set of tools to create plots. The most basic ones are two of them: `get_figure` and `plot_actions`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sisl.viz import get_figure, plot_actions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "They are what support the multibackend framework. Let's try them out:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We create an action.\n", + "action = plot_actions.draw_line(x=[1, 2], y=[3, 4], line={\"color\": \"red\"})\n", + "\n", + "# And then we plot it in a figure\n", + "get_figure(backend=\"plotly\", plot_actions=[action])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Simple, isn't it?\n", + "\n", + "As you might have imagined, we can ask for a matplotlib figure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_figure(backend=\"matplotlib\", plot_actions=[action])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A plot function\n", + "\n", + "It now feels reasonable to pack this very cool implementation this into a function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def a_cool_plot(color=\"red\", backend=\"plotly\"):\n", + "\n", + " action = plot_actions.draw_line(x=[1, 2], y=[3, 4], line={\"color\": color})\n", + "\n", + " return get_figure(backend=backend, plot_actions=[action])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And just like that, **you have your multi framework plot function**. It would be a shame to leave it unused." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a_cool_plot(color=\"green\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What is there left to do then? Remember that we wanted our plot to be a workflow, and currently it isn't.\n", + "\n", + "## From function to `Plot`\n", + "\n", + "To convert our function to a workflow, we need to introduce a new tool, `Plot`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sisl.viz import Plot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It is just an extension of sisl's `Workflow` class (see `sisl.nodes` documentation), so creating a `Plot` from a function is straightforward:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CoolPlot = Plot.from_func(a_cool_plot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now visualize our workflow!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CoolPlot.network.visualize(notebook=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There we go, our first multi-backend, updatable `Plot` :)\n", + "\n", + "Let's use it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot = CoolPlot(color=\"blue\")\n", + "plot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now, the moment we've all been waiting for. Let's update our plot: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(backend=\"matplotlib\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additional methods\n", + "------" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It might be useful sometimes to provide helper methods so that the users can interact quickly with your plot. E.g. to change inputs or to extract some information from it.\n", + "\n", + "In that case, you'll just have to define the plot class with `class` syntax and write the methods as you always do:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class CoolPlot(Plot):\n", + "\n", + " # The function that this workflow will execute\n", + " function = staticmethod(a_cool_plot)\n", + "\n", + " # Additional methods.\n", + " def color_like(self, object):\n", + " \"\"\"Uses the latest AI to change the color of the plot matching a given object\"\"\"\n", + "\n", + " color = None\n", + " if object == \"sun\":\n", + " color = \"orange\"\n", + " elif object == \"grass\":\n", + " color = \"green\"\n", + " else:\n", + " raise ValueError(f\"The AI could not determine the color of {color}\")\n", + "\n", + " return self.update_inputs(color=color)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And then you just use it as you would expect:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot = CoolPlot()\n", + "plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.color_like(\"grass\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complex plots\n", + "-------------\n", + "\n", + "As you probably have noticed, you can go as complex as you wish inside your plot function. If you want to convert it to a `Plot` however, it is important that you encapsulate sub-functionalities into separate functions so that the workflow doesn't become to complex, storing useless data and adding too much overhead (this is generic advice for `Workflow`s).\n", + "\n", + "In `sisl.viz`, you will find plenty of helper functions, specially in `sisl.viz.processors`, that you might benefit from. You might want to check the already implemented plots in `sisl.viz.plots` for inspiration." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_______________\n", + "This next cell is just to create the thumbnail for the notebook in the docs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "thumbnail_plot = plot\n", + "\n", + "if thumbnail_plot:\n", + " thumbnail_plot.show(\"png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/visualization/viz_module/diy/Building a plot class.ipynb b/docs/visualization/viz_module/diy/Building a plot class.ipynb deleted file mode 100644 index 2b9638f53c..0000000000 --- a/docs/visualization/viz_module/diy/Building a plot class.ipynb +++ /dev/null @@ -1,642 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Building a plot class\n", - "=================" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Following this guide**, you will not only **build a very flexible plot class** that you will be able to use in a wide range of cases, but also your class will be automatically recognized by the [graphical interface](https://github.com/pfebrer/sisl-gui). Therefore, **you will get visual interactivity for free**.\n", - "\n", - "
\n", - " \n", - "Warning\n", - " \n", - "Please make sure to read [this brief introduction to sisl's visualization framework](./00-Intro.html) before you go on with this notebook. It will only take a few minutes and you will understand the concepts much easier! :)\n", - " \n", - "
\n", - "\n", - "Let's begin!\n", - "\n", - "Class definition\n", - "------------------------\n", - "\n", - "*Things that don't start in the right way are not likely to end well.*\n", - "\n", - "Therefore, make sure that **all plot classes that you develop inherit from the parent class `Plot`**. \n", - " \n", - "That is, if you were to define a new class to plot, let's say, the happiness you feel for having found this notebook, you would define it as `class HappinessPlot(Plot):`. \n", - " \n", - "In this way, your plots will profit from all the generic methods and processes that are implemented there. The `Plot` class is meant for you to write as little code as possible while still getting a powerful and dynamic representation. \n", - " \n", - "More info on class inheritance: [written explanation](https://www.w3schools.com/python/python_inheritance.asp), [Youtube video](https://www.youtube.com/watch?v=Cn7AkDb4pIU).\n", - "\n", - "*Let's do it!*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sisl.viz.plotly import Plot\n", - "\n", - "class HappinessPlot(Plot):\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*And just like that, you have your first plot class. Let's play with it:*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt = HappinessPlot()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*Well, that seems reasonable. Our plot has no data because our class does not know how to get it yet.*\n", - "\n", - "*However, we can already do all the general things a plot is expected to do:*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(plt)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.update_layout(xaxis_title = \"Meaningless axis (eV)\",\n", - " xaxis_showgrid = True, xaxis_gridcolor = \"red\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*If you are done generating and playing with useless plot classes, let's continue our way to usefulness...*\n", - "\n", - "Parameters\n", - "---------------\n", - "\n", - "*It is only when you define something that it begins to exist.*\n", - "\n", - "Before starting to write methods for our new class, we will **write the parameters that define it**. We will store them in a **class variable** called `_parameters`. Here is the definition of the `_parameters` variable that your class should contain:\n", - "\n", - "`_parameters` (tuple of InputFields): it contains all the parameters that the user can tweak in your analysis. Each parameter or setting should use an input field object (see the cell below to see types of input fields that you can use). Why do we need to do it like this? Well, this has three main purposes:\n", - "\n", - "- If you use an input field, the graphical interface already knows how to display it.\n", - "- It will make documentation very consistent in the long term.\n", - "- You will be able to access their values very easily at any point in the plot's methods.\n", - "- Helpful methods can be implemented to input fields to facilitate some recurrent work on the inputs.\n", - "\n", - "*Let's begin populating our HappinessPlot class:*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# These are some input fields that are available to you. \n", - "# The names are quite self-explanatory\n", - "from sisl.viz.plotly.input_fields import TextInput, SwitchInput, \\\n", - " ColorPicker, DropdownInput, IntegerInput, FloatInput, \\\n", - " RangeSlider, QueriesInput, ProgramaticInput\n", - "\n", - "class HappinessPlot(Plot):\n", - " \n", - " # The _plot_type variable is the name that will be displayed for the plot\n", - " # If not present, it will be the class name (HappinessPlot).\n", - " _plot_type = \"Happiness Plot\"\n", - " \n", - " _parameters = (\n", - " \n", - " # This is our first parameter\n", - " FloatInput(\n", - " # \"key\" will allow you to identify the parameter during your data processing\n", - " # (be patient, we are getting there)\n", - " key=\"init_happiness\",\n", - " # \"name\" is the name that will be displayed (because, you know, \n", - " # init_happiness is not a beautiful name to show to non-programmers) \n", - " name=\"Initial happiness level\",\n", - " # \"default\" is the default value for the parameter\n", - " default=0,\n", - " # \"help\" is a helper message that will be displayed to the user when\n", - " # they don't know what the parameter means. It will also be used in\n", - " # the automated docs of the plot class.\n", - " help=\"This is your level of happiness before reading this notebook.\",\n", - " ),\n", - " \n", - " # This is our second parameter\n", - " SwitchInput(\n", - " key=\"read_notebook\",\n", - " name=\"Notebook has been read?\",\n", - " default=False,\n", - " help=\"Whether you have read the DIY notebook yet.\",\n", - " )\n", - " \n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*Now we have something! Let's check if it works:*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt = HappinessPlot( init_happiness = 3 )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "print(plt)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*You can see that our settings have appeared, but they are still meaningless, let's continue.*" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Flow methods\n", - "-----\n", - "\n", - "*Is this class just a poser or does it actually do something?*\n", - "\n", - "After defining the parameters that our analysis will depend on and that the user will be able to tweak, we can proceed to actually using them to **read, process and show data**.\n", - "\n", - "As mentioned in the [introductory page](./00-Intro.html), the `Plot` class will control the flow of our plot and will be in charge of managing how it needs to behave at each situation. Because `Plot` is an experienced class that has seen many child classes fail, it knows all the things that can go wrong and what is the best way to do things. Therefore, **all the methods called by the user** will actually be **methods of** `Plot`, not our class. \n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "Don't worry, this is just true for the main plotting flow! Besides that, **you can add as much public methods as you wish** to make the usage of your class much more convenient.\n", - " \n", - "
\n", - "\n", - "However, `Plot` is of course not omniscient, so it needs the help of your class to do the particular analysis that you need. During the workflow, **there are many points where Plot will try to use methods of your class**, and that is where you can do the processing required for your plots. At first, this might seem annoying and limiting, but the flexibility provided is very high and in this way you can be 100% sure that your code is ran in the right moments without having to think much about it.\n", - "\n", - "The flow of the `Plot` class is quite simple. There are three main steps represented by three different methods: `read_data`, `set_data` and `get_figure`. The names can already give you a first idea of what each step does, but let's get to the details of each method and show you where you will be able to do your magic:\n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "Following, you will find advice of what to do at each point of the workflow. But really, do whatever you need to do, don't feel limited by our advice!\n", - "\n", - "
\n", - "\n", - "- `.__init__()`, *the party starter*:\n", - "\n", - " Of course, before ever thinking of doing things with your plot, we need to initialize it. On initialization, your plot will inherit everything from the parent classes, and all the parameters under the `_parameters` variable (both in your class and in `Plot`) will be transferred to `self.settings`, a dictionary that will contain all the current values for each parameter. You will also get a full copy of `_parameters` under `self.params`, in case you need to check something at any point. \n", - " \n", - "
\n", - " \n", - " Warning\n", - " \n", - " Please **don't ever use `_parameters`** directly, as you would have the risk of **changing the parameters for the whole class**, not only your plot.\n", - "\n", - "
\n", - " \n", - " You should let `Plot.__init__()` do its thing, but after it is done, you have the first place where you can act. If your class has an `_after_init` method, it will be ran at this point. This is a good place to intialize your plot attributes if you are a clean coder and don't initialize attributes all over the place. But hey, we don't judge!\n", - " \n", - "\n", - "- `.read_data()`, *the heavy lifter*:\n", - "\n", - " This method will probably be **the most time and resource consuming** of your class, therefore we need to make sure that we **store all the important things inside our object** so that we don't have to use it frequently, only if there is a change in the reading method or the files that must be read.\n", - " \n", - " Our advice is that, at the end of this method, you end up with a [pandas dataframe](https://www.learnpython.org/en/Pandas%20Basics), [xarray Dataarray or Dataset](http://xarray.pydata.org/en/stable/) or whatever other **ordered way** to store the data, so that later operations that need to be run more frequently and will query bits of this data can be performed in a quick and efficient manner.\n", - " \n", - " `read_data` is a polite method, so it will let you do something first if you need to by using the `_before_read` method. We have not thinked of something that would be good to do here yet, but you may need it, so there you have it...\n", - " \n", - " After that, it will attempt to **initialize the plot from the different entry points** until it finds one that succeeds. Entry points are signalled with the `entry_point` wrapper, as follows:\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sisl.viz.plotly.plot import entry_point\n", - "\n", - "class HappinessPlot(Plot):\n", - " \n", - " @entry_point(\"my first entry point\") # This is the name of the entry point\n", - " def _just_continue():\n", - " \"\"\"Some docs for the entry point\"\"\"\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
\n", - " \n", - "Note\n", - " \n", - "The order in which `read_data` goes through entry points is the same in which you have defined them.\n", - "\n", - "
\n", - "\n", - "When an entry point succeeds (that is, ends without raising an exception), you will get the source of the data under `self.source` for if you need to know it further down in your processing. Then `Plot` will let you have one last word with the `_after_read` method, before moving on to the next step. This is a good point to update `self.params` or `self.settings` **according to the data you have read**. For instance, in a PDOS plot the orbitals, atomic species and atoms available are only known after you have read the data, so you will use `_after_read` to set the options of the corresponding input fields." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- `.set_data()`, *the picky one*:\n", - "\n", - " Great! You have all the data in the world now stored in your plot object, but you sure enough don't want to plot it all. And even if you did, you probably don't want to display just all the numbers on the screen. In this step you should pick the data that you need from `self.df` (or whatever place you have stored your data), and organize it in plot elements (i.e. lines, scatter points, bars, pies...).\n", - " \n", - " In this method, you should end up populating the plot with all traces that are related to the data. Keep in mind that our plot is an extension of plotly's `Figure`, so you can use any of the methods they provide to add data. The most common ones are `self.add_trace()` and `self.add_traces()`, but they have plenty of them, so you can check [their documentation](https://plotly.com/python/). \n", - " \n", - " You are kind of alone in this step, as `Plot` will only ensure that the basics are there and execute your `_set_data()` method. By the way, you don't need to worry about cleaning previous data. Each time `_set_data` is called all traces are removed.\n", - " \n", - " \n", - "- `.get_figure()`, *the beautifier*:\n", - " \n", - " You can rest now, all the work is done. `Plot` will not need to do anything here, but other subclasses like `Animation` might need to set up some things in the figure.\n", - " \n", - " But hey, you still get the chance to give a final touch to your work with `._after_get_figure`, which is executed after the figure is built and before showing it to the world. You may want to add annotations, lines that highlight facts about your plot or whatever other thing here. By keeping it separate from the actual processing of your data, setting updates that only concern `._after_get_figure` will get executed much faster.\n", - "\n", - "## Accessing settings\n", - "\n", - "When you need to access the value of a setting inside a method, just add it as an argument." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class HappinessPlot(Plot):\n", - " \n", - " _parameters = (\n", - " IntegerInput(key=\"n\", name=\"Just a test setting\", default=3)\n", - " ,)\n", - " \n", - " def _method_that_uses_n(self, n):\n", - " pass\n", - " \n", - " def _method_that_uses_n_and_provides_default(self, n=5):\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then the values will be directly provided to you and their use will be registered so that `Plot` knows what to run when a given setting is updated.\n", - "\n", - "After some thought, this turned up to be the best way of managing settings because **it allows you to use the methods even if you are not inside the plot class**. \n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "The defaults specified in the method are ignored if the method is called within the plot instance. I.e: in `_method_that_uses_n_and_provides_default`, `n` will default to:\n", - " \n", - "- `3` if it's called from the plot instance.\n", - "- `5` if the method is used externally.\n", - " \n", - "
\n", - "\n", - "*Wow, that was long...* \n", - "\n", - "It might seem intimidating, but rest assured that your life will be **extremely easy after this**. Let's see an example of how to apply the knowledge that we acquired to our class:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class HappinessPlot(Plot):\n", - " \n", - " _plot_type = \"Happiness Plot\"\n", - " \n", - " _parameters = (\n", - " \n", - " FloatInput(\n", - " key=\"init_happiness\", \n", - " name=\"Initial happiness level\",\n", - " default=0,\n", - " help=\"This is your level of happiness before reading this notebook.\",\n", - " ),\n", - " \n", - " SwitchInput(\n", - " key=\"read_notebook\",\n", - " name=\"Notebook has been read?\",\n", - " default=False,\n", - " help=\"Whether you have read the DIY notebook yet.\",\n", - " )\n", - " \n", - " )\n", - " \n", - " # The _layout_defaults allow you to provide some default values\n", - " # for your plot's layout (See https://plotly.com/python/creating-and-updating-figures/#the-layout-key \n", - " # and https://plotly.com/python/reference/#layout)\n", - " # Let's help the users understand what they are seeing with axes titles\n", - " _layout_defaults = {\n", - " \"yaxis_title\": \"Happiness level\",\n", - " \"xaxis_title\": \"Time\"\n", - " }\n", - " \n", - " @entry_point(\"Previously happy\")\n", - " def _init_with_happiness(self, init_happiness):\n", - " \"\"\"Given that you were happy enough, sets the background color pink\"\"\"\n", - " if init_happiness <= 0:\n", - " raise ValueError(f\"Your level of happiness ({init_happiness}) is not enough to use this entry point.\")\n", - " self.update_layout(paper_bgcolor=\"pink\")\n", - " \n", - " \n", - " @entry_point(\"Being sad\")\n", - " def _init_with_sadness(self, init_happiness):\n", - " \"\"\"Lets you in if you're sad, that's all.\"\"\"\n", - " if init_happiness > 0:\n", - " raise ValueError(f\"You are too intrinsically happy to use this entry point\")\n", - " pass\n", - " \n", - " def _set_data(self, init_happiness, read_notebook):\n", - " # The _set_data method needs to generate the plot elements\n", - " # (in this case, a line)\n", - " \n", - " #Calculate the final happiness based on the settings values\n", - " if read_notebook:\n", - " final_happiness = (init_happiness + 1) * 100\n", - " else:\n", - " final_happiness = init_happiness\n", - " \n", - " # Define a line that goes from the initial happiness to the final happiness\n", - " self.add_trace({\n", - " # The type of element\n", - " 'type': 'scatter',\n", - " # Draw a line\n", - " 'mode': 'lines+markers',\n", - " # The values for Y (X will be automatic, we don't care now)\n", - " 'y': [init_happiness, final_happiness],\n", - " # Name that shows in the legend\n", - " 'name': 'Happiness evolution',\n", - " # Other plotly fancy stuff that we don't really need\n", - " 'hovertemplate': 'Happiness level: %{y}',\n", - " 'line': {\"color\": \"red\" if final_happiness <= init_happiness else \"green\"}\n", - " \n", - " })" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*And just like this, we have our first \"meaningful\" plot!*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt = HappinessPlot()\n", - "plt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.update_settings(init_happiness=100 ,read_notebook=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*Simplicity is great, but that is too simple... let's add more things to our plot!*" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Additional public methods\n", - "\n", - "You might feel like you are always at the mercy of the `Plot` class, but that's not completely true.`Plot` expects your class to have certain methods and automatically provides your class with useful plot manipulation methods, but **you can always add methods that you think will be helpful for users that will use your particular plot**.\n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "If you believe that a method can be useful for plots other than yours, consider contributing it to the `Plot` class :)\n", - " \n", - "
\n", - "\n", - "Let's see how this could work with our happiness plot. We will add a method `read_notebook`, which simulates that we just read the notebook. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class HappinessPlot(Plot):\n", - " \n", - " _plot_type = \"Happiness Plot\"\n", - " \n", - " _parameters = (\n", - " \n", - " FloatInput(\n", - " key=\"init_happiness\", \n", - " name=\"Initial happiness level\",\n", - " default=0,\n", - " help=\"This is your level of happiness before reading this notebook.\",\n", - " ),\n", - " \n", - " SwitchInput(\n", - " key=\"read_notebook\",\n", - " name=\"Notebook has been read?\",\n", - " default=False,\n", - " help=\"Whether you have read the DIY notebook yet.\",\n", - " )\n", - " \n", - " )\n", - " \n", - " _layout_defaults = {\n", - " \"yaxis_title\": \"Happiness level\",\n", - " \"xaxis_title\": \"Time\"\n", - " }\n", - " \n", - " @entry_point(\"Previously happy\")\n", - " def _init_with_happiness(self, init_happiness):\n", - " \"\"\"Given that you were happy enough, sets the background color pink\"\"\"\n", - " if init_happiness <= 0:\n", - " raise ValueError(f\"Your level of happiness ({init_happiness}) is not enough to use this entry point.\")\n", - " self.update_layout(paper_bgcolor=\"pink\")\n", - " \n", - " \n", - " @entry_point(\"Being sad\")\n", - " def _init_with_sadness(self, init_happiness):\n", - " \"\"\"Lets you in if you're sad, that's all.\"\"\"\n", - " if init_happiness > 0:\n", - " raise ValueError(f\"You are too intrinsically happy to use this entry point\")\n", - " pass\n", - " \n", - " def _set_data(self, init_happiness, read_notebook):\n", - " \n", - " #Calculate the final happiness based on the settings values\n", - " if read_notebook:\n", - " final_happiness = (init_happiness + 1) * 100\n", - " else:\n", - " final_happiness = init_happiness\n", - " \n", - " # Define a line that goes from the initial happiness to the final happiness\n", - " self.add_trace({\n", - " 'type': 'scatter',\n", - " 'mode': 'lines+markers',\n", - " 'y': [init_happiness, final_happiness],\n", - " 'name': 'Happiness evolution',\n", - " 'hovertemplate': 'Happiness level: %{y}',\n", - " 'line': {\"color\": \"red\" if final_happiness <= init_happiness else \"green\"}\n", - " \n", - " })\n", - " \n", - " def read_notebook(self, location=\"your computer\"):\n", - " \"\"\"Method that 'reads the notebook'.\"\"\"\n", - " import time\n", - " \n", - " # Let's do a little show\n", - " print(f\"Reading the notebook in {location}...\")\n", - " time.sleep(3)\n", - " self.update_settings(read_notebook=True)\n", - " print(\"Read\")\n", - " \n", - " return self" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt = HappinessPlot()\n", - "plt.show(\"png\")\n", - "plt.read_notebook()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Congratulations, you know everything now!**\n", - "\n", - "*Well, not really, because there are still some things missing like adding keyboard shortcuts or default animations. But yeah, you know some things...*\n", - "\n", - "*Just kidding, this is more than enough to get you started! Try to build your own plots and come back for more tutorials when you feel like it. We'll be waiting for you.*\n", - "\n", - "
\n", - " \n", - "Note\n", - " \n", - "Note that this plot class that we built here **is directly usable by the** [graphical user interface](https://github.com/pfebrer/sisl-gui). So its use does not end in a python script.\n", - "\n", - "
\n", - "\n", - "Cheers, checkin' out!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/visualization/viz_module/index.rst b/docs/visualization/viz_module/index.rst index 0045fc3311..bca26717d0 100644 --- a/docs/visualization/viz_module/index.rst +++ b/docs/visualization/viz_module/index.rst @@ -35,6 +35,7 @@ The following notebooks will help you develop a deeper understanding of what eac :name: viz-plotly-showcase-gallery showcase/GeometryPlot.ipynb + showcase/SitesPlot.ipynb showcase/GridPlot.ipynb showcase/BandsPlot.ipynb showcase/FatbandsPlot.ipynb @@ -51,6 +52,7 @@ we dedicate this section to it with the hope of making the usage of it less conf :name: viz-plotly-blender blender/Getting started.rst + blender/First animation.rst Combining plots ^^^^^^^^^^^^^^^ @@ -61,7 +63,7 @@ to the right place! .. nbgallery:: :name: viz-plotly-combining-plots-gallery - combining-plots/Intro to multiple plots.ipynb + combining-plots/Intro to combining plots.ipynb Do it yourself ^^^^^^^^^^^^^^ diff --git a/docs/visualization/viz_module/showcase/BandsPlot.ipynb b/docs/visualization/viz_module/showcase/BandsPlot.ipynb index 014b12bf99..e7a469e051 100644 --- a/docs/visualization/viz_module/showcase/BandsPlot.ipynb +++ b/docs/visualization/viz_module/showcase/BandsPlot.ipynb @@ -84,7 +84,7 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(Erange=[-10, 10])" + "bands_plot.update_inputs(Erange=[-10, 10])" ] }, { @@ -102,7 +102,7 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(bands_range=[6, 15], Erange=None)" + "bands_plot.update_inputs(bands_range=[6, 15], Erange=None)" ] }, { @@ -120,7 +120,7 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(E0=-10, bands_range=None, Erange=None)" + "bands_plot.update_inputs(E0=-10, bands_range=None, Erange=None)" ] }, { @@ -137,7 +137,7 @@ "outputs": [], "source": [ "# Set them back to \"normal\"\n", - "bands_plot = bands_plot.update_settings(E0=0, bands_range=None, Erange=None)" + "bands_plot = bands_plot.update_inputs(E0=0, bands_range=None, Erange=None)" ] }, { @@ -151,9 +151,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Quick styling\n", + "## Bands styling\n", "\n", - "If all you want is to change the color and width of the bands, there's one simple solution: use the `bands_color` and `bands_width` settings.\n", + "If all you want is to change the color and width of the bands, there's one simple solution: use the `bands_style` input to tweak the line styles.\n", "\n", "Let's show them in red:" ] @@ -164,7 +164,7 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(bands_color=\"red\")" + "bands_plot.update_inputs(bands_style={\"color\": \"red\"})" ] }, { @@ -177,19 +177,19 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(bands_color=\"green\", bands_width=3)" + "bands_plot.update_inputs(bands_style={\"color\": \"green\", \"width\": 3})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "If you have spin polarized bands, `bands_color` will modify the color of the first spin channel, while the second one can be tuned with `spindown_color`." + "If you have spin polarized bands, `bands_style` will tweak the colors for the first spin channel, while the second one can be tuned with `spindown_style`.\n", + "\n", + "Finally, you can pass functions to the keys of `bands_style` to customize the styles on a band basis, or even on a point basis. The functions should accept `data` as an argument, which will be an `xarray.Dataset` containing all the bands data. It should then return a single value or an array of values. It is best shown with examples. Let's create a function just to see what we receive as an input:" ] }, { @@ -198,18 +198,21 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot = bands_plot.update_settings(bands_color=\"black\", bands_width=1)" + "def color(data):\n", + " \"\"\"Dummy function to see what we receive.\"\"\"\n", + " print(data)\n", + " return \"green\"\n", + "\n", + "bands_plot.update_inputs(bands_style={\"color\": color})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Displaying the smallest gaps\n", - "\n", - "The easiest thing to do is to let `BandsPlot` discover where the (minimum) gaps are.\n", - "\n", - "This is indicated by setting the `gap` parameter to `True`. One can also use `gap_color` if a particular color is desired." + "So, you can see that we receive a `Dataset`. The most important variable is `E`, which contains the energy (that depends on `k` and `band`). Let's now play with it to do some custom styling:\n", + "- The **color** will be determined by **the slope of the band**.\n", + "- We will plot **bands that are closer to the fermi level bigger** because they are more important." ] }, { @@ -218,20 +221,37 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(gap=True, gap_color=\"green\", Erange=[-10,10]) # We reduce Erange just to see it better" + "def gradient(data):\n", + " \"\"\"Function that computes the absolute value of dE/dk.\n", + " \n", + " This returns a two dimensional array (gradient depends on k and band)\n", + " \"\"\"\n", + " return abs(data.E.differentiate(\"k\"))\n", + "\n", + "def band_closeness_to_Ef(data):\n", + " \"\"\"Computes how close one band is to the fermi level.\n", + " \n", + " This returns a one dimensional array (distance depends only on band)\n", + " \"\"\"\n", + " dist_from_Ef = abs(data.E).min(\"k\")\n", + " \n", + " return (1 / dist_from_Ef ** 0.4) * 5\n", + "\n", + "# Now we are going to set the width of the band according to the distance from the fermi level\n", + "# and the color according to the gradient. We are going to set the colorscale also, instead of using\n", + "# the default one.\n", + "bands_plot.update_inputs(\n", + " bands_style={\"width\": band_closeness_to_Ef, \"color\": gradient}, \n", + " colorscale=\"temps\",\n", + " Erange=[-10, 10]\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This displays the minimum gaps. However there may be some issues with it: it will show **all** gaps with the minimum value. That is, if you have repeated points in the brillouin zone it will display multiple gaps that are equivalent. \n", - "\n", - "What's worse, if the region where your gap is is very flat, two consecutive points might have the same energy. Multiple gaps will be displayed one glued to another.\n", - "\n", - "To help cope with this issues, you have the `direct_gaps_only` and `gap_tol`.\n", - "\n", - "In this case, since we have no direct gaps, setting `direct_gaps_only` will hide them all:" + "You can see that by providing callables the possibilities are endless, you are only limited by your imagination!" ] }, { @@ -240,14 +260,18 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(direct_gaps_only=True)" + "bands_plot = bands_plot.update_inputs(bands_style={})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This example is not meaningful for `gap_tol`, but it is illustrative of what `gap_tol` does. It is the **minimum k-distance between two points to consider them \"the same point\"** in the sense that only one of them will be used to show the gap. In this case, if we set `gap_tol` all the way up to 3, the plot will consider the two gamma points to be part of the same \"point\" and therefore it will only show the gap once." + "## Displaying the smallest gaps\n", + "\n", + "The easiest thing to do is to let `BandsPlot` discover where the (minimum) gaps are.\n", + "\n", + "This is indicated by setting the `gap` parameter to `True`. One can also use `gap_color` if a particular color is desired." ] }, { @@ -256,14 +280,20 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(direct_gaps_only=False, gap_tol=3)" + "bands_plot.update_inputs(gap=True, gap_color=\"green\", Erange=[-10,10])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This is not what `gap_tol` is meant for, since it is thought to remediate the effect of locally flat bands, but still you can get the idea of what it does." + "This displays the minimum gaps. However there may be some issues with it: it will show **all** gaps with the minimum value. That is, if you have repeated points in the brillouin zone it will display multiple gaps that are equivalent. \n", + "\n", + "What's worse, if the region where your gap is is very flat, two consecutive points might have the same energy. Multiple gaps will be displayed one glued to another.\n", + "\n", + "To help cope with this issues, you have the `direct_gaps_only` and `gap_tol`.\n", + "\n", + "In this case, since we have no direct gaps, setting `direct_gaps_only` will hide them all:" ] }, { @@ -272,18 +302,14 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot = bands_plot.update_settings(gap=False, gap_tol=0.01)" + "bands_plot.update_inputs(direct_gaps_only=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Displaying custom gaps\n", - "\n", - "If you are not happy with the gaps that the plot is displaying for you or **you simply want gaps that are not the smallest ones**, you can always use `custom_gaps`.\n", - "\n", - "Custom gaps should be a list where each item specifies how to draw that given gap. See the setting's help message:" + "This example is not meaningful for `gap_tol`, but it is illustrative of what `gap_tol` does. It is the **minimum k-distance between two points to consider them \"the same point\"** in the sense that only one of them will be used to show the gap. In this case, if we set `gap_tol` all the way up to 3, the plot will consider the two gamma points to be part of the same \"point\" and therefore it will only show the gap once." ] }, { @@ -292,14 +318,14 @@ "metadata": {}, "outputs": [], "source": [ - "print(bands_plot.get_param(\"custom_gaps\").help)" + "bands_plot.update_inputs(direct_gaps_only=False, gap_tol=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "So, for example, if we want to plot the gamma-gamma gap:" + "This is not what `gap_tol` is meant for, since it is thought to remediate the effect of locally flat bands, but still you can get the idea of what it does." ] }, { @@ -308,16 +334,20 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(custom_gaps=[{\"from\": \"Gamma\", \"to\": \"Gamma\", \"color\": \"red\"}])" + "bands_plot = bands_plot.update_inputs(gap=False, gap_tol=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Notice how we got the gap probably not where we wanted, since it would be better to have it in the middle `Gamma` point, which is more visible. As the help message of `custom_gaps` states, you can also pass the K value instead of a label.\n", + "## Displaying custom gaps\n", + "\n", + "If you are not happy with the gaps that the plot is displaying for you or **you simply want gaps that are not the smallest ones**, you can always use `custom_gaps`.\n", + "\n", + "Custom gaps should be a list where each item specifies how to draw that given gap. The key labels of each item are `from` and `to`, which specifies the k-points through which you want to draw the gap. The rest of labels are the typical styling labels: `color`, `width`...\n", "\n", - "Now, you'll be happy to know that you can easily access the k values of all labels, as they are stored as attributes in the bands dataarray, which you can find in `bands_plot.bands`:" + "For example, if we want to plot the gamma-gamma gap:" ] }, { @@ -326,14 +356,16 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.bands.attrs" + "bands_plot.update_inputs(custom_gaps=[{\"from\": \"Gamma\", \"to\": \"Gamma\", \"color\": \"red\"}])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now all we need to do is to grab the value for the second gamma point:" + "Notice how we got the gap probably not where we wanted, since it would be better to have it in the middle `Gamma` point, which is more visible. Instead of the K point name, you can also pass the K value.\n", + "\n", + "Now, you'll be happy to know that you can easily access the k values of all labels, as they are stored as part of the attributes of the `k` coordinate in the bands dataarray:" ] }, { @@ -342,127 +374,83 @@ "metadata": {}, "outputs": [], "source": [ - "gap_k = None\n", - "for val, label in zip(bands_plot.bands.attrs[\"ticks\"], bands_plot.bands.attrs[\"ticklabels\"]):\n", - " if label == \"Gamma\":\n", - " gap_k = val\n", - "gap_k" + "bands_plot.nodes['bands_data'].get().k.axis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And use it to build a custom gap:" + "Now all we need to do is to grab the value for the second gamma point:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_settings(custom_gaps=[{\"from\": gap_k, \"to\": gap_k, \"color\": \"orange\"}])" + "axis_info = bands_plot.nodes['bands_data'].get().k.axis\n", + "\n", + "gap_k = None\n", + "for val, label in zip(axis_info[\"tickvals\"], axis_info[\"ticktext\"]):\n", + " if label == \"Gamma\":\n", + " gap_k = val\n", + "gap_k" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Individual band styling\n", - "\n", - "The `bands_color` and `bands_width` should be enough for most uses. However, you may want to style each band differently. Since we can not support every possible case, you can pass a function to the `add_band_data`. Here's the help message:" + "And use it to build a custom gap:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ - "print(bands_plot.get_param(\"add_band_data\").help)" + "bands_plot.update_inputs(custom_gaps=[{\"from\": gap_k, \"to\": gap_k, \"color\": \"orange\"}])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can build a dummy function to print the band and see how it looks like. Notice that you only get those bands that are inside the range specified for the plot, therefore the first band here is band 11!" + "## Displaying spin texture\n", + "\n", + "If your bands plot comes from a non-colinear spin calculation (or is using a `Hamiltonian` with non-colinear spin), you can pass `\"x\"`, `\"y\"` or `\"z\"` to the `spin` setting in order to get a display of the spin texture.\n", + "\n", + "Let's read in a hamiltonian coming from a spin orbit SIESTA calculation, which is obtained from [this fantastic spin texture tutorial](https://github.com/juijan/TopoToolsSiesta/tree/master/Tutorials/Exercise/TI_02):" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "def add_band_data(band, self):\n", - " \"\"\"Dummy function to see the band DataArray\"\"\"\n", - " if band.band == 11:\n", - " print(band)\n", - " \n", - " return {}\n", - "\n", - "bands_plot.update_settings(add_band_data=add_band_data)" - ] - }, - { - "cell_type": "markdown", "metadata": {}, + "outputs": [], "source": [ - "Just as an educational example, we are going to style the bands according to this conditions:\n", - "- If the band is +- 5 eV within the fermi level, we are going to draw markers whose **size is proportional to the gradient of the band** at each point.\n", - "- Otherwise, we will just display the bands as **purple dotted lines that fade** as we get far from the fermi level (just because we can!)\n", - "\n", - "**Note**: Of course, to modify traces, one must have some notion of how plotly traces work. Just hit plotly's visual reference page https://plotly.com/python/ for inspiration." + "import sisl\n", + "siesta_files = sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "def draw_gradient(band, self):\n", - " \"\"\"\n", - " Takes a band and styles it according to its energy dispersion.\n", - " \n", - " NOTE: If it's to far from the fermi level, it fades it in purple for additional coolness. \n", - " \"\"\"\n", - " dist_from_Ef = np.max(abs(band))\n", - " if dist_from_Ef < 5:\n", - " return {\n", - " \"mode\": \"lines+markers\",\n", - " \"marker_size\": np.abs(np.gradient(band))*40,\n", - " }\n", - " else:\n", - " return {\n", - " \"line_color\": \"purple\",\n", - " \"line_dash\": \"dot\",\n", - " \"opacity\": 1-float(dist_from_Ef/10)\n", - " }\n", - " \n", - "bands_plot.update_settings(add_band_data=draw_gradient)" + "H = sisl.get_sile(siesta_files / \"Bi2D_BHex.TSHS\").read_hamiltonian()\n", + "H.spin.is_spinorbit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Displaying spin texture\n", - "\n", - "If your bands plot comes from a non-colinear spin calculation (or is using a `Hamiltonian` with non-colinear spin), you can pass `\"x\"`, `\"y\"` or `\"z\"` to the `spin` setting in order to get a display of the spin texture.\n", - "\n", - "Let's read in a hamiltonian coming from a spin orbit SIESTA calculation, which is obtained from [this fantastic spin texture tutorial](https://github.com/juijan/TopoToolsSiesta/tree/master/Tutorials/Exercise/TI_02):" + "Generate the path for our band structure:" ] }, { @@ -471,15 +459,17 @@ "metadata": {}, "outputs": [], "source": [ - "H = sisl.get_sile(siesta_files / \"Bi2D_BHex.TSHS\").read_hamiltonian()\n", - "H.spin.is_spinorbit" + "band_struct = sisl.BandStructure(H, points=[[1./2, 0., 0.], [0., 0., 0.],\n", + " [1./3, 1./3, 0.], [1./2, 0., 0.]],\n", + " divisions=301,\n", + " names=['M', r'Gamma', 'K', 'M'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Generate the path for our band structure:" + "And finally generate the plot:" ] }, { @@ -488,17 +478,17 @@ "metadata": {}, "outputs": [], "source": [ - "band_struct = sisl.BandStructure(H, points=[[1./2, 0., 0.], [0., 0., 0.],\n", - " [1./3, 1./3, 0.], [1./2, 0., 0.]],\n", - " divisions=301,\n", - " names=['M', r'$\\Gamma$', 'K', 'M'])" + "spin_texture_plot = band_struct.plot.bands(Erange=[-2,2])\n", + "spin_texture_plot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And finally generate the plot:" + "Now it's time to add spin texture to these bands. Remember the section on styling bands? If you haven't checked it, take a quick look at it, because it will come handy now. The main point to take from that section for our purpose here is that each key in the styles accepts a callable.\n", + "\n", + "As in other cases through the `sisl.viz` module, we provide callables that will work out of the box for the most common styling. In this case, what we need is the `SpinMoment` node. We will import it and use it simply by specifying the axis." ] }, { @@ -507,15 +497,21 @@ "metadata": {}, "outputs": [], "source": [ - "spin_texture_plot = band_struct.plot(Erange=[-2,2])\n", - "spin_texture_plot" + "from sisl.viz.data_sources import SpinMoment\n", + "\n", + "spin_texture_plot.update_inputs(\n", + " bands_style={\"color\": SpinMoment(\"x\"), \"width\": 3}\n", + ")\n", + "\n", + "# We hide the legend so that the colorbar can be easily seen.\n", + "spin_texture_plot.update_layout(showlegend=False) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "These are the bands, now let's ask for a particular spin texture:" + "There is nothing magic about the `SpinMoment` node. If you pass a dummy callable as we did in the styling section, you will see that the bands data now contains a `spin_moments` variable since it comes from a non-colinear calculation. It is just a matter of grabbing that variable:" ] }, { @@ -524,14 +520,19 @@ "metadata": {}, "outputs": [], "source": [ - "spin_texture_plot.update_settings(spin=\"x\", bands_width=3)" + "def color(data):\n", + " \"\"\"Dummy function to see what we receive.\"\"\"\n", + " print(data)\n", + " return \"green\"\n", + "\n", + "spin_texture_plot.update_inputs(bands_style={\"color\": color})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And let's change the colorscale for the spin texture:" + "Note that, as shown in the styling section, you can use the `colorscale` input to change the colorscale, or use the `SpinMoment` node for the other styling keys. For example, we can set the width of the band to display whether there is some spin moment, and the color can show the sign." ] }, { @@ -540,7 +541,9 @@ "metadata": {}, "outputs": [], "source": [ - "spin_texture_plot.update_settings(backend=\"plotly\", spin_texture_colorscale=\"temps\")" + "spin_texture_plot.update_inputs(\n", + " bands_style={\"color\": SpinMoment(\"x\"), \"width\": abs(SpinMoment(\"x\")) * 40}\n", + ").update_layout(showlegend=False)" ] }, { @@ -551,6 +554,8 @@ ] }, "source": [ + "Notice how we did some postprocessing to adapt the values of the spin moment to some number that is suitable for the width. This is possible thanks to the magic of nodes!\n", + "\n", "We hope you enjoyed what you learned!" ] }, @@ -592,7 +597,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -606,9 +611,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb b/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb index 2a4b696809..14811a063f 100644 --- a/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb +++ b/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb @@ -77,7 +77,7 @@ "source": [ "band = sisl.BandStructure(H, [[0., 0.], [2./3, 1./3],\n", " [1./2, 1./2], [1., 1.]], 301,\n", - " [r'$\\Gamma$', 'K', 'M', r'$\\Gamma$'])" + " [r'Gamma', 'K', 'M', r'Gamma'])" ] }, { @@ -90,9 +90,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "fatbands = band.plot.fatbands()\n", @@ -112,7 +110,7 @@ "source": [ "## Requesting specific weights\n", "\n", - "The fatbands that the plot draws are controlled by the `groups` setting." + "The fatbands that the plot draws are controlled by the `groups` setting. This setting works exactly like the `groups` setting in `PdosPlot`, which is documented [here](./PdosPlot.ipynb). Therefore we won't give an extended description of it, but just quickly show that you can autogenerate the groups:" ] }, { @@ -121,23 +119,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(fatbands.get_param(\"groups\").help)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This setting works exactly like the `requests` setting in `PdosPlot`, which is documented [here](./PdosPlot.ipynb). Therefore we won't give an extended description of it, but just quickly show that you can autogenerate the groups:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fatbands.split_groups(on=\"species\")" + "fatbands.split_orbs(on=\"species\", name=\"$species\")" ] }, { @@ -153,7 +135,7 @@ "metadata": {}, "outputs": [], "source": [ - "fatbands.update_settings(groups=[\n", + "fatbands.update_inputs(groups=[\n", " {\"species\": \"N\", \"color\": \"blue\", \"name\": \"Nitrogen\"},\n", " {\"species\": \"B\", \"color\": \"red\", \"name\": \"Boron\"}\n", "])" @@ -174,23 +156,7 @@ "metadata": {}, "outputs": [], "source": [ - "fatbands.update_settings(scale=2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can also use the `scale_fatbands` method, which additionally lets you choose if you want to rescale from the current size or just set the value of `scale`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fatbands.scale_fatbands(0.5, from_current=True)" + "fatbands.update_inputs(fatbands_scale=2)" ] }, { @@ -199,7 +165,7 @@ "source": [ "## Use BandsPlot settings\n", "\n", - "All settings of `BandsPlot` work as well for `FatbandsPlot`. Even spin texture!" + "All settings of `BandsPlot` work as well for `FatbandsPlot`." ] }, { @@ -251,7 +217,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -265,9 +231,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/visualization/viz_module/showcase/GeometryPlot.ipynb b/docs/visualization/viz_module/showcase/GeometryPlot.ipynb index 44989f0a3d..7c8edda812 100644 --- a/docs/visualization/viz_module/showcase/GeometryPlot.ipynb +++ b/docs/visualization/viz_module/showcase/GeometryPlot.ipynb @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xy\")" + "plot.update_inputs(axes=\"xy\")" ] }, { @@ -109,12 +109,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"x\")" + "plot.update_inputs(axes=\"x\",)" ] }, { @@ -126,17 +124,6 @@ "It can be an array that **explicitly sets the values**:" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "plot.update_settings(axes=\"x\", dataaxis_1d=plot.geometry.atoms.Z)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -147,12 +134,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(dataaxis_1d=np.sin)" + "plot.update_inputs(dataaxis_1d=np.sin)" ] }, { @@ -168,7 +153,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xyz\")" + "plot.update_inputs(axes=\"xyz\")" ] }, { @@ -187,7 +172,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=[\"x\", \"y\"])" + "plot.update_inputs(axes=[\"x\", \"y\"])" ] }, { @@ -203,7 +188,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xy\")" + "plot.update_inputs(axes=\"xy\")" ] }, { @@ -219,7 +204,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"yx\")" + "plot.update_inputs(axes=\"yx\")" ] }, { @@ -235,7 +220,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"ab\")" + "plot.update_inputs(axes=\"ab\")" ] }, { @@ -251,7 +236,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=[[1,1,0], [1, -1, 0]])" + "plot.update_inputs(axes=[[1,1,0], [1, -1, 0]])" ] }, { @@ -267,7 +252,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=[[1,1,0], [2, -2, 0]])" + "plot.update_inputs(axes=[[1,1,0], [2, -2, 0]])" ] }, { @@ -283,7 +268,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=[\"x\", [1,1,0]])" + "plot.update_inputs(axes=[\"x\", [1,1,0]])" ] }, { @@ -381,12 +366,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xy\", show_cell=False, show_atoms=False)" + "plot.update_inputs(axes=\"xy\", show_cell=False, show_atoms=False)" ] }, { @@ -406,7 +389,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms=[1,2,3,4,5], show_atoms=True, show_cell=\"axes\")\n", + "plot.update_inputs(atoms=[1,2,3,4,5], show_atoms=True, show_cell=\"axes\")\n", "#show_cell accepts \"box\", \"axes\" and False" ] }, @@ -425,7 +408,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms={\"neighbours\": 3}, show_cell=\"box\")" + "plot.update_inputs(atoms={\"neighbours\": 3}, show_cell=\"box\")" ] }, { @@ -441,7 +424,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(bind_bonds_to_ats=False)" + "plot.update_inputs(bind_bonds_to_ats=False)" ] }, { @@ -450,7 +433,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot = plot.update_settings(atoms=None, bind_bonds_to_ats=True)" + "plot = plot.update_inputs(atoms=None, bind_bonds_to_ats=True)" ] }, { @@ -470,7 +453,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms_scale=0.6)" + "plot.update_inputs(atoms_scale=0.6)" ] }, { @@ -479,7 +462,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms_scale=1)" + "plot.update_inputs(atoms_scale=1)" ] }, { @@ -499,7 +482,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms=None, axes=\"yx\", atoms_style={\"color\": \"green\", \"size\": 14})" + "plot.update_inputs(atoms=None, axes=\"yx\", atoms_style={\"color\": \"green\", \"size\": 0.6})" ] }, { @@ -512,12 +495,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms_style={\"color\": \"green\", \"size\": [12, 20]})" + "plot.update_inputs(atoms_style={\"color\": \"green\", \"size\": [0.6, 0.8]})" ] }, { @@ -535,9 +516,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(\n", + "plot.update_inputs(\n", " atoms_style=[\n", - " {\"color\": \"green\", \"size\": [12, 20], \"opacity\": [1, 0.3]},\n", + " {\"color\": \"green\", \"size\": [0.6, 0.8], \"opacity\": [1, 0.3]},\n", " {\"atoms\": [0,1], \"color\": \"orange\"}\n", " ]\n", ")" @@ -562,7 +543,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms_style=[{\"atoms\": [0,1], \"color\": \"orange\"}])" + "plot.update_inputs(atoms_style=[{\"atoms\": [0,1], \"color\": \"orange\"}])" ] }, { @@ -578,7 +559,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(atoms_style=[\n", + "plot.update_inputs(atoms_style=[\n", " {\"atoms\": {\"fx\": (None, 0.4)}, \"color\": \"orange\"},\n", " {\"atoms\": sisl.geom.AtomOdd(), \"opacity\":0.3},\n", "])" @@ -602,7 +583,7 @@ "# Get the Y coordinates\n", "y = plot.geometry.xyz[:,1]\n", "# And color atoms according to it\n", - "plot.update_settings(atoms_style=[\n", + "plot.update_inputs(atoms_style=[\n", " {\"color\": y}, \n", " {\"atoms\": sisl.geom.AtomOdd(), \"opacity\":0.3},\n", "], atoms_colorscale=\"viridis\")" @@ -628,7 +609,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xyz\")" + "plot.update_inputs(axes=\"xyz\")" ] }, { @@ -646,7 +627,83 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"yx\", bonds_style={\"color\": \"orange\", \"width\": 5, \"opacity\": 0.5})" + "plot.update_inputs(axes=\"yx\", bonds_style={\"color\": \"orange\", \"width\": 5, \"opacity\": 0.5}).get()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As in the case of atoms, the styling attributes can also be lists:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(\n", + " bonds_style={\"color\": ['blue'] * 10 + ['orange'] * 19, \"width\": np.linspace(3, 7, 29)}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, in this case, providing a list is more difficult than in the atoms case, because you don't know beforehand how many bonds are going to be drawn (in this case 29) or which atoms will correspond to each bond.\n", + "\n", + "For this reason, in this case it is much better to provide a callable that receives `geometry` and `bonds` and returns the property:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def color_bonds(geometry: sisl.Geometry, bonds: \"xr.DataArray\"):\n", + " # We are going to color the bonds based on how far they go in the Y axis\n", + " return abs(geometry[bonds[:, 0], 1] - geometry[bonds[:, 1], 1])\n", + "\n", + "plot.update_inputs(bonds_style={\"color\": color_bonds, \"width\": 5})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It is even better to use nodes, because they will not recompute the property if the styles need to be recomputed but the geometry and bonds haven't changed.\n", + "\n", + "In `sisl.viz.data_sources` you can find several `Bond*` nodes already prepared for you. `BondLength` is probably the most common to use, but in this case all bonds have the same length, so we are going to use `BondRandom` just for fun :)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sisl.viz.data_sources import BondLength, BondDataFromMatrix, BondRandom\n", + "\n", + "plot.update_inputs(axes=\"yx\", bonds_style={\"color\": BondRandom(), \"width\": BondRandom() * 10, \"opacity\": 0.5})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As with atoms, you can change the colorscale of the bonds with `bonds_colorscale`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(bonds_colorscale=\"viridis\")" ] }, { @@ -655,10 +712,10 @@ "source": [ "
\n", " \n", - "Coloring individual bonds\n", + "Bicolor bonds\n", " \n", - "It is **not possible to style bonds individually** yet, e.g. using a colorscale. However, it is one of the goals to improve ``GeometryPlot`` and some thought has already been put into how to design the interface to make it as usable as possible. Rest assured that when the right interface is found, coloring individual bonds will be allowed, as well as drawing bicolor bonds, as most rendering softwares do.\n", - "\n", + "Most rendering softwares display **bonds with two colors, one for each half of the bond**. This is not supported yet in `sisl`, but it is probably going to be supported in the future.\n", + " \n", "
" ] }, @@ -668,7 +725,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot = plot.update_settings(axes=\"xyz\", bonds_style={})" + "plot = plot.update_inputs(axes=\"xyz\", bonds_style={})" ] }, { @@ -688,7 +745,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(arrows={\"data\": [0,0,2], \"name\": \"Upwards force\"})" + "plot.update_inputs(arrows={\"data\": [0,0,2], \"name\": \"Upwards force\"})" ] }, { @@ -705,7 +762,7 @@ "outputs": [], "source": [ "forces = np.linspace([0,0,2], [0,3,1], 18)\n", - "plot.update_settings(arrows={\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4})" + "plot.update_inputs(arrows={\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4})" ] }, { @@ -721,7 +778,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(arrows=[\n", + "plot.update_inputs(arrows=[\n", " {\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4},\n", " {\"data\": [0,0,2], \"name\": \"Upwards force\", \"color\": \"red\"}\n", "])" @@ -740,7 +797,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(arrows=[\n", + "plot.update_inputs(arrows=[\n", " {\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4},\n", " {\"atoms\": {\"fy\": (0, 0.5)} ,\"data\": [0,0,2], \"name\": \"Upwards force\", \"color\": \"red\"}\n", "])" @@ -759,7 +816,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"yz\")" + "plot.update_inputs(axes=\"yz\")" ] }, { @@ -794,7 +851,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xyz\", nsc=[2,1,1])" + "plot.update_inputs(axes=\"xyz\", nsc=[2,1,1])" ] }, { @@ -842,7 +899,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -856,9 +913,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/visualization/viz_module/showcase/GridPlot.ipynb b/docs/visualization/viz_module/showcase/GridPlot.ipynb index 1e4dcea46f..b9a924db5a 100644 --- a/docs/visualization/viz_module/showcase/GridPlot.ipynb +++ b/docs/visualization/viz_module/showcase/GridPlot.ipynb @@ -32,7 +32,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "import sisl\n", @@ -64,7 +66,7 @@ "plot = grid.plot()\n", "\n", "# All siles that implement read_grid can also be directly plotted\n", - "plot = sisl.get_sile(rho_file).plot()" + "plot = sisl.get_sile(rho_file).plot.grid()" ] }, { @@ -109,23 +111,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xy\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that you can make 2d representations look smoother without having to make the grid finer by using the `zsmooth` setting, which is part of plotly's `go.Heatmap` trace options." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot.update_settings(zsmooth=\"best\")" + "plot.update_inputs(axes=\"xy\")" ] }, { @@ -141,7 +127,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xyz\")" + "plot.update_inputs(axes=\"xyz\")" ] }, { @@ -150,7 +136,7 @@ "source": [ "Specifying the axes\n", "----\n", - "You may want to see your grid in The most common one is to display the cartesian coordinates. You indicate that you want cartesian coordinates by passing `{\"x\", \"y\", \"z\"}`. You can pass them as a list or as a multicharacter string:" + "You may want to see your grid in different ways. The most common one is to display the cartesian coordinates. You indicate that you want cartesian coordinates by passing `{\"x\", \"y\", \"z\"}`. You can pass them as a list or as a multicharacter string:" ] }, { @@ -159,7 +145,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xy\")" + "plot.update_inputs(axes=\"xy\")" ] }, { @@ -175,7 +161,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"yx\")" + "plot.update_inputs(axes=\"yx\")" ] }, { @@ -188,12 +174,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"ab\")" + "plot.update_inputs(axes=\"ab\")" ] }, { @@ -220,23 +204,14 @@ "source": [ "## Dimensionality reducing method\n", "\n", - "As we mentioned, the dimensions that are not displayed in the plot are reduced. The setting that controls how this process is done is `reduce_method`. Let's see what are the options:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot.get_param(\"reduce_method\").options" + "As we mentioned, the dimensions that are not displayed in the plot are reduced. The setting that controls how this process is done is `reduce_method`, which can be either `\"average\"` or `\"sum\"`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can test different reducing methods in a 1D representation to see the effects:" + "We can test the different reducing methods in a 1D representation to see the effects:" ] }, { @@ -245,7 +220,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"z\", reduce_method=\"average\")" + "plot.update_inputs(axes=\"z\", reduce_method=\"average\")" ] }, { @@ -254,7 +229,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(reduce_method=\"sum\")" + "plot.update_inputs(reduce_method=\"sum\")" ] }, { @@ -270,7 +245,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot = plot.update_settings(axes=\"xyz\")" + "plot = plot.update_inputs(axes=\"xyz\")" ] }, { @@ -281,16 +256,12 @@ "\n", "There's one parameter that controls both the display of isosurfaces (in 3d) and contours (in 2d): `isos`.\n", "\n", - "`isos` is a list of dicts where each dict asks for an isovalue. See the help message:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(plot.get_param(\"isos\").help)" + "`isos` is a list of dicts where each dict asks for an isovalue. The possible keys are:\n", + "\n", + "- `val`: Value where to draw the isosurface.\n", + "- `frac`: Same as value, but indicates the fraction between the minimum and maximum value, useful if you don't know the range of the values.\n", + "- `color`, `opacity`: control the aesthetics of the isosurface.\n", + "- `name`: Name of the isosurface, e.g. to display on the plot legend." ] }, { @@ -299,7 +270,7 @@ "source": [ "If no `isos` is provided, 3d representations plot the 0.3 and 0.7 (`frac`) isosurfaces. This is what you can see in the 3d plot that we displayed above.\n", "\n", - "Let's play a bit with `isos`. The first thing I will do is change the opacity of the outer isosurface, since there's no way to see the inner one right now (although you can toggle it by clicking at the legend, courtesy of plotly :))." + "Let's play a bit with `isos`. The first thing we will do is to change the opacity of the outer isosurface, since there's no way to see the inner one right now (although you can toggle it by clicking at the legend, courtesy of plotly :))." ] }, { @@ -308,7 +279,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(isos=[{\"frac\": 0.3, \"opacity\": 0.4}, {\"frac\": 0.7}])" + "plot.update_inputs(isos=[{\"frac\": 0.3, \"opacity\": 0.4}, {\"frac\": 0.7}])" ] }, { @@ -326,7 +297,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xy\")" + "plot.update_inputs(axes=\"xy\")" ] }, { @@ -354,7 +325,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(\n", + "plot.update_inputs(\n", " axes=\"xyz\", isos=[{\"frac\":frac, \"opacity\": frac/2, \"color\": \"green\"} for frac in np.linspace(0.1, 0.8, 20)],\n", ")" ] @@ -380,7 +351,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot = plot.update_settings(axes=\"xy\", isos=[])" + "plot = plot.update_inputs(axes=\"xy\", isos=[])" ] }, { @@ -398,7 +369,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(colorscale=\"temps\")" + "plot.update_inputs(colorscale=\"temps\")" ] }, { @@ -416,7 +387,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(crange=[50, 200])" + "plot.update_inputs(crange=[50, 200])" ] }, { @@ -433,12 +404,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(z_range=[1,3])" + "plot.update_inputs(z_range=[1,3])" ] }, { @@ -465,7 +434,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(transforms=[abs, \"numpy.sin\"], crange=None)" + "plot.update_inputs(transforms=[abs, \"numpy.sin\"], crange=None)" ] }, { @@ -481,7 +450,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(transforms=[\"sin\", abs], crange=None) \n", + "plot.update_inputs(transforms=[\"sin\", abs], crange=None) \n", "# If a string is provided with no module, it will be interpreted as a numpy function\n", "# Therefore \"sin\" == \"numpy.sin\" and abs != \"abs\" == \"numpy.abs\"" ] @@ -501,92 +470,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(nsc=[1,3,1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Performing scans\n", - "\n", - "We can use the `scan` method to create a scan of the grid along a given direction." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot.scan(\"z\", num=5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Notice how the scan respected our `z_range` from 1 to 3. If we want the rest of the grid, we can set `z_range` back to `None` before creating the scan, or we can indicate the bounds of the scan." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot.scan(\"z\", start=0, stop=plot.grid.cell[2,2], num=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This is the `\"moving_slice\"` scan, but we can also display the scan as an animation:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "scan = plot.scan(\"z\", mode=\"as_is\", num=15)\n", - "scan" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If we are using the `plotly` backend, the axes of the animation will not be correctly scaled, but we can easily solve this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "scan.update_layout(xaxis_scaleanchor=\"y\", xaxis_scaleratio=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This mode is called `\"as_is\"` because **it creates an animation of the current representation**. That is, it can scan through 1d, 2d and 3d representations and it keeps displaying the supercell.\n", - "\n", - "Here's a scan of 1d data:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot.update_settings(axes=\"z\").scan(\"y\",mode=\"as_is\", num=15)" + "plot.update_inputs(nsc=[1,3,1])" ] }, { @@ -618,7 +502,7 @@ }, "outputs": [], "source": [ - "thumbnail_plot = plot.update_settings(axes=\"yx\", z_range=[1.7, 1.9])\n", + "thumbnail_plot = plot.update_inputs(axes=\"yx\", z_range=[1.7, 1.9])\n", "\n", "if thumbnail_plot:\n", " thumbnail_plot.show(\"png\")" @@ -638,7 +522,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -652,9 +536,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/visualization/viz_module/showcase/PdosPlot.ipynb b/docs/visualization/viz_module/showcase/PdosPlot.ipynb index a776dad424..13205c7cdf 100644 --- a/docs/visualization/viz_module/showcase/PdosPlot.ipynb +++ b/docs/visualization/viz_module/showcase/PdosPlot.ipynb @@ -24,7 +24,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "import sisl\n", @@ -37,7 +39,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We are going to get the PDOS from a SIESTA `.PDOS` file, but we could get it from a hamiltonian as well." + "We are going to get the PDOS from a SIESTA `.PDOS` file, but we could get it from some other source, e.g. a hamiltonian." ] }, { @@ -69,9 +71,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## PDOS requests\n", + "## PDOS groups\n", "\n", - "There's a very important setting in the `PdosPlot`: `requests`. This setting expects a list of PDOS requests, where each request is a dictionary that can specify \n", + "There's a very important setting in the `PdosPlot`: `groups`. This setting expects a list of orbital groups, where each group is a dictionary that can specify \n", "- `species`\n", "- `atoms`\n", "- `orbitals` (the orbital name)\n", @@ -79,9 +81,9 @@ "- `Z` (the Z shell of the orbital)\n", "- `spin`\n", "\n", - "involved in the PDOS line that you want to draw. Apart from that, a request also accepts the `name`, `color`, `linewidth` and `dash` keys that manage the aesthetics of the line and `normalize`, which indicates if the PDOS should be normalized (divided by number of orbitals).\n", + "involved in the PDOS line that you want to draw. Apart from that, a group also accepts the `name`, `color`, `linewidth` and `dash` keys that manage the aesthetics of the line and `reduce`, which indicates the method to use for accumulating orbital contributions: `\"mean\"` averages over orbitals while `\"sum\"` simply accumulates all contributions. Finally, `scale` lets you multiply the DOS of the group by whatever factor you want.\n", "\n", - "Here is an example of how to use the `requests` setting to create a line that displays the Oxygen 2p PDOS:" + "Here is an example of how to use the `groups` setting to create a line that displays the Oxygen 2p PDOS:" ] }, { @@ -90,9 +92,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(requests=[{\"name\": \"My first PDOS (Oxygen)\", \"species\": [\"O\"], \"n\": 2, \"l\": 1}])\n", + "plot.update_inputs(groups=[{\"name\": \"My first PDOS (Oxygen)\", \"species\": [\"O\"], \"n\": 2, \"l\": 1}])\n", "# or (it's equivalent)\n", - "plot.update_settings(requests=[{\n", + "plot.update_inputs(groups=[{\n", " \"name\": \"My first PDOS (Oxygen)\", \"species\": [\"O\"],\n", " \"orbitals\": [\"2pzZ1\", \"2pzZ2\", \"2pxZ1\", \"2pxZ2\", \"2pyZ1\", \"2pyZ2\"]\n", "}])" @@ -111,10 +113,10 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(requests=[\n", - " {\"name\": \"Oxygen\", \"species\": [\"O\"], \"color\": \"darkred\", \"dash\": \"dash\", \"normalize\": True},\n", - " {\"name\": \"Titanium\", \"species\": [\"Ti\"], \"color\": \"grey\", \"linewidth\": 3, \"normalize\": True},\n", - " {\"name\": \"Sr\", \"species\": [\"Sr\"], \"color\": \"green\", \"normalize\": True},\n", + "plot.update_inputs(groups=[\n", + " {\"name\": \"Oxygen\", \"species\": [\"O\"], \"color\": \"darkred\", \"dash\": \"dash\", \"reduce\": \"mean\"},\n", + " {\"name\": \"Titanium\", \"species\": [\"Ti\"], \"color\": \"gray\", \"size\": 3, \"reduce\": \"mean\"},\n", + " {\"name\": \"Sr\", \"species\": [\"Sr\"], \"color\": \"green\", \"reduce\": \"mean\"},\n", "], Erange=[-5, 5])" ] }, @@ -122,7 +124,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It's interesting to note that the `atoms` key of each request accepts the same possibilities as the `atoms` argument of the `Geometry` methods. Therefore, **you can use indices, categories, dictionaries, strings...**\n", + "It's interesting to note that the `atoms` key of each group accepts the same possibilities as the `atoms` argument of the `Geometry` methods. Therefore, **you can use indices, categories, dictionaries, strings...**\n", "\n", "For example:" ] @@ -130,15 +132,13 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "# Let's import the AtomZ and AtomOdd categories just to play with them\n", "from sisl.geom import AtomZ, AtomOdd\n", "\n", - "plot.update_settings(requests=[\n", + "plot.update_inputs(groups=[\n", " {\"atoms\": [0,1], \"name\": \"Atoms 0 and 1\"},\n", " {\"atoms\": {\"Z\": 8}, \"name\": \"Atoms with Z=8\"},\n", " {\"atoms\": AtomZ(8) & ~ AtomOdd(), \"name\": \"Oxygens with even indices\"}\n", @@ -151,7 +151,7 @@ "source": [ "## Easy and fast DOS splitting\n", "\n", - "As you might have noticed, sometimes it might be cumbersome to build all the requests you want. If your needs are simple and you don't need the flexibility of defining every parameter by yourself, there is a set of methods that will help you explore your PDOS data faster than ever before. These are: `split_DOS`, `split_requests`, `update_requests`, `remove_requests` and `add_requests`.?\n", + "As you might have noticed, sometimes it might be cumbersome to build all the groups you want. If your needs are simple and you don't need the flexibility of defining every parameter by yourself, there is a set of methods that will help you explore your PDOS data faster than ever before. These are: `split_DOS`, `split_groups`, `update_groups`, `remove_groups` and `add_groups`.\n", "\n", "Let's begin with `split_DOS`. As you can imagine, this method splits the density of states:" ] @@ -187,7 +187,7 @@ "source": [ "Now we have the contribution of each atom.\n", "\n", - "But here comes the powerful part: `split_DOS` accepts as keyword arguments all the keys that a request accepts. Then, it adds that extra constrain to the splitting by adding the value to each request. So, if we want to get the separate contributions of all oxygen atoms, **we can impose an extra constraint** on species:" + "But here comes the powerful part: `split_DOS` accepts as keyword arguments all the keys that a group accepts. Then, it adds that extra constrain to the splitting by adding the value to each group. So, if we want to get the separate contributions of all oxygen atoms, **we can impose an extra constraint** on species:" ] }, { @@ -205,7 +205,7 @@ "source": [ "and then we have only the oxygen atoms, which are all equivalent.\n", "\n", - "Note that we also set a name for all requests, with the additional twist that we used the templating supported by `split_DOS`. If you are splitting on `parameter`, you can use `$parameter` inside your name and the method will replace it with the value for each request. In this case `parameter` was `atoms`, but it could be anything you are splitting the DOS on.\n", + "Note that we also set a name for all groups, with the additional twist that we used the templating supported by `split_DOS`. If you are splitting on `parameter`, you can use `$parameter` inside your name and the method will replace it with the value for each group. In this case `parameter` was `atoms`, but it could be anything you are splitting the DOS on.\n", "\n", "You can also **exclude some values of the parameter you are splitting on**:" ] @@ -255,11 +255,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Managing existing requests\n", + "## Managing existing groups\n", "\n", - "Not only you can create requests easily with `split_DOS`, but it's also easy to manage the requests that you have created. \n", + "Not only you can create groups easily with `split_DOS`, but it's also easy to manage the groups that you have created. \n", "\n", - "The methods that help you accomplish this are `split_requests`, `update_requests`, `remove_requests`. All three methods accept an undefined number of arguments that are used to select the requests you want to act on. You can refer to requests by their name (using a `str`) or their position (using an `int`). It's very easy to understand with examples. Then, keyword arguments depend on the functionality of each method.\n", + "The methods that help you accomplish this are `split_groups`, `update_groups`, `remove_groups`. All three methods accept an undefined number of arguments that are used to select the groups you want to act on. You can refer to groups by their name (using a `str`) or their position (using an `int`). It's very easy to understand with examples. Then, keyword arguments depend on the functionality of each method.\n", "\n", "For example, let's say that we have splitted the DOS on species" ] @@ -286,16 +286,16 @@ "metadata": {}, "outputs": [], "source": [ - "plot.remove_requests(\"Sr\", 2)" + "plot.remove_groups(\"Sr\", 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We have indicated that we wanted to remove the request with name `\"Sr\"` and the 2nd request. Simple, isn't it?\n", + "We have indicated that we wanted to remove the group with name `\"Sr\"` and the 2nd group. Simple, isn't it?\n", "\n", - "Now that we know how to indicate the requests that we want to act on, let's use it to get the total `Sr` contribution, and then the `Ti` and `O` contributions splitted by `n` and `l`.\n", + "Now that we know how to indicate the groups that we want to act on, let's use it to get the total `Sr` contribution, and then the `Ti` and `O` contributions splitted by `n` and `l`.\n", "\n", "It sounds difficult, but it's actually not. Just split the DOS on species:" ] @@ -306,14 +306,14 @@ "metadata": {}, "outputs": [], "source": [ - "plot.split_DOS(name=\"$species\", normalize=True)" + "plot.split_DOS(name=\"$species\", reduce=\"mean\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And then use `split_requests` to split only the requests that we want to split:" + "And then use `split_groups` to split only the groups that we want to split:" ] }, { @@ -322,27 +322,25 @@ "metadata": {}, "outputs": [], "source": [ - "plot.split_requests(\"Sr\", 2, on=\"n+l\", dash=\"dot\")" + "plot.split_groups(\"Sr\", 2, on=\"n+l\", dash=\"dot\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Notice how we've also set `dash` for all the requests that `split_requests` has generated. We can do this because `split_requests` works exactly as `split_DOS`, with the only difference that splits specific requests.\n", + "Notice how we've also set `dash` for all the groups that `split_groups` has generated. We can do this because `split_groups` works exactly as `split_DOS`, with the only difference that splits specific groups.\n", "\n", - "Just as a last thing, we will let you figure out how `update_requests` works:" + "Just as a last thing, we will let you figure out how `update_groups` works:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ - "plot.update_requests(\"Ti\", color=\"red\", linewidth=2)" + "plot.update_groups(\"Ti\", color=\"red\", size=4)" ] }, { @@ -374,7 +372,7 @@ }, "outputs": [], "source": [ - "thumbnail_plot = plot.update_requests(\"Ti\", color=None, linewidth=1)\n", + "thumbnail_plot = plot.update_groups(\"Ti\", color=None, size=1)\n", "\n", "if thumbnail_plot:\n", " thumbnail_plot.show(\"png\")" @@ -390,11 +388,18 @@ "source": [ "-------------" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -408,9 +413,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 -} \ No newline at end of file + "nbformat_minor": 4 +} diff --git a/docs/visualization/viz_module/showcase/SitesPlot.ipynb b/docs/visualization/viz_module/showcase/SitesPlot.ipynb new file mode 100644 index 0000000000..4ca58075eb --- /dev/null +++ b/docs/visualization/viz_module/showcase/SitesPlot.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "notebook-header" + ] + }, + "source": [ + "[![GitHub issues by-label](https://img.shields.io/github/issues-raw/pfebrer/sisl/SitesPlot?style=for-the-badge)](https://github.com/pfebrer/sisl/labels/SitesPlot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " \n", + " \n", + "SitesPlot\n", + "=========\n", + "\n", + "The `SitesPlot` is simply an adaptation of `GeometryPlot`'s machinery to any class that can be represented as sites in space. The main difference is that it doesn't show bonds, and also inputs with the word `atoms` are renamed to `sites`. Therefore, see `GeometryPlot`'s showcase notebook to understand the full customization possibilities.\n", + "\n", + "We are just going to show how you can plot the k points of a `BrillouinZone` object with it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sisl\n", + "import sisl.viz\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a circle of K points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sisl.geom.graphene()\n", + "\n", + "# Create the circle\n", + "bz = sisl.BrillouinZone.param_circle(\n", + " g,\n", + " kR=0.0085,\n", + " origin= [0.0, 0.0, 0.0],\n", + " normal= [0.0, 0.0, 1.0],\n", + " N_or_dk=25,\n", + " loop=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And then generate some fake vectorial data for it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = np.zeros((len(bz), 3))\n", + "\n", + "data[:, 0] = - bz.k[:, 1]\n", + "data[:, 1] = bz.k[:, 0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now plot the k points, showing the vectorial data as arrows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot k points as sites\n", + "bz.plot.sites(\n", + " axes=\"xy\", drawing_mode=\"line\", sites_style={\"color\": \"black\", \"size\": 2},\n", + " arrows={\"data\": data, \"color\": \"red\", \"width\": 3, \"name\": \"Force\"}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----\n", + "This next cell is just to create the thumbnail for the notebook in the docs " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "nbsphinx-thumbnail" + ] + }, + "outputs": [], + "source": [ + "thumbnail_plot = _\n", + "\n", + "if thumbnail_plot:\n", + " thumbnail_plot.show(\"png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "notebook-footer" + ] + }, + "source": [ + "-------------" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb b/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb index c0f0385864..37190f3623 100644 --- a/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb +++ b/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb @@ -26,6 +26,8 @@ " \n", "`WavefunctionPlot` is just an extension of `GridPlot`, so everything in [the GridPlot notebook](./GridPlot.html) applies and this notebook **will only display the additional features**.\n", "\n", + "`WavefunctionPlot` changes the defaults of the `axes` and `plot_geom` inputs so that by default the grid is shown in 3D and displaying the geometry.\n", + "\n", "" ] }, @@ -86,7 +88,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "That truly is an ugly wavefunction." + "This is a very delocalized state, so its representation in 3D is not very interesting" ] }, { @@ -140,7 +142,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xy\", k=(0,0,0), transforms=[\"square\"]) # by default grid_prec is 0.2 Ang" + "plot.update_inputs(axes=\"xy\", transforms=[\"square\"]) # by default grid_prec is 0.2 Ang" ] }, { @@ -149,14 +151,14 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(grid_prec=0.05)" + "plot.update_inputs(grid_prec=0.05)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Much better, isn't it? Notice how it didn't look that bad in 3d, because the grid is smooth, so it's values are nicely interpolated. You can also appreciate this by setting `zsmooth` to `\"best\"` in 2D, which does an \"OK job\" at guessing the values." + "Much better, isn't it? Notice how it didn't look that bad in 3d, because the grid is smooth, so it's values are nicely interpolated. You can also appreciate this by setting `smooth` to `True` in 2D, which does an \"OK job\" at guessing the values." ] }, { @@ -165,7 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(grid_prec=0.2, zsmooth=\"best\")" + "plot.update_inputs(grid_prec=0.2, smooth=True)" ] }, { @@ -196,7 +198,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_settings(axes=\"xyz\", nsc=[2,2,1], grid_prec=0.1, transforms=[],\n", + "plot.update_inputs(axes=\"xyz\", nsc=[2,2,1], grid_prec=0.1, transforms=[],\n", " isos=[\n", " {\"val\": -0.07, \"opacity\": 1, \"color\": \"salmon\"},\n", " {\"val\": 0.07, \"opacity\": 0.7, \"color\": \"blue\"}\n", @@ -254,7 +256,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -268,9 +270,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.15" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/src/sisl/io/siesta/tests/test_eig.py b/src/sisl/io/siesta/tests/test_eig.py index ccd44e2e25..21a56da9fa 100644 --- a/src/sisl/io/siesta/tests/test_eig.py +++ b/src/sisl/io/siesta/tests/test_eig.py @@ -10,7 +10,6 @@ from sisl.io.siesta.eig import * from sisl.io.siesta.fdf import * - pytestmark = [pytest.mark.io, pytest.mark.siesta] _dir = osp.join("sisl", "io", "siesta") diff --git a/src/sisl/nodes/__init__.py b/src/sisl/nodes/__init__.py index b07f9b554d..d16664a624 100644 --- a/src/sisl/nodes/__init__.py +++ b/src/sisl/nodes/__init__.py @@ -1,6 +1,4 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. from .context import SISL_NODES_CONTEXT, NodeContext, temporal_context from .node import Node +from .utils import nodify_module from .workflow import Workflow diff --git a/src/sisl/nodes/context.py b/src/sisl/nodes/context.py index cb14872673..85efa410f7 100644 --- a/src/sisl/nodes/context.py +++ b/src/sisl/nodes/context.py @@ -1,6 +1,3 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. import contextlib from collections import ChainMap from typing import Any, Union @@ -11,9 +8,8 @@ lazy=True, # On initialization, should the node compute? If None, defaults to `lazy`. lazy_init=None, - # Debugging options - debug=False, - debug_show_inputs=False + # The level of logs stored in the node. + log_level="INFO" ) # Temporal contexts stack. It should not be used directly by users, the aim of this diff --git a/src/sisl/nodes/dispatcher.py b/src/sisl/nodes/dispatcher.py index a5aecd5ebe..0cb7bbb65c 100644 --- a/src/sisl/nodes/dispatcher.py +++ b/src/sisl/nodes/dispatcher.py @@ -1,6 +1,3 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. from .context import lazy_context from .node import Node diff --git a/src/sisl/nodes/node.py b/src/sisl/nodes/node.py index f8f6d79553..c94635c937 100644 --- a/src/sisl/nodes/node.py +++ b/src/sisl/nodes/node.py @@ -1,11 +1,10 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. from __future__ import annotations import inspect +import logging from collections import ChainMap -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from io import StringIO +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union from numpy.lib.mixins import NDArrayOperatorsMixin @@ -32,7 +31,6 @@ def __init__(self, node, error, inputs): def __str__(self): return (f"Couldn't generate an output for {self._node} with the current inputs.") - class NodeInputError(NodeError): def __init__(self, node, error, inputs): @@ -43,7 +41,6 @@ def __str__(self): # Should make this more specific return (f"Some input is not right in {self._node} and could not be parsed") - class Node(NDArrayOperatorsMixin): """Generic class for nodes. @@ -79,7 +76,7 @@ class Node(NDArrayOperatorsMixin): _prev_evaluated_inputs: Dict[str, Any] # Current output value of the node - _output: Any + _output: Any = _blank # Nodes that are connected to this node's inputs _input_nodes: Dict[str, Node] @@ -90,6 +87,13 @@ class Node(NDArrayOperatorsMixin): _nupdates: int # Whether the node's output is currently outdated. _outdated: bool + # Whether the node has errored during the last execution + # with the current inputs. + _errored: bool + + # Logs of the node's execution. + _logger: logging.Logger + logs: str # Contains the raw function of the node. function: Callable @@ -104,11 +108,15 @@ def __init__(self, *args, **kwargs): if not lazy_init: self.get() + + def __call__(self, *args, **kwargs): + self.update_inputs(*args, **kwargs) + return self.get() def setup(self, *args, **kwargs): """Sets up the node based on its initial inputs.""" # Parse inputs into arguments. - bound_params = self.__class__.__signature__.bind_partial(*args, **kwargs) + bound_params = inspect.signature(self.function).bind_partial(*args, **kwargs) bound_params.apply_defaults() self._inputs = bound_params.arguments @@ -124,6 +132,13 @@ def setup(self, *args, **kwargs): self._nupdates = 0 self._outdated = True + self._errored = False + + self._logger = logging.getLogger( + str(id(self)) + ) + self._log_formatter = logging.Formatter(fmt='%(asctime)s | %(levelname)-8s :: %(message)s') + self.logs = "" self.context = self.__class__.context.new_child({}) @@ -161,7 +176,7 @@ def __init_subclass__(cls): init_sig = sig if "self" not in init_sig.parameters: init_sig = sig.replace(parameters=[ - inspect.Parameter("self", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("self", kind=inspect.Parameter.POSITIONAL_ONLY), *sig.parameters.values() ]) @@ -177,13 +192,12 @@ def __init_subclass__(cls): if parameter.kind == parameter.VAR_KEYWORD: cls._kwargs_inputs_key = key - cls.__init__.__signature__ = init_sig cls.__signature__ = no_self_sig return super().__init_subclass__() @classmethod - def from_func(cls, func: Optional[Callable] = None, context: Optional[dict] = None): + def from_func(cls, func: Union[Callable, None] = None, context: Union[dict, None] = None): """Builds a node from a function. Parameters @@ -204,6 +218,11 @@ def from_func(cls, func: Optional[Callable] = None, context: Optional[dict] = No if isinstance(func, type) and issubclass(func, Node): return func + if isinstance(func, Node): + node = func + + return CallableNode(func=node) + if func in cls._known_function_nodes: return cls._known_function_nodes[func] @@ -314,34 +333,56 @@ def _sanitize_inputs(self, inputs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dic kwargs.update(kwargs_inputs) return args, kwargs + + @staticmethod + def evaluate_input_node(node: Node): + return node.get() def get(self): # Map all inputs to their values. That is, if they are nodes, call the get # method on them so that we get the updated output. This recursively evaluates nodes. + self._logger.setLevel(getattr(logging, self.context['log_level'].upper())) + + logs = logging.StreamHandler(StringIO()) + self._logger.addHandler(logs) + + logs.setFormatter(self._log_formatter) + + self._logger.debug("Getting output from node...") + self._logger.debug(f"Raw inputs: {self._inputs}") + evaluated_inputs = self.map_inputs( inputs=self._inputs, - func=lambda node: node.get(), + func=self.evaluate_input_node, only_nodes=True, ) + self._logger.debug(f"Evaluated inputs: {evaluated_inputs}") + if self._outdated or self.is_output_outdated(evaluated_inputs): try: args, kwargs = self._sanitize_inputs(evaluated_inputs) self._output = self.function(*args, **kwargs) - if self.context['debug']: - if self.context['debug_show_inputs']: - info(f"{self}: evaluated with inputs {evaluated_inputs}: {self._output}.") - else: - info(f"{self}: evaluated because inputs changed.") + + self._logger.info(f"Evaluated because inputs changed.") except Exception as e: + self._logger.exception(e) + self.logs += logs.stream.getvalue() + logs.close() + self._errored = True raise NodeCalcError(self, e, evaluated_inputs) self._nupdates += 1 self._prev_evaluated_inputs = evaluated_inputs self._outdated = False + self._errored = False else: - if self.context['debug']: - info(f"{self}: no need to evaluate") + self._logger.info(f"No need to evaluate") + + self._logger.debug(f"Output: {self._output}.") + + self.logs += logs.stream.getvalue() + logs.close() return self._output @@ -359,7 +400,7 @@ def get_tree(self): @property def default_inputs(self): - params = self.__class__.__signature__.bind_partial() + params = inspect.signature(self.function).bind_partial() params.apply_defaults() return params.arguments @@ -371,18 +412,94 @@ def get_input(self, key: str): input_val = self.inputs[key] return input_val + + def recursive_update_inputs(self, cls: Optional[Union[Type, Tuple[Type, ...]]] = None, **inputs): + """Updates the inputs of the node recursively. + + This method updates the inputs of the node and all its children. + + Parameters + ---------- + cls : Optional[Union[Type, Tuple[Type, ...]]], optional + Only update nodes of this class. If None, update all nodes. + inputs : Dict[str, Any] + The inputs to update. + """ + from .utils import traverse_tree_backward + + def _update(node): + + if cls is None or isinstance(self, cls): + node.update_inputs(**inputs) + + update_inputs = {} + # Update the inputs of the node + for k in self.inputs: + if k in inputs: + update_inputs[k] = inputs[k] + + self.update_inputs(**update_inputs) - def update_inputs(self, *args, **inputs): + traverse_tree_backward([self], _update) + + def update_inputs(self, **inputs): + """Updates the inputs of the node. + + Note that you can not pass positional arguments to this method. + The positional arguments must be passed also as kwargs. + + This is because there would not be a well defined way to update the + variadic positional arguments. + + E.g. if the function signature is (a: int, *args), there is no way + to pass *args without passing a value for a. + + This means that one must also pass the *args also as a key: + ``update_inputs(args=(2, 3))``. Beware that functions not necessarily + name their variadic arguments ``args``. If the function signature is + ``(a: int, *arguments)`` then the key that you need to use is `arguments`. + + Similarly, the **kwargs can be passed either as a dictionary in the key ``kwargs`` + (or whatever the name of the variadic keyword arguments is). This indicates that + the whole kwargs is to be replaced by the new value. Alternatively, you can pass + the kwargs as separate key-value arguments, which means that you want to update the + kwargs dictionary, but keep the old values. In this second option, you can indicate + that a key should be removed by passing ``Node.DELETE_KWARG`` as the value. + + Parameters + ---------- + **inputs : + The inputs to update. + """ # If no new inputs were provided, there's nothing to do - if not inputs and len(args) == 0: + if not inputs: return - - bound = self.__class__.__signature__.bind_partial(*args, **inputs) + + # Pop the args key (if any) so that we can parse the inputs without errors. + args = None + if self._args_inputs_key: + args = inputs.pop(self._args_inputs_key, None) + # Pop also the kwargs key (if any) + explicit_kwargs = None + if self._kwargs_inputs_key: + explicit_kwargs = inputs.pop(self._kwargs_inputs_key, None) + + # Parse the inputs. We do this to separate the kwargs from the rest of the inputs. + bound = inspect.signature(self.function).bind_partial(**inputs) inputs = bound.arguments + + # Now that we have parsed the inputs, put back the args key (if any). + if args is not None: + inputs[self._args_inputs_key] = args - # If kwargs inputs are provided, add them to the previous input kwargs. - if self._kwargs_inputs_key is not None: - new_kwargs = bound.kwargs + if explicit_kwargs is not None: + # If a kwargs dictionary has been passed, this means that the user wants to replace + # the whole kwargs dictionary. So, we just update the inputs with the new kwargs. + inputs[self._kwargs_inputs_key] = explicit_kwargs + elif self._kwargs_inputs_key is not None: + # Otherwise, update the old kwargs with the new separate arguments that have been passed. + # Here we give the option to delete individual kwargs by passing the DELETE_KWARG indicator. + new_kwargs = inputs.get(self._kwargs_inputs_key, {}) if len(new_kwargs) > 0: kwargs = self._inputs.get(self._kwargs_inputs_key, {}).copy() kwargs.update(new_kwargs) @@ -411,7 +528,12 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return UfuncNode(ufunc=ufunc, method=method, input_kwargs=kwargs, **inputs) def __getitem__(self, key): - return GetItemNode(data=self, key=key) + return GetItemNode(obj=self, key=key) + + def __getattr__(self, key): + if key.startswith('_'): + raise super().__getattr__(key) + return GetAttrNode(obj=self, key=key) def _update_connections(self, inputs): @@ -489,6 +611,7 @@ def _inform_outdated(self): def _receive_outdated(self): # Mark the node as outdated self._outdated = True + self._errored = False # If automatic recalculation is turned on, recalculate output self._maybe_autoupdate() # Inform to the nodes that use our input that they are outdated @@ -500,7 +623,6 @@ def _maybe_autoupdate(self): if not self.context['lazy']: self.get() - class DummyInputValue(Node): """A dummy node that can be used as a placeholder for input values.""" @@ -516,24 +638,37 @@ def value(self): def function(input_key: str, value: Any = Node._blank): return value - class FuncNode(Node): @staticmethod - def function(func: Callable, **kwargs): - return func(**kwargs) - + def function(*args, func: Callable, **kwargs): + return func(*args, **kwargs) + +class CallableNode(FuncNode): + + def __call__(self, *args, **kwargs): + self.update_inputs(*args, **kwargs) + return self class GetItemNode(Node): @staticmethod - def function(data: Any, key: Any): - return data[key] + def function(obj: Any, key: Any): + return obj[key] + +class GetAttrNode(Node): + @staticmethod + def function(obj: Any, key: str): + return getattr(obj, key) class UfuncNode(Node): """Node that wraps a numpy ufunc.""" + def __call__(self, *args, **kwargs): + self.recursive_update_inputs(*args, **kwargs) + return self.get() + @staticmethod def function(ufunc, method: str, input_kwargs: Dict[str, Any], **kwargs): # We need to @@ -545,4 +680,11 @@ def function(ufunc, method: str, input_kwargs: Dict[str, Any], **kwargs): break inputs.append(kwargs.pop(key)) i += 1 - return getattr(ufunc, method)(*inputs, **input_kwargs) + return getattr(ufunc, method)(*inputs, **input_kwargs) + +class ConstantNode(Node): + """Node that just returns its input value.""" + + @staticmethod + def function(value: Any): + return value diff --git a/src/sisl/nodes/syntax_nodes.py b/src/sisl/nodes/syntax_nodes.py new file mode 100644 index 0000000000..bcf4719a9d --- /dev/null +++ b/src/sisl/nodes/syntax_nodes.py @@ -0,0 +1,24 @@ +from .node import Node + + +class SyntaxNode(Node): + ... + +class ListSyntaxNode(SyntaxNode): + + @staticmethod + def function(*items): + return list(items) + + +class TupleSyntaxNode(SyntaxNode): + + @staticmethod + def function(*items): + return tuple(items) + +class DictSyntaxNode(SyntaxNode): + + @staticmethod + def function(**items): + return items \ No newline at end of file diff --git a/src/sisl/nodes/tests/__init__.py b/src/sisl/nodes/tests/__init__.py index 210912af7a..e69de29bb2 100644 --- a/src/sisl/nodes/tests/__init__.py +++ b/src/sisl/nodes/tests/__init__.py @@ -1,4 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" tests for sisl.nodes """ diff --git a/src/sisl/nodes/tests/test_context.py b/src/sisl/nodes/tests/test_context.py index 9c7f83791f..79d06d7345 100644 --- a/src/sisl/nodes/tests/test_context.py +++ b/src/sisl/nodes/tests/test_context.py @@ -1,6 +1,3 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. import pytest from sisl.nodes import Node, Workflow diff --git a/src/sisl/nodes/tests/test_node.py b/src/sisl/nodes/tests/test_node.py index 1147152d55..a45bf3680c 100644 --- a/src/sisl/nodes/tests/test_node.py +++ b/src/sisl/nodes/tests/test_node.py @@ -1,6 +1,3 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. import pytest from sisl.nodes import Node, temporal_context @@ -192,28 +189,6 @@ def reduce_(*nums, factor: int = 1): assert val2.get() == 16 assert val._nupdates == 1 -@temporal_context(lazy=True) -def test_update_args(): - """When calling update_inputs with args, the old args should be completely - discarded and replaced by the ones provided. - """ - - @Node.from_func - def reduce_(*nums, factor: int = 1): - - val = 0 - for num in nums: - val += num - return val * factor - - val = reduce_(1, 2, 3, factor=2) - - assert val.get() == 12 - - val.update_inputs(4, 5, factor=3) - - assert val.get() == 27 - @temporal_context(lazy=True) def test_node_links_args(): @@ -233,25 +208,6 @@ def my_node(*some_args): assert 'some_args[2]' in node2._input_nodes assert node2._input_nodes['some_args[2]'] is node1 - # Now check that if we update node2, the connections - # will be removed. - node2.update_inputs(2) - - assert len(node2._input_nodes) == 0 - assert len(node1._output_links) == 0 - - # Check that connections are properly built when - # updating inputs with a value containing a node. - node2.update_inputs(node1) - - # Check that node1 knows that node2 uses its output - assert len(node1._output_links) == 1 - assert node1._output_links[0] is node2 - - # And that node2 knows it's using node1 as an input. - assert len(node2._input_nodes) == 1 - assert 'some_args[0]' in node2._input_nodes - assert node2._input_nodes['some_args[0]'] is node1 @temporal_context(lazy=True) def test_kwargs(): @@ -331,4 +287,14 @@ def my_node(**some_kwargs): # And that node2 knows it's using node3 as an input. assert len(node2._input_nodes) == 1 assert 'some_kwargs[a]' in node2._input_nodes - assert node2._input_nodes['some_kwargs[a]'] is node3 \ No newline at end of file + assert node2._input_nodes['some_kwargs[a]'] is node3 + +def test_ufunc(sum_node): + + node = sum_node(1, 3) + + assert node.get() == 4 + + node2 = node + 6 + + assert node2.get() == 10 \ No newline at end of file diff --git a/src/sisl/nodes/tests/test_syntax_nodes.py b/src/sisl/nodes/tests/test_syntax_nodes.py new file mode 100644 index 0000000000..6d607bc62c --- /dev/null +++ b/src/sisl/nodes/tests/test_syntax_nodes.py @@ -0,0 +1,29 @@ +from sisl.nodes.syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode +from sisl.nodes.workflow import Workflow + + +def test_list_syntax_node(): + assert ListSyntaxNode("a", "b", "c").get() == ["a", "b", "c"] + +def test_tuple_syntax_node(): + assert TupleSyntaxNode("a", "b", "c").get() == ("a", "b", "c") + +def test_dict_syntax_node(): + assert DictSyntaxNode(a="b", c="d", e="f").get() == {"a": "b", "c": "d", "e": "f"} + +def test_workflow_with_syntax(): + + def f(a): + return [a] + + assert Workflow.from_func(f)(2).get() == [2] + + def f(a): + return (a,) + + assert Workflow.from_func(f)(2).get() == (2,) + + def f(a): + return {"a": a} + + assert Workflow.from_func(f)(2).get() == {"a": 2} diff --git a/src/sisl/nodes/tests/test_workflow.py b/src/sisl/nodes/tests/test_workflow.py index 07b4aa27a6..1b05a1ad1b 100644 --- a/src/sisl/nodes/tests/test_workflow.py +++ b/src/sisl/nodes/tests/test_workflow.py @@ -1,6 +1,3 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. from typing import Type import pytest diff --git a/src/sisl/nodes/utils.py b/src/sisl/nodes/utils.py index 7cb4356b28..4c76fc2443 100644 --- a/src/sisl/nodes/utils.py +++ b/src/sisl/nodes/utils.py @@ -1,7 +1,6 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from typing import Any, Callable, Sequence +import inspect +from types import FunctionType, ModuleType +from typing import Any, Callable, Dict, Sequence, Type from .node import Node @@ -57,4 +56,82 @@ def visit_all_connected(nodes: Sequence[Node], func: Callable[[Node], Any], _see _seen_nodes.append(id(node)) traverse_tree_forward((node, ), func=lambda node: visit_all_connected((node, ), func=func, _seen_nodes=_seen_nodes)) - traverse_tree_backward((node, ), func=lambda node: visit_all_connected((node, ), func=func, _seen_nodes=_seen_nodes)) \ No newline at end of file + traverse_tree_backward((node, ), func=lambda node: visit_all_connected((node, ), func=func, _seen_nodes=_seen_nodes)) + +def nodify_module(module: ModuleType, node_class: Type[Node] = Node) -> ModuleType: + """Returns a copy of a module where all functions are replaced with nodes. + + This new nodified module contains only nodes (coming from functions or classes). + The rest of variables are not copied. In fact, the new module uses the variables + from the original module. + + Also, some functions might not be convertable to nodes and therefore won't be found + in the new module. + + Parameters + ---------- + module : ModuleType + The module to nodify. + node_class : Type[Node], optional + The class from which the created nodes will inherit, by default Node. + This can be useful for example to convert to workflows, if you pass + the Workflow class. + + Returns + ------- + ModuleType + A new module with all functions replaced with nodes. + """ + + # Function that recursively traverses the module and replaces functions with nodes. + def _nodified_module(module: ModuleType, visited: Dict[ModuleType, ModuleType], main_module: str) -> ModuleType: + # This module has already been visited, so do return the already nodified module. + if module in visited: + return visited[module] + + # Create a copy of this module, with the nodified_ prefix in the name. + noded_module = ModuleType(f"nodified_{module.__name__}") + # Register the module as visited. + visited[module] = noded_module + + all_vars = vars(module).copy() + + # Loop through all the variables in the module. + for k, variable in all_vars.items(): + if k.startswith("__"): + continue + + # Initialize the noded variable to None. + noded_variable = None + + if isinstance(variable, (type, FunctionType)): + # If the variable was not defined in the module that we are nodifying, + # skip it. This is to avoid nodifying variables that were imported + # from other modules. + module_name = getattr(variable, "__module__", "") or "" + if not (isinstance(module_name, str) and module_name.startswith(main_module)): + continue + + # If the variable is a function or a class, try to create a node from it. + # There are some reasons why a function or class with exotic properties + # might not be able to be converted to a node. We do not aim at nodifying them. + try: + noded_variable = node_class.from_func(variable) + noded_variable.__module__ = f"nodified_{variable.__module__}" + except: + ... + elif inspect.ismodule(variable): + module_name = getattr(variable, "__name__", "") or "" + if not (isinstance(module_name, str) and module_name.startswith(main_module)): + continue + + # If the variable is a module, recursively nodify it. + noded_variable = _nodified_module(variable, visited, main_module=main_module) + + # Add the new noded variable to the new module. + if noded_variable is not None: + setattr(noded_module, k, noded_variable) + + return noded_module + + return _nodified_module(module, visited={}, main_module=module.__name__) \ No newline at end of file diff --git a/src/sisl/nodes/workflow.py b/src/sisl/nodes/workflow.py index 42d433548b..77d7353544 100644 --- a/src/sisl/nodes/workflow.py +++ b/src/sisl/nodes/workflow.py @@ -1,32 +1,20 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. from __future__ import annotations import ast import html import inspect import textwrap +from _ast import Dict from collections import ChainMap from types import FunctionType -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Type, Union from sisl._environ import get_environ_variable, register_environ_variable from sisl.messages import warn from .context import temporal_context from .node import DummyInputValue, Node +from .syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode from .utils import traverse_tree_backward, traverse_tree_forward register_environ_variable( @@ -151,7 +139,7 @@ def to_pyvis(self, colorscale: str = "viridis", show_workflow_inputs: bool = Fal notebook: bool = False, hierarchial: bool = True, inputs_props: Dict[str, Any] = {}, node_props: Dict[str, Any] = {}, leafs_props: Dict[str, Any] = {}, output_props: Dict[str, Any] = {}, auto_text_color: bool = True, - to_export: Optional[bool] = None, + to_export: Union[bool, None] = None, ): """Convert a Workflow class to a pyvis network for visualization. @@ -357,7 +345,7 @@ def rgb2gray(rgb): return net @staticmethod - def _show_pyvis(net: "pyvis.Network", notebook: bool, to_export: Optional[bool]): + def _show_pyvis(net: "pyvis.Network", notebook: bool, to_export: Union[bool, None]): """Shows a pyvis network. This is implemented here because pyvis implementation of `show` is very dangerous, @@ -415,7 +403,7 @@ def visualize(self, colorscale: str = "viridis", show_workflow_inputs: bool = Tr edge_labels: bool = True, node_help: bool = True, notebook: bool = False, hierarchial: bool = True, node_props: Dict[str, Any] = {}, inputs_props: Dict[str, Any] = {}, leafs_props: Dict[str, Any] = {}, output_props: Dict[str, Any] = {}, - to_export: Optional[bool] = None, + to_export: Union[bool, None] = None, ): """Visualize the workflow's network in a plot. @@ -500,6 +488,69 @@ def from_workflow_run(cls, inputs: Dict[str, WorkflowInput], output: WorkflowOut break return cls(inputs=inputs, workers=workers, output=output, named_vars=_named_vars) + + @classmethod + def from_node_tree(cls, output_node): + + # Gather all worker nodes inside the workflow. + workers = cls.gather_from_inputs_and_output([], output=output_node) + + # Dictionary that will store the workflow input nodes. + wf_inputs = {} + # The workers found by traversing are node instances that might be in use + # by the user, so we should create copies of them and store them in this new_workers + # dictionary. Additionally, we need a mapping from old nodes to new nodes in order + # to update the links between nodes. + new_workers = {} + old_to_new = {} + # Loop through the workers. + for k, node in workers.items(): + # Find out the inputs that we should connect to the workflow inputs. We connect all inputs + # that are not nodes, and that are not the args or kwargs inputs. + node_inputs = { + param_k: WorkflowInput(input_key=f"{node.__class__.__name__}_{param_k}", value=node.inputs[param_k]) + for param_k, v in node.inputs.items() if not ( + isinstance(v, Node) or param_k == node._args_inputs_key or param_k == node._kwargs_inputs_key + ) + } + + # Create a new node using the newly determined inputs. However, we keep the links to the old nodes + # These inputs will be updated later. + with temporal_context(lazy=True): + new_workers[k] = node.__class__().update_inputs(**{**node.inputs, **node_inputs}) + + # Register this new node in the mapping from old to new nodes. + old_to_new[id(node)] = new_workers[k] + + # Update the workflow inputs dictionary with the inputs that we have determined + # to be connected to this node. We use the node class name as a prefix to avoid + # name clashes. THIS IS NOT PERFECT, IF THERE ARE TWO NODES OF THE SAME CLASS + # THERE CAN BE A CLASH. + wf_inputs.update({ + f"{node.__class__.__name__}_{param_k}": v for param_k, v in node_inputs.items() + }) + + # Now that we have all the node copies, update the links to old nodes with + # links to new nodes. + for k, node in new_workers.items(): + + new_node_inputs = {} + for param_k, v in node.inputs.items(): + if param_k == node._args_inputs_key: + new_node_inputs[param_k] = [old_to_new[id(n)] if isinstance(n, Node) else n for n in v] + elif param_k == node._args_inputs_key: + new_node_inputs[param_k] = {k: old_to_new[id(n)] if isinstance(n, Node) else n for k, n in v.items()} + elif isinstance(v, Node) and not isinstance(v, WorkflowInput): + new_node_inputs[param_k] = old_to_new[id(v)] + + with temporal_context(lazy=True): + node.update_inputs(**new_node_inputs) + + # Create the workflow output. + new_output = WorkflowOutput(value=old_to_new[id(output_node)]) + + # Initialize and return the WorkflowNodes object. + return cls(inputs=wf_inputs, workers=new_workers, output=new_output, named_vars={}) def __dir__(self) -> Iterable[str]: return dir(self.named_vars) + dir(self._all_nodes) @@ -642,7 +693,51 @@ class Workflow(Node): network = NetworkDescriptor() + @classmethod + def from_node_tree(cls, output_node: Node, workflow_name: Union[str, None] = None): + """Creates a workflow class from a node. + + It does so by recursively traversing the tree in the inputs direction until + it finds the leaves. + All the nodes found are included in the workflow. For each node, inputs + that are not nodes are connected to the inputs of the workflow. + + Parameters + ---------- + output_node: Node + The final node, that should be connected to the output of the workflow. + workflow_name: str, optional + The name of the new workflow class. If None, the name of the output node + will be used. + + Returns + ------- + Workflow + The newly created workflow class. + """ + # Create the node manager for the workflow. + dryrun_nodes = WorkflowNodes.from_node_tree(output_node) + + # Create the signature of the workflow from the inputs that were determined + # by the node manager. + signature = inspect.Signature(parameters=[ + inspect.Parameter(inp.input_key, inspect.Parameter.KEYWORD_ONLY, default=inp.value) for inp in dryrun_nodes.inputs.values() + ]) + + def function(*args, **kwargs): + raise NotImplementedError("Workflow class created from node tree. Calling it as a function is not supported.") + + function.__signature__ = signature + + # Create the class and return it. + return type( + workflow_name or output_node.__class__.__name__, + (cls,), + {"dryrun_nodes": dryrun_nodes, "__signature__": signature, "function": staticmethod(function)} + ) + def setup(self, *args, **kwargs): + self.nodes = self.dryrun_nodes super().setup(*args, **kwargs) self.nodes = self.dryrun_nodes.copy(inputs=self._inputs) @@ -651,6 +746,9 @@ def __init_subclass__(cls): # If this is just a subclass of Workflow that is not meant to be ran, continue if not hasattr(cls, "function"): return super().__init_subclass__() + # Also, if the node manager has already been created, continue. + if "dryrun_nodes" in cls.__dict__: + return super().__init_subclass__() # Otherwise, do all the setting up of the class @@ -659,6 +757,11 @@ def __init_subclass__(cls): named_vars = {} def assign_workflow_var(value: Any, var_name: str): + original_name = var_name + repeats = 0 + while var_name in named_vars: + repeats += 1 + var_name = f"{original_name}_{repeats}" if var_name in named_vars: raise ValueError(f"Variable {var_name} has already been assigned a value, in workflows you can't overwrite variables.") named_vars[var_name] = value @@ -713,7 +816,7 @@ def find_node_key(cls, node, *args) -> str: if len(args) == 1: return args[0] - raise ValueError(f"Could not find node {node} in the workflow. Workflow nodes {node}") + raise ValueError(f"Could not find node {node} in the workflow. Workflow nodes {cls.dryrun_nodes.items()}") def get(self): """Returns the up to date output of the workflow. @@ -736,11 +839,19 @@ def update_inputs(self, **inputs): self._inputs.update(inputs) return self + + def _get_output(self): + return self.nodes.output._output + + def _set_output(self, value): + self.nodes.output._output = value + + _output = property(_get_output, _set_output) class NodeConverter(ast.NodeTransformer): """AST transformer that converts a function into a workflow.""" - def __init__(self, *args, assign_fn: Optional[str] = None, node_cls_name: str = "Node", **kwargs): + def __init__(self, *args, assign_fn: Union[str, None] = None, node_cls_name: str = "Node", **kwargs): super().__init__(*args, **kwargs) self.assign_fn = assign_fn @@ -758,7 +869,6 @@ def visit_Call(self, node): ast.fix_missing_locations(node2) - return node2 def visit_Assign(self, node): @@ -766,6 +876,8 @@ def visit_Assign(self, node): if self.assign_fn is None: return self.generic_visit(node) + if len(node.targets) > 1 or not isinstance(node.targets[0], ast.Name): + return self.generic_visit(node) node.value = ast.Call( func=ast.Name(id=self.assign_fn, ctx=ast.Load()), @@ -780,10 +892,62 @@ def visit_Assign(self, node): return node + def visit_List(self, node): + """Converts the list syntax into a call to the ListSyntaxNode.""" + if all(isinstance(elt, ast.Constant) for elt in node.elts): + return self.generic_visit(node) + + new_node = ast.Call( + func=ast.Name(id="ListSyntaxNode", ctx=ast.Load()), + args=[self.visit(elt) for elt in node.elts], + keywords=[] + ) + + ast.fix_missing_locations(new_node) + + return new_node + + def visit_Tuple(self, node): + """Converts the tuple syntax into a call to the TupleSyntaxNode.""" + if all(isinstance(elt, ast.Constant) for elt in node.elts): + return self.generic_visit(node) + + new_node = ast.Call( + func=ast.Name(id="TupleSyntaxNode", ctx=ast.Load()), + args=[self.visit(elt) for elt in node.elts], + keywords=[] + ) + + ast.fix_missing_locations(new_node) + + return new_node + + def visit_Dict(self, node: ast.Dict) -> Any: + """Converts the dict syntax into a call to the DictSyntaxNode.""" + if all(isinstance(elt, ast.Constant) for elt in node.values): + return self.generic_visit(node) + if not all(isinstance(elt, ast.Constant) for elt in node.keys): + return self.generic_visit(node) + + new_node = ast.Call( + func=ast.Name(id="DictSyntaxNode", ctx=ast.Load()), + args=[], + keywords=[ + ast.keyword(arg=key.value, value=self.visit(value)) + for key, value in zip(node.keys, node.values) + ], + ) + + ast.fix_missing_locations(new_node) + + return new_node + + + def nodify_func( func: FunctionType, transformer_cls: Type[NodeConverter] = NodeConverter, - assign_fn: Optional[Callable] = None, + assign_fn: Union[Callable, None] = None, node_cls: Type[Node] = Node ) -> FunctionType: """Converts all calculations of a function into nodes. @@ -799,7 +963,7 @@ def nodify_func( The function to convert. transformer_cls : Type[NodeConverter], optional The NodeTransformer class to that is used to transform the AST. - assign_fn : Callable, optional + assign_fn : Union[Callable, None], optional A function that will be placed as middleware for variable assignments. It will be called with the following arguments: - value: The value assigned to the variable. @@ -850,7 +1014,11 @@ def nodify_func( code_obj = compile(new_tree, "compiled_workflows", "exec") # Add the needed variables into the namespace. - namespace = {node_cls_name: node_cls, **func_namespace} + namespace = { + node_cls_name: node_cls, + "ListSyntaxNode": ListSyntaxNode, "TupleSyntaxNode": TupleSyntaxNode, "DictSyntaxNode": DictSyntaxNode, + **func_namespace, + } if assign_fn_key is not None: namespace[assign_fn_key] = assign_fn diff --git a/src/sisl/viz/.coverage b/src/sisl/viz/.coverage new file mode 100644 index 0000000000..b486dd959a Binary files /dev/null and b/src/sisl/viz/.coverage differ diff --git a/src/sisl/viz/__init__.py b/src/sisl/viz/__init__.py index 39644e8ec2..81b391cbc3 100644 --- a/src/sisl/viz/__init__.py +++ b/src/sisl/viz/__init__.py @@ -5,17 +5,8 @@ Visualization utilities ======================= -Various visualization modules are described here. - - -Plotly -====== - -The plotly backend. """ -# from ._presets import * -# from ._templates import * -# from ._user_customs import import_user_plots, import_user_presets, import_user_sessions, import_user_plugins + import os from sisl._environ import register_environ_variable @@ -29,17 +20,10 @@ description="Maximum number of processors used for parallel plotting", process=int) -# isort: split -from .plot import Animation, MultiplePlot, Plot, SubPlots - -# isort: split - +from . import _xarray_accessor from ._plotables import register_plotable from ._plotables_register import * -from .backends import load_backends +from .figure import Figure, get_figure +from .plot import Plot from .plots import * -from .plotutils import load -from .session import Session -from .sessions import * - -load_backends() +from .plotters import plot_actions diff --git a/src/sisl/viz/_doc_updater.py b/src/sisl/viz/_doc_updater.py deleted file mode 100644 index 4bfef079a9..0000000000 --- a/src/sisl/viz/_doc_updater.py +++ /dev/null @@ -1,100 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -"""This file makes documentation of plots and sessions easy based on their parameters. - -Basically, you just need to have your class defined and have a flag in the docs -specifying where the settings documentation should go. This script will fill it -for you. - -Example: - -class FakePlot(Plot): - ''' - This plot does really nothing useful - - Parameters - ----------- - %%configurable_settings%% - ''' - -%%configurable_settings%% is the key to let the script now where to put the documentation. - -IF YOU HAVE MORE THAN ONE PLOT CLASS IN A FILE, YOU SHOULD SPECIFY %%FakePlot_configurable_settings%% - -Then, just run `python -m sisl.viz._doc_updater`. -Or you can use fill_class_docs to only update a certain class. -""" -import inspect - -from sisl.viz.plot import Animation, MultiplePlot, Plot, SubPlots -from sisl.viz.plotutils import get_plot_classes, get_session_classes -from sisl.viz.session import Session - - -def get_parameters_docstrings(cls): - """ - Returns the documentation for the configurable's parameters. - - Parameters - ----------- - cls: - the class you want the docstring for - - Returns - ----------- - str: - the docs with the settings added. - """ - import re - - if isinstance(cls, type): - params = cls._get_class_params()[0] - doc = cls.__doc__ - if doc is None: - doc = "" - else: - # It's really an instance, not the class - params = cls.params - doc = "" - - configurable_settings = "\n".join( - [param._get_docstring() for param in params]) - - html_cleaner = re.compile('<.*?>') - configurable_settings = re.sub(html_cleaner, '', configurable_settings) - - return configurable_settings - - -def fill_class_docs(cls): - """ Fills the documentation for a class that inherits from Configurable - - You just need to use the placeholder %%configurable_settings%% or - %%ClassName_configurable_settings%% for more specificity, where ClassName is the name - of the class that you want to document. Then, this function replaces that placeholder - with the documentation for all the settings. - - Parameters - ----------- - cls: - the class you want to document. - - """ - filename = inspect.getfile(cls) - parameters_docs = "\n ".join( - get_parameters_docstrings(cls) - .split("\n") - ) - - with open(filename, 'r') as fi: - lines = fi.read() - new_lines = lines.replace("%%configurable_settings%%", parameters_docs) - new_lines = new_lines.replace(f"%%{cls.__name__}_configurable_settings%%", parameters_docs) - - open(filename, 'w').write(new_lines) - - -if __name__ == "__main__": - for cls in [*get_plot_classes(), Plot, MultiplePlot, Animation, SubPlots, Session, *get_session_classes().values()]: - fill_class_docs(cls) diff --git a/src/sisl/viz/_input_field.py b/src/sisl/viz/_input_field.py deleted file mode 100644 index 4ae2cf4c02..0000000000 --- a/src/sisl/viz/_input_field.py +++ /dev/null @@ -1,316 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -#This file defines all the currently available input fields so that it is easier to develop plots -import json -from copy import copy, deepcopy - -import numpy as np - -from .plotutils import get_nested_key, modify_nested_dict - -__all__ = ["InputField"] - - -class InputField: - """ This class is meant to help a smooth interface between python and the GUI. - - A class that inherits from Configurable should have all its settings defined as `ÌnputField`s. In - this way, the GUI will know how to display an input field to let the user interact with the - settings plot. - - This is just the base class of all input fields. Each type of field has its own class. The most - simple one is `TextInput`, which just renders an input of type text in the GUI. - - Input fields also help documenting the class and parsing the user's input to normalize it (with - the `parse` method.). - - Finally, since all input fields are copied to the instance, the classes that define input fields - can have methods that help making your life easier. See the `OrbitalQueries` input field for an example - of that. In that case, the input field can update its own options based on a geometry that is passed. - - Parameters - ---------- - key: str - The key with which you will be able to access the value of the setting in your class - name: str - The name that you want to show for this setting in the GUI. - default: optional (None) - The default value for the setting. If it is not provided it will be None. - params: dict, optional - A dictionary with parameters that you want to add to the params key of the input field. - If a key is already in the defaults, your provided value will have preference. - style: dict, optional - A dictionary with parameters that you want to add to the style key of the input field. - If a key is already in the defaults, your provided value will have preference. - - The keys inside style determine the aesthetical appearance of the input field. This is passed directly - to the style property of the container of the input field. Therefore, one could pass any react CSS key. - - If you don't know what CSS is, don't worry, it's easy and cool. The only thing that you need to know is - that the style dictionary contains keys that determine how something looks in the web. For example: - - {backgroundColor: "red", padding: 30} - - would render a container with red background color and a padding of 30px. - - This links provide info on: - - What CSS is: https://www.youtube.com/watch?v=4BEyFVufmM8&list=PL4cUxeGkcC9gQeDH6xYhmO-db2mhoTSrT&index=2 - - React CSS examples (you can play with them): https://www.w3schools.com/react/react_css.asp - - Just remember that you want to pass REACT CSS keys, NOT CSS. Basically the difference is that "-" are replaced by - capital letters: - Normal CSS: {font-size: 10} React CSS: {fontSize: 10} - - You probably won't need to style anything and the defaults are good enough, but we still give this option for more flexibility. - inputFieldAttrs: dict, optional - A dictionary with additional keys that you want to add to the inputField dictionary. - group: str, optional - Group of parameters to which the parameter belongs - subGroup: str, optional - If the setting belongs to a group, the subgroup it is in (if any). - help: str, optional - Help message to guide the user on what the parameter does. They will appear as tooltips in the GUI. - - Supports html tags, so one can write
to generate a new line or mylink to display a link, for example. - - This parameter is optional but extremely adviseable. - **kwargs: - All keyword arguments passed will be added to the parameter, overwriting any existing value in case there is one. - """ - - dtype = None - - def __init__(self, key, name, default=None, params={}, style={}, input_field_attrs={}, group=None, subGroup=None, dtype=None, help="", **kwargs): - self.key = key - self.name = name - self.default = default - self.group = group - self.subGroup = subGroup - self.help = help - - if dtype is not None: - self.dtype = dtype - - default_input = deepcopy(getattr(self, "_default", {})) - - setattr(self, "inputField", { - 'type': copy(getattr(self, '_type', None)), - **default_input, - "params": { - **default_input.get("params", {}), - **params - }, - "style": { - **default_input.get("style", {}), - **style - }, - **input_field_attrs - }) - - for key, value in kwargs.items(): - - setattr(self, key, value) - - def __getitem__(self, key): - """ Gets a key from the input field, even if it is nested """ - if isinstance(key, str): - return get_nested_key(self.__dict__, key) - - return None - - def __str__(self): - """ String representation of the structure of the input field """ - return str(vars(self)) - - def __repr__(self): - """ String representation of the structure of the input field """ - return self.__str__() - - def modify(self, *args): - """ Modifies the parameter - - See *args to know how can it be used. - - This is a general schema of how an input field parameter looks internally, so that you - can know what do you want to change: - - (Note that it is very easy to modify nested values, more on this in *args explanation) - - { - "key": whatever, - "name": whatever, - "default": whatever, - . - . (keys that affect, let's say, the programmatic functionality of the parameter, - . they can be modified with Configurable.modify_param) - . - "inputField": { - "type": whatever, - "params": { they can be modified with Configurable.modifyInputField) - whatever - }, - "style": { - whatever - } - - } - } - - Arguments - -------- - *args: - Depending on what you pass the setting will be modified in different ways: - - Two arguments: - the first argument will be interpreted as the attribute that you want to change, - and the second one as the value that you want to set. - - Ex: obj.modify_param("length", "default", 3) - will set the default attribute of the parameter with key "length" to 3 - - Modifying nested keys is possible using dot notation. - - Ex: obj.modify_param("length", "inputField.params.min", 3) - will modify the min key inside inputField params on the schema above. - - The last key, but only the last one, will be created if it does not exist. - - Ex: obj.modify_param("length", "inputField.params.min", 3) - will only work if all the path before `min` exists and the value of `params` is a dictionary. - - Otherwise you could go like this: obj.modify_param("length", "inputField.params", {"min": 3}) - - - One argument and it is a dictionary: - the keys will be interpreted as attributes that you want to change and the values - as the value that you want them to have. - - Each key-value pair in the dictionary will be updated in exactly the same way as - it is in the previous case. - - - One argument and it is a function: - - the function will recieve the parameter and can act on it in any way you like. - It doesn't need to return the parameter, just modify it. - In this function, you can call predefined methods of the parameter, for example. - - Ex: obj.modify_param("length", lambda param: param.incrementByOne() ) - - given that you know that this type of parameter has this method. - - Returns - -------- - self: - The configurable object. - """ - if len(args) == 2: - - modFunction = lambda obj: modify_nested_dict(obj.__dict__, *args) - - elif isinstance(args[0], dict): - - def modFunction(obj): - for attr, val in args[0].items(): - modify_nested_dict(obj.__dict__, attr, val) - - elif callable(args[0]): - - modFunction = args[0] - - modFunction(self) - - return self - - def to_json(self): - """ Helps converting the input field to json so that it can be sent to the GUI - - Returns - --------- - dict - the dict ready to be jsonified. - """ - def default(obj): - - if isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.generic): - return obj.item() - else: - return getattr(obj, '__dict__', str(obj)) - - return json.loads( - json.dumps(self, default=default) - ) - - def parse(self, val): - """ Parses the user input to the actual values that will be used - - This method may be overwritten, but you probably need to still call - it with `super().parse(val)`, because it implements the basic functionality - for the `splot` commmand to understand the values that receives. - - Parameters - ----------- - val: any - the value to parse - - Returns - ----------- - self.dtype - the parsed value, which will be of the datatype specified by the input. - """ - if val is None: - return None - - dtypes = self.dtype - - if dtypes is None: - return val - - if not isinstance(dtypes, tuple): - dtypes = (dtypes, ) - - for dtype in dtypes: - try: - if dtype == bool and isinstance(val, str): - val = val.lower() not in ('false', 'f', 'no', 'n') - elif dtype in [list, int, float]: - val = dtype(val) - except Exception: - continue - - return val - - def _get_docstring(self): - """ Generates the docstring for this input field """ - import textwrap - - valid_vals = getattr(self, "valid_vals", None) - - if valid_vals is None: - - dtypes = getattr(self, "dtype", "") - if dtypes is None: - dtypes = "" - - if not isinstance(dtypes, tuple): - dtypes = (dtypes,) - - vals_help = " or ".join([getattr(dtype, "__name__", str(dtype)) for dtype in dtypes]) - - else: - vals_help = '{' + ', '.join(valid_vals) + '}' - - help_message = getattr(self, "help", "") - tw = textwrap.TextWrapper(width=70, initial_indent="\t", subsequent_indent="\t") - help_message = tw.fill(help_message) - - doc = f'{self.key}: {vals_help}{"," if vals_help else ""} optional\n{help_message}' - - return doc - - def _raise_type_error(self, val): - raise TypeError(f"{self.__class__.__name__} received input of type {type(val)}: {val}") diff --git a/src/sisl/viz/_plotables.py b/src/sisl/viz/_plotables.py index c49fc3a2dc..12107edf4a 100644 --- a/src/sisl/viz/_plotables.py +++ b/src/sisl/viz/_plotables.py @@ -4,28 +4,53 @@ """ This file provides tools to handle plotability of objects """ +import inspect +from collections import ChainMap +from typing import Sequence, Type + from sisl._dispatcher import AbstractDispatch, ClassDispatcher, ObjectDispatcher -__all__ = ["register_plotable"] +__all__ = ["register_plotable", "register_data_source", "register_sile_method"] class ClassPlotHandler(ClassDispatcher): """Handles all plotting possibilities for a class""" - def __init__(self, *args, **kwargs): + def __init__(self, cls, *args, inherited_handlers = (), **kwargs): + self._cls = cls if not "instance_dispatcher" in kwargs: kwargs["instance_dispatcher"] = ObjectPlotHandler kwargs["type_dispatcher"] = None - super().__init__(*args, **kwargs) + super().__init__(*args, inherited_handlers=inherited_handlers, **kwargs) + + self._dispatchs = ChainMap(self._dispatchs, *[handler._dispatchs for handler in inherited_handlers]) + + def set_default(self, key: str): + """Sets the default plotting function for the class.""" + if key not in self._dispatchs: + raise KeyError(f"Cannot set {key} as default since it is not registered.") + self._default = key + class ObjectPlotHandler(ObjectDispatcher): """Handles all plotting possibilities for an object.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self._default is not None: + default_call = getattr(self, self._default) + + self.__doc__ = default_call.__doc__ + self.__signature__ = default_call.__signature__ + def __call__(self, *args, **kwargs): """If the plot handler is called, we will run the default plotting function unless the keyword method has been passed.""" - return getattr(self, kwargs.pop("method", self._default) or self._default)(*args, **kwargs) + if self._default is None: + raise TypeError(f"No default plotting function has been defined for {self._obj.__class__.__name__}.") + return getattr(self, self._default)(*args, **kwargs) class PlotDispatch(AbstractDispatch): @@ -50,17 +75,16 @@ def create_plot_dispatch(function, name): return type( f"Plot{name.capitalize()}Dispatch", (PlotDispatch, ), - {"_plot": staticmethod(function), "__doc__": function.__doc__} + {"_plot": staticmethod(function), "__doc__": function.__doc__, "__signature__": inspect.signature(function)} ) -def _get_plotting_func(PlotClass, setting_key): - """ - Generates a plotting function for an object. +def _get_plotting_func(plot_cls, setting_key): + """Generates a plotting function for an object. Parameters ----------- - PlotClass: child of Plot + plot_cls: subclass of Plot the plot class that you want to use to plot the object. setting_key: str the setting where the plotable should go @@ -75,21 +99,25 @@ def _get_plotting_func(PlotClass, setting_key): """ def _plot(obj, *args, **kwargs): - return PlotClass(*args, **{setting_key: obj, **kwargs}) + return plot_cls(*args, **{setting_key: obj, **kwargs}) - _plot.__doc__ = f"""Builds a {PlotClass.__name__} by setting the value of "{setting_key}" to the current object. - - Apart from this specific parameter ,it accepts the same arguments as {PlotClass.__name__}. - - Documentation for {PlotClass.__name__} - ------------- + _plot.__doc__ = f"""Builds a {plot_cls.__name__} by setting the value of "{setting_key}" to the current object. - {PlotClass.__doc__} + Documentation for {plot_cls.__name__} + =========== + {inspect.cleandoc(plot_cls.__doc__) if plot_cls.__doc__ is not None else None} """ + + sig = inspect.signature(plot_cls) + + # The signature will be the same as the plot class, but without the setting key, which + # will be added by the _plot function + _plot.__signature__ = sig.replace(parameters=[p for p in sig.parameters.values() if p.name != setting_key]) + return _plot -def register_plotable(plotable, PlotClass=None, setting_key=None, plotting_func=None, name=None, default=False, plot_handler_attr='plot', engine=None, **kwargs): +def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=None, name=None, default=False, plot_handler_attr='plot', **kwargs): """ Makes the sisl.viz module aware of which sisl objects can be plotted and how to do it. @@ -101,11 +129,11 @@ def register_plotable(plotable, PlotClass=None, setting_key=None, plotting_func= plotable: any any class or object that you want to make plotable. Note that, if it's an object, the plotting capabilities will be attributed to all instances of the object's class. - PlotClass: child of sisl.Plot, optional + plot_cls: child of sisl.Plot, optional The class of the Plot that we want this object to use. setting_key: str, optional The key of the setting where the object must go. This works together with - the PlotClass parameter. + the plot_cls parameter. plotting_func: function the function that takes care of the plotting. It should accept (self, *args, **kwargs) and return a plot object. @@ -121,30 +149,226 @@ def register_plotable(plotable, PlotClass=None, setting_key=None, plotting_func= the attribute where the plot handler is or should be located in the class that you want to register. """ - # If no plotting function is provided, we will try to create one by using the PlotClass + # If no plotting function is provided, we will try to create one by using the plot_cls # and the setting_key that have been provided if plotting_func is None: - plotting_func = _get_plotting_func(PlotClass, setting_key) + plotting_func = _get_plotting_func(plot_cls, setting_key) - if name is None: + if name is None and plot_cls is not None: # We will take the name of the plot class as the name - name = PlotClass.suffix() + name = plot_cls.plot_class_key() # Check if we already have a plot_handler - plot_handler = plotable.__dict__.get(plot_handler_attr, None) + plot_handler = getattr(plotable, plot_handler_attr, None) # If it's the first time that the class is being registered, # let's give the class a plot handler - if not isinstance(plot_handler, ClassPlotHandler): + if not isinstance(plot_handler, ClassPlotHandler) or plot_handler._cls is not plotable: + + if isinstance(plot_handler, ClassPlotHandler): + inherited_handlers = [plot_handler] + else: + inherited_handlers = [] # If the user is passing an instance, we get the class if not isinstance(plotable, type): plotable = type(plotable) - setattr(plotable, plot_handler_attr, ClassPlotHandler(plot_handler_attr)) + setattr(plotable, plot_handler_attr, ClassPlotHandler(plotable, plot_handler_attr, inherited_handlers=inherited_handlers)) plot_handler = getattr(plotable, plot_handler_attr) plot_dispatch = create_plot_dispatch(plotting_func, name) # Register the function in the plot_handler plot_handler.register(name, plot_dispatch, default=default, **kwargs) + +def register_data_source( + data_source_cls, plot_cls, setting_key, name=None, default: Sequence[Type] = [], plot_handler_attr='plot', + data_source_init_kwargs: dict = {}, + **kwargs +): + + plot_cls_params = { + name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) + for name, param in inspect.signature(plot_cls).parameters.items() if name != setting_key + } + + for plotable, cls_method in data_source_cls.new.dispatcher.registry.items(): + + func = cls_method.__get__(None, data_source_cls) + + signature = inspect.signature(func) + + register_this = True + for k in data_source_init_kwargs.keys(): + if k not in signature.parameters: + register_this = False + break + + if not register_this: + continue + + new_parameters = [] + data_args = [] + replaced_data_args = {} + data_var_kwarg = None + for param in list(signature.parameters.values())[1:]: + if param.kind == param.VAR_KEYWORD: + data_var_kwarg = param.name + replaced_data_args[f'data_{param.name}'] = param.name + param = param.replace(name=f'data_{param.name}', kind=param.KEYWORD_ONLY, default={}) + elif param.name in plot_cls_params: + replaced_data_args[f'data_{param.name}'] = param.name + param = param.replace(name=f'data_{param.name}') + + data_args.append(param.name) + new_parameters.append(param) + + new_parameters.extend(list(plot_cls_params.values())) + + signature = signature.replace(parameters=new_parameters) + + params_info = { + "data_args": data_args, + "replaced_data_args": replaced_data_args, + "data_var_kwarg": data_var_kwarg, + "plot_var_kwarg": new_parameters[-1].name if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD else None + } + + def _plot(obj, *args, __params_info=params_info, __signature=signature, **kwargs): + sig = __signature + params_info = __params_info + + bound = sig.bind_partial(**kwargs) + + try: + data_kwargs = {} + for k in params_info['data_args']: + if k not in bound.arguments: + continue + + data_key = params_info['replaced_data_args'].get(k, k) + if params_info['data_var_kwarg'] == data_key: + data_kwargs.update(bound.arguments[k]) + else: + data_kwargs[data_key] = bound.arguments.pop(k) + except Exception as e: + raise TypeError(f"Error while parsing arguments to create the {data_source_cls.__name__}") + + for k, v in data_source_init_kwargs.items(): + if k not in data_kwargs: + data_kwargs[k] = v + + data = data_source_cls.new(obj, *args, **data_kwargs) + + plot_kwargs = bound.arguments.pop(params_info['plot_var_kwarg'], {}) + + return plot_cls(**{setting_key: data, **bound.arguments, **plot_kwargs}) + + _plot.__signature__ = signature + doc = f"Read data into {data_source_cls.__name__} and create a {plot_cls.__name__} from it.\n\n" + + doc += "This function accepts the arguments for creating both the data source and the plot. The following"\ + " arguments of the data source have been renamed so that they don't clash with the plot arguments:\n" + \ + '\n'.join( f' - {v} -> {k}' for k, v in replaced_data_args.items()) + \ + f"\n\nDocumentation for the {data_source_cls.__name__} creator ({func.__name__})"\ + f"\n=============\n{inspect.cleandoc(func.__doc__) if func.__doc__ is not None else None}"\ + f"\n\nDocumentation for {plot_cls.__name__}:"\ + f"\n=============\n{inspect.cleandoc(plot_cls.__doc__) if plot_cls.__doc__ is not None else None}" + + _plot.__doc__ = doc + + try: + this_default = plotable in default + except: + this_default = False + + try: + register_plotable( + plotable, plot_cls=plot_cls, + plotting_func=_plot, name=name, default=this_default, plot_handler_attr=plot_handler_attr, + **kwargs + ) + except TypeError: + pass + +def register_sile_method(sile_cls, method: str, plot_cls, setting_key, name=None, default=False, plot_handler_attr='plot', **kwargs): + + plot_cls_params = { + name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) + for name, param in inspect.signature(plot_cls).parameters.items() if name != setting_key + } + + func = getattr(sile_cls, method) + + signature = inspect.signature(getattr(sile_cls, method)) + + new_parameters = [] + data_args = [] + replaced_data_args = {} + data_var_kwarg = None + for param in list(signature.parameters.values())[1:]: + if param.kind == param.VAR_KEYWORD: + data_var_kwarg = param.name + replaced_data_args[param.name] = f'data_{param.name}' + param = param.replace(name=f'data_{param.name}', kind=param.KEYWORD_ONLY, default={}) + elif param.name in plot_cls_params: + replaced_data_args[param.name] = f'data_{param.name}' + param = param.replace(name=f'data_{param.name}') + + data_args.append(param.name) + new_parameters.append(param) + + new_parameters.extend(list(plot_cls_params.values())) + + params_info = { + "data_args": data_args, + "replaced_data_args": replaced_data_args, + "data_var_kwarg": data_var_kwarg, + "plot_var_kwarg": new_parameters[-1].name if len(new_parameters) > 0 and new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD else None + } + + signature = signature.replace(parameters=new_parameters) + + def _plot(obj, *args, **kwargs): + + bound = signature.bind_partial(**kwargs) + + try: + data_kwargs = {} + for k in params_info['data_args']: + if k not in bound.arguments: + continue + + data_key = params_info['replaced_data_args'].get(k, k) + if params_info['data_var_kwarg'] == data_key: + data_kwargs.update(bound.arguments[k]) + else: + data_kwargs[data_key] = bound.arguments.pop(k) + except: + raise TypeError(f"Error while parsing arguments to create the call {method}") + + data = func(obj, *args, **data_kwargs) + + plot_kwargs = bound.arguments.pop(params_info['plot_var_kwarg'], {}) + + return plot_cls(**{setting_key: data, **bound.arguments, **plot_kwargs}) + + _plot.__signature__ = signature + doc = f"Calls {method} and creates a {plot_cls.__name__} from its output.\n\n" + + doc += f"This function accepts the arguments both for calling {method} and creating the plot. The following"\ + f" arguments of {method} have been renamed so that they don't clash with the plot arguments:\n" + \ + '\n'.join( f' - {k} -> {v}' for k, v in replaced_data_args.items()) + \ + f"\n\nDocumentation for {method} "\ + f"\n=============\n{inspect.cleandoc(func.__doc__) if func.__doc__ is not None else None}"\ + f"\n\nDocumentation for {plot_cls.__name__}:"\ + f"\n=============\n{inspect.cleandoc(plot_cls.__doc__) if plot_cls.__doc__ is not None else None}" + + _plot.__doc__ = doc + + register_plotable( + sile_cls, plot_cls=plot_cls, + plotting_func=_plot, name=name, default=default, plot_handler_attr=plot_handler_attr, + **kwargs + ) \ No newline at end of file diff --git a/src/sisl/viz/_plotables_register.py b/src/sisl/viz/_plotables_register.py index ee051359f2..d66e150971 100644 --- a/src/sisl/viz/_plotables_register.py +++ b/src/sisl/viz/_plotables_register.py @@ -8,66 +8,55 @@ """ import sisl import sisl.io.siesta as siesta -import sisl.io.tbtrans as tbtrans +# import sisl.io.tbtrans as tbtrans from sisl.io.sile import BaseSile, get_siles -from ._plotables import register_plotable -from .plot import Plot +from ._plotables import register_data_source, register_plotable, register_sile_method +from .data import * from .plots import * -from .plotutils import get_plot_classes + +# from .old_plot import Plot +# from .plotutils import get_plot_classes + __all__ = [] + # ----------------------------------------------------- -# Register plotable siles +# Register data sources # ----------------------------------------------------- -register = register_plotable - -for GridSile in get_siles(attrs=["read_grid"]): - register(GridSile, GridPlot, 'grid_file', default=True) - -for GeomSile in get_siles(attrs=["read_geometry"]): - register(GeomSile, GeometryPlot, 'geom_file', default=True) - register(GeomSile, BondLengthMap, 'geom_file') +# This will automatically register as plotable everything that +# the data source can digest -for HSile in get_siles(attrs=["read_hamiltonian"]): - register(HSile, WavefunctionPlot, 'H', default=HSile != siesta.fdfSileSiesta) - register(HSile, PdosPlot, "H") - register(HSile, BandsPlot, "H") - register(HSile, FatbandsPlot, "H") +register_data_source(PDOSData, PdosPlot, "pdos_data", default=[siesta.pdosSileSiesta]) +register_data_source(BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta]) +register_data_source(BandsData, FatbandsPlot, "bands_data", data_source_init_kwargs={"extra_vars": ("norm2", )}) +register_data_source(EigenstateData, WavefunctionPlot, "eigenstate", default=[sisl.EigenstateElectron]) -for cls in get_plot_classes(): - register(siesta.fdfSileSiesta, cls, "root_fdf", overwrite=True) +# ----------------------------------------------------- +# Register plotable siles +# ----------------------------------------------------- -# register(siesta.outSileSiesta, ForcesPlot, 'out_file', default=True) +register = register_plotable -register(siesta.bandsSileSiesta, BandsPlot, 'bands_file', default=True) -register(siesta.bandsSileSiesta, FatbandsPlot, 'bands_file') +for GeomSile in get_siles(attrs=["read_geometry"]): + register_sile_method(GeomSile, "read_geometry", GeometryPlot, 'geometry') -register(siesta.pdosSileSiesta, PdosPlot, 'pdos_file', default=True) -register(tbtrans.tbtncSileTBtrans, PdosPlot, 'tbt_out', default=True) +for GridSile in get_siles(attrs=["read_grid"]): + register_sile_method(GridSile, "read_grid", GridPlot, 'grid', default=True) -# ----------------------------------------------------- -# Register plotable sisl objects -# ----------------------------------------------------- +# # ----------------------------------------------------- +# # Register plotable sisl objects +# # ----------------------------------------------------- -# Geometry +# # Geometry register(sisl.Geometry, GeometryPlot, 'geometry', default=True) -register(sisl.Geometry, BondLengthMap, 'geometry') -# Grid +# # Grid register(sisl.Grid, GridPlot, 'grid', default=True) -# Hamiltonian -register(sisl.Hamiltonian, WavefunctionPlot, 'H', default=True) -register(sisl.Hamiltonian, PdosPlot, "H") -register(sisl.Hamiltonian, BandsPlot, "H") -register(sisl.Hamiltonian, FatbandsPlot, "H") - -# Band structure -register(sisl.BandStructure, BandsPlot, "band_structure", default=True) -register(sisl.BandStructure, FatbandsPlot, "band_structure") +# Brilloiun zone +register(sisl.BrillouinZone, SitesPlot, 'sites_obj') -# Eigenstate -register(sisl.EigenstateElectron, WavefunctionPlot, 'eigenstate', default=True) +sisl.BandStructure.plot.set_default("bands") diff --git a/src/sisl/viz/_presets.py b/src/sisl/viz/_presets.py index f05f3a377a..bbd04afda8 100644 --- a/src/sisl/viz/_presets.py +++ b/src/sisl/viz/_presets.py @@ -1,6 +1,8 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. +"""Presets are not used for now, but I would like to use them again at some point""" + __all__ = ["add_presets", "get_preset"] PRESETS = { diff --git a/src/sisl/viz/_shortcuts.py b/src/sisl/viz/_shortcuts.py deleted file mode 100644 index 0541138fe8..0000000000 --- a/src/sisl/viz/_shortcuts.py +++ /dev/null @@ -1,138 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from functools import partial - -import numpy as np - -__all__ = ["ShortCutable"] - - -class ShortCutable: - """ - Class that adds hot key functionality to those that inherit from it. - - Shortcuts help quickly executing common actions without needing to click - or write a line of code. - - They are supported both in the GUI and in the jupyter notebook. - - """ - - def __init__(self, *args, **kwargs): - self._shortcuts = {} - - super().__init__(*args, **kwargs) - - def shortcut(self, keys): - """ - Gets the dict that represents a shortcut. - - Parameters - ----------- - keys: str - the sequence of keys that trigger the shortcut. - """ - return self._shortcuts.get(keys, None) - - def add_shortcut(self, _keys, _name, func, *args, _description=None, **kwargs): - """ - Makes a new shortcut available to the instance. - - You will see that argument names here are marked with "_", in an - attempt to avoid interfering with the action function's arguments. - - Parameters - ----------- - _keys: str - the sequence of keys that trigger the shortcut (e.g.: ctrl+e). - _name: str - a short name for the shortcut that should give a first idea of what the shortcut does. - func: function - the function to execute when the shortcut is called. - *args: - positional arguments that go to the function's execution. - _description: str, optional - a longer description of what the shortcut does, maybe including some tips and gotcha's. - **kwargs: - keyword arguments that go to the function's execution. - """ - self._shortcuts[_keys] = { - "name": _name, - "description": _description, - "action": partial(func, *args, **kwargs) - } - - def remove_shortcut(self, keys): - """ - Unregisters a given shortcut. - - Parameters - ------------ - keys: str - the sequence of keys that trigger the shortcut. - """ - if keys in self._shortcuts: - del self._shortcuts[keys] - - def call_shortcut(self, keys, *args, **kwargs): - """ - Programatic way to call a shortcut. - - In fact, this is the method that is executed when a keypress is detected - in the GUI or the jupyter notebook. - - Parameters - ----------- - keys: str - the sequence of keys that trigger the shortcut. - *args and **kwargs: - extra arguments that you pass to the function call. - """ - self._shortcuts[keys]["action"](*args, **kwargs) - - return self - - def has_shortcut(self, keys): - """ - Checks if a shortcut is already registered. - - Parameters - ----------- - keys: str - the sequence of keys that trigger the shortcut. - """ - return keys in self._shortcuts - - @property - def shortcuts_for_json(self): - """ - Returns a jsonifiable object with information of the shortcuts - - This is meant to be passed to the GUI, so that it knows which shortcuts are available. - """ - #Basically we are going to remove the action - return {key: {key: val for key, val in info.items() if key != 'action'} for key, info in self._shortcuts.items()} - - def shortcuts_summary(self, format="str"): - """ - Gets a formatted summary of the shortcuts. - """ - if format == "str": - return "\n".join([f'{key}: {shortcut["name"]}' for key, shortcut in self._shortcuts.items()]) - elif format == "html": - summ = "Available keyboard shortcuts:
" - - def get_shortcut_div(key, shortcut): - - key_span = "".join([f'{key}' for key in key.split()]) - - name_span = f'{shortcut["name"]}' - - description_div = f'
{shortcut["description"] or ""}
' - - return f'
{key_span}{name_span}{description_div}
' - - summ += "".join([get_shortcut_div(key, shortcut) for key, shortcut in self._shortcuts.items()]) - - return f'
{summ}
' diff --git a/src/sisl/viz/_single_dispatch.py b/src/sisl/viz/_single_dispatch.py new file mode 100644 index 0000000000..d74ee76528 --- /dev/null +++ b/src/sisl/viz/_single_dispatch.py @@ -0,0 +1,15 @@ +# This is a single dispatch method that works with class methods that have annotations. +from functools import singledispatchmethod as real_singledispatchmethod + + +class singledispatchmethod(real_singledispatchmethod): + def register(self, cls, method=None): + if hasattr(cls, '__func__'): + setattr(cls, '__annotations__', cls.__func__.__annotations__) + return self.dispatcher.register(cls, func=method) + + def __get__(self, obj, cls=None): + _method = super().__get__(obj, cls) + _method.dispatcher = self.dispatcher + return _method + \ No newline at end of file diff --git a/src/sisl/viz/splot.py b/src/sisl/viz/_splot.py similarity index 98% rename from src/sisl/viz/splot.py rename to src/sisl/viz/_splot.py index 9a4f00a430..7bf5ec744b 100644 --- a/src/sisl/viz/splot.py +++ b/src/sisl/viz/_splot.py @@ -3,6 +3,9 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. """ Easy plotting from the command line. + +NOT FUNCTIONAL WITH THE NEW REFACTORING OF SISL.VIZ, MUST UPDATE +(AND PROBABLY MERGE INTO A ROOT sisl CLI) """ import argparse @@ -13,7 +16,7 @@ from sisl.utils import cmd from ._user_customs import PLOTS_FILE, PRESETS_FILE, PRESETS_VARIABLE -from .plot import Plot +from .old_plot import Plot from .plotutils import find_plotable_siles, get_avail_presets, get_plot_classes __all__ = ["splot"] diff --git a/src/sisl/viz/_xarray_accessor.py b/src/sisl/viz/_xarray_accessor.py new file mode 100644 index 0000000000..f543a0e922 --- /dev/null +++ b/src/sisl/viz/_xarray_accessor.py @@ -0,0 +1,63 @@ +"""This module creates the sisl accessor in xarray to facilitate operations +on scientifically meaningful indices.""" + +import functools +import inspect + +import xarray as xr + +from .figure import get_figure +from .plotters.xarray import draw_xarray_xy +from .processors.atom import reduce_atom_data +from .processors.orbital import reduce_orbital_data, split_orbitals +from .processors.xarray import group_reduce + + +def wrap_accessor_method(fn): + @functools.wraps(fn) + def _method(self, *args, **kwargs): + return fn(self._obj, *args, **kwargs) + + return _method + +def plot_xy(*args, backend: str ="plotly", **kwargs): + + plot_actions = draw_xarray_xy(*args, **kwargs) + + return get_figure(plot_actions=plot_actions, backend=backend) + +sig = inspect.signature(draw_xarray_xy) +plot_xy.__signature__ = sig.replace(parameters=[ + *sig.parameters.values(), + inspect.Parameter("backend", inspect.Parameter.KEYWORD_ONLY, default="plotly") +]) + +@xr.register_dataarray_accessor("sisl") +class SislAccessorDataArray: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + group_reduce = wrap_accessor_method(group_reduce) + + reduce_orbitals = wrap_accessor_method(reduce_orbital_data) + + split_orbitals = wrap_accessor_method(split_orbitals) + + reduce_atoms = wrap_accessor_method(reduce_atom_data) + + plot_xy = wrap_accessor_method(plot_xy) + +@xr.register_dataset_accessor("sisl") +class SislAccessorDataset: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + group_reduce = wrap_accessor_method(group_reduce) + + reduce_orbitals = wrap_accessor_method(reduce_orbital_data) + + split_orbitals = wrap_accessor_method(split_orbitals) + + reduce_atoms = wrap_accessor_method(reduce_atom_data) + + plot_xy = wrap_accessor_method(plot_xy) \ No newline at end of file diff --git a/src/sisl/viz/backends/__init__.py b/src/sisl/viz/backends/__init__.py deleted file mode 100644 index 3ffa19f8c2..0000000000 --- a/src/sisl/viz/backends/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import importlib - -__all__ = ["load_backend", "load_backends"] - - -def load_backend(backend): - """ Load backend from this module level - - Parameters - ---------- - backend : str - name of backend to load - - Raises - ------ - ModuleNotFoundError - """ - importlib.import_module(f".{backend}", __name__) - - -def load_backends(): - """ Loads all available backends from this module level - - Will *not* raise any errors. - """ - for backend in ("templates", "plotly", "matplotlib", "blender"): - try: - load_backend(backend) - except ModuleNotFoundError: - pass diff --git a/src/sisl/viz/backends/_plot_backends.py b/src/sisl/viz/backends/_plot_backends.py deleted file mode 100644 index d17c8a6779..0000000000 --- a/src/sisl/viz/backends/_plot_backends.py +++ /dev/null @@ -1,109 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -__all__ = [] - - -class Backends: - """The backends manager for a plot class""" - - def __init__(self, plot_cls): - self._backends = {} - self._template = None - self._children = [] - - self._cls = plot_cls - - self._cls._backend = None - - def register(self, backend_name, backend, default=False): - """Register a new backend to the available backends. - - Note that if there is a template registered, you can only register backends that - inherit from that template, otherwise a `TypeError` will be raised. - - Parameters - ----------- - backend_name: str - The name of the backend being registered. Users will need to pass this value - in order to choose this backend. - backend: Backend - The backend class to be registered - default: bool, optional - Whether this backend should be the default one. - """ - if self._template is not None: - if not issubclass(backend, self._template): - raise TypeError(f"Error registering '{backend_name}': Backends for {self._cls.__name__} should inherit from {self._template.__name__}") - - # Update the options of the backend setting - backend_param = self._cls.get_class_param("backend") - backend_param.options = [*backend_param.get_options(raw=True), {"label": backend_name, "value": backend_name}] - if backend_param.default is None or default: - backend_param.default = backend_name - - backend._backend_name = backend_name - self._backends[backend_name] = backend - - for child in self._children: - child.register(backend_name, backend, default=default) - - def setup(self, plot, backend_name): - """Sets up the backend for a given plot. - - Note that if the current backend of the plot is already `backend_name`, then nothing is done. - Also, if the requested `backend_name` is not available, a `NotImplementedError` is raised. - Parameters - ----------- - plot: Plot - The plot for which we want to set up a backend. - backend_name: str - The name of the backend we want to initialize. - """ - current_backend = getattr(plot, "_backend", None) - if current_backend is None or current_backend._backend_name != backend_name: - if backend_name not in self._backends: - raise NotImplementedError(f"There is no '{backend_name}' backend implemented for {self._cls.__name__} or the backend has not been loaded.") - plot._backend = self._backends[backend_name](plot) - - def register_template(self, template): - """Sets a template that all registered backends have to satisfy. - - That is, any backend that you want to register here needs to inherit from this template. - - Parameters - ----------- - template: Backend - The backend class that should be used as a template. - """ - self._template = template - for child in self._children: - child.register_template(template) - - def register_child(self, child): - """Registers a backend manager to follow this one. - - This is useful if an extension of a plot class can use exactly the same - backends. - - Parameters - ----------- - child: Backends - The backends manager that you want to make follow this one. - - Examples - ----------- - `WavefunctionPlot` is an extension of `GridPlot`, but it can use the same - backends. - - >>> GridPlot.backends.register_child(WavefunctionPlot.backends) - - will make the backends registered at `GridPlot` automatically available for `WavefunctionPlot`. - Note that the opposite is not True, so you can register wavefunction specific backends without - problem. - """ - self._children.append(child) - - @property - def options(self): - return list(self._backends) diff --git a/src/sisl/viz/backends/blender/__init__.py b/src/sisl/viz/backends/blender/__init__.py deleted file mode 100644 index bfd4bc3246..0000000000 --- a/src/sisl/viz/backends/blender/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -r"""Blender -========== - -Blender is an open source general use 3D software. - -In science, we can profit from its excellent features to generate very nice images! - -Currently, the following plots have a blender drawing backend implemented: - GridPlot -""" - -import bpy - -from ._helpers import delete_all_objects -from ._plots import * -from .backend import BlenderBackend, BlenderMultiplePlotBackend diff --git a/src/sisl/viz/backends/blender/_helpers.py b/src/sisl/viz/backends/blender/_helpers.py deleted file mode 100644 index c795eb98dc..0000000000 --- a/src/sisl/viz/backends/blender/_helpers.py +++ /dev/null @@ -1,10 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import bpy - - -def delete_all_objects(): - """Deletes all objects present in the scene""" - bpy.ops.object.select_all(action='SELECT') - bpy.ops.object.delete(use_global=False, confirm=False) diff --git a/src/sisl/viz/backends/blender/_plots/__init__.py b/src/sisl/viz/backends/blender/_plots/__init__.py deleted file mode 100644 index 83551fc702..0000000000 --- a/src/sisl/viz/backends/blender/_plots/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .geometry import BlenderGeometryBackend -from .grid import BlenderGridBackend diff --git a/src/sisl/viz/backends/blender/_plots/geometry.py b/src/sisl/viz/backends/blender/_plots/geometry.py deleted file mode 100644 index 8eb73070b1..0000000000 --- a/src/sisl/viz/backends/blender/_plots/geometry.py +++ /dev/null @@ -1,99 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import bpy - -from ....plots import GeometryPlot -from ...templates import GeometryBackend -from ..backend import BlenderBackend - -__all__ = ["BlenderGeometryBackend"] - - -def add_atoms_frame(ani_objects, child_objects, frame): - """Creates the frames for a child plot atoms. - - Given the objects of the Atoms collection in the animation, it uses - the corresponding atoms in the child to set keyframes. - - Parameters - ----------- - ani_objects: CollectionObjects - the objects of the Atoms collection in the animation. - child_objects: CollectionObjects - the objects of the Atoms collection in the child plot. - frame: int - the frame number to which the keyframe values should be set. - """ - # Loop through all objects in the collections - for ani_obj, child_obj in zip(ani_objects, child_objects): - # Set the atom position - ani_obj.location = child_obj.location - ani_obj.keyframe_insert(data_path="location", frame=frame) - - # Set the atom size - ani_obj.scale = child_obj.scale - ani_obj.keyframe_insert(data_path="scale", frame=frame) - - # Set the atom color and opacity - ani_mat_inputs = ani_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs - child_mat_inputs = child_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs - - for input_key in ("Base Color", "Alpha"): - ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value - ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) - - -class BlenderGeometryBackend(BlenderBackend, GeometryBackend): - - _animatable_collections = { - **BlenderBackend._animatable_collections, - "Atoms": {"add_frame": add_atoms_frame}, - "Unit cell": BlenderBackend._animatable_collections["Lines"] - } - - def draw_1D(self, backend_info, **kwargs): - raise NotImplementedError("A way of drawing 1D geometry representations is not implemented for blender") - - def draw_2D(self, backend_info, **kwargs): - raise NotImplementedError("A way of drawing 2D geometry representations is not implemented for blender") - - def _draw_single_atom_3D(self, xyz, size, color="gray", name=None, opacity=1, vertices=15, **kwargs): - - try: - atom = self._template_atom.copy() - atom.data = self._template_atom.data.copy() - except Exception: - bpy.ops.surface.primitive_nurbs_surface_sphere_add(radius=1, enter_editmode=False, align='WORLD') - self._template_atom = bpy.context.object - atom = self._template_atom - bpy.context.collection.objects.unlink(atom) - - atom.location = xyz - atom.scale = (size, size, size) - - # Link the atom to the atoms collection - atoms_col = self.get_collection("Atoms") - atoms_col.objects.link(atom) - - atom.name = name - atom.data.name = name - - self._color_obj(atom, color, opacity=opacity) - - def _draw_bonds_3D(self, *args, line=None, **kwargs): - # Multiply the width of the bonds to 0.2, otherwise they look gigantic. - line = line or {} - line["width"] = 0.2 * line.get("width", 1) - # And call the method to draw bonds (which will use self.draw_line3D) - collection = self.get_collection("Bonds") - super()._draw_bonds_3D(*args, line=line, collection=collection, **kwargs) - - def _draw_cell_3D_box(self, *args, width=None, **kwargs): - width = width or 0.1 - # This method is only defined to provide a better default for the width in blender - # otherwise it looks gigantic, as the bonds - collection = self.get_collection("Unit cell") - super()._draw_cell_3D_box(*args, width=width, collection=collection, **kwargs) - -GeometryPlot.backends.register("blender", BlenderGeometryBackend) diff --git a/src/sisl/viz/backends/blender/_plots/grid.py b/src/sisl/viz/backends/blender/_plots/grid.py deleted file mode 100644 index 1c6f05940c..0000000000 --- a/src/sisl/viz/backends/blender/_plots/grid.py +++ /dev/null @@ -1,43 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import bpy - -from ....plots.grid import GridPlot -from ...templates import GridBackend -from ..backend import BlenderBackend - - -class BlenderGridBackend(BlenderBackend, GridBackend): - - def draw_3D(self, backend_info, **kwargs): - - col = self.get_collection("Grid") - - for isosurf in backend_info["isosurfaces"]: - - mesh = bpy.data.meshes.new(isosurf["name"]) - - obj = bpy.data.objects.new(mesh.name, mesh) - - col.objects.link(obj) - #bpy.context.view_layer.objects.active = obj - - edges = [] - mesh.from_pydata(isosurf["vertices"], edges, isosurf["faces"].tolist()) - - self._color_obj(obj, isosurf["color"], isosurf['opacity']) - - # mat = bpy.data.materials.new("material") - # mat.use_nodes = True - - # color = self._to_rgb_color(isosurf["color"]) - - # if color is not None: - # mat.node_tree.nodes["Principled BSDF"].inputs[0].default_value = (*color, 1) - - # mat.node_tree.nodes["Principled BSDF"].inputs[19].default_value = isosurf["opacity"] - - # mesh.materials.append(mat) - -GridPlot.backends.register("blender", BlenderGridBackend) diff --git a/src/sisl/viz/backends/blender/backend.py b/src/sisl/viz/backends/blender/backend.py deleted file mode 100644 index c4cd1cc75a..0000000000 --- a/src/sisl/viz/backends/blender/backend.py +++ /dev/null @@ -1,249 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import bpy -import numpy as np - -from ...plot import Animation, MultiplePlot -from ..templates.backend import AnimationBackend, Backend, MultiplePlotBackend - - -def add_line_frame(ani_objects, child_objects, frame): - """Creates the frames for a child plot lines. - - Given the objects of the lines collection in the animation, it uses - the corresponding lines in the child to set keyframes. - - Parameters - ----------- - ani_objects: CollectionObjects - the objects of the Atoms collection in the animation. - child_objects: CollectionObjects - the objects of the Atoms collection in the child plot. - frame: int - the frame number to which the keyframe values should be set. - """ - # Loop through all objects in the collections - for ani_obj, child_obj in zip(ani_objects, child_objects): - # Each curve object has multiple splines - for ani_spline, child_spline in zip(ani_obj.data.splines, child_obj.data.splines): - # And each spline has multiple points - for ani_point, child_point in zip(ani_spline.bezier_points, child_spline.bezier_points): - # Set the position of that point - ani_point.co = child_point.co - ani_point.keyframe_insert(data_path="co", frame=frame) - - # Loop through all the materials that the object might have associated - for ani_material, child_material in zip(ani_obj.data.materials, child_obj.data.materials): - ani_mat_inputs = ani_material.node_tree.nodes["Principled BSDF"].inputs - child_mat_inputs = child_material.node_tree.nodes["Principled BSDF"].inputs - - for input_key in ("Base Color", "Alpha"): - ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value - ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) - - -class BlenderBackend(Backend): - """Generic backend for the blender framework. - - This is the first experiment with it, so it is quite simple. - - Everything is drawn in the same scene. On initialization, a collections - dictionary is started. The keys should be the local name of a collection - in the backend environment and the values are the actual collections. - Plots should try to organize the items they draw in collections. However, - as said before, this is just a proof of concept. - """ - - _animatable_collections = { - "Lines": {"add_frame": add_line_frame}, - } - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # This is the collection that will store everything related to the plot. - self._collection = bpy.data.collections.new(f"sislplot_{self._plot.id}") - self._collections = {} - - def draw_on(self, figure): - self._plot.get_figure(backend=self._backend_name, clear_fig=False) - - def clear(self): - """ Clears the blender scene so that data can be reset""" - - for key, collection in self._collections.items(): - - for obj in collection.objects: - bpy.data.objects.remove(obj, do_unlink=True) - - bpy.data.collections.remove(collection) - - self._collections = {} - - def get_collection(self, key): - if key not in self._collections: - self._collections[key] = bpy.data.collections.new(key) - self._collection.children.link(self._collections[key]) - - return self._collections[key] - - def draw_line3D(self, x, y, z, line={}, name="", collection=None, **kwargs): - """Draws a line using a bezier curve.""" - if collection is None: - collection = self.get_collection("Lines") - # First, generate the curve object - bpy.ops.curve.primitive_bezier_curve_add() - # Then get it from the context - curve_obj = bpy.context.object - # And give it a name - if name is None: - name = "" - curve_obj.name = name - - # Link the curve to our collection (remove it from the context one) - context_col = bpy.context.collection - if context_col is not collection: - context_col.objects.unlink(curve_obj) - collection.objects.link(curve_obj) - - # Retrieve the curve from the object - curve = curve_obj.data - # And modify some attributes to make it look cylindric - curve.dimensions = '3D' - curve.fill_mode = 'FULL' - width = line.get("width") - curve.bevel_depth = width if width is not None else 0.1 - curve.bevel_resolution = 10 - # Clear all existing splines from the curve, as we are going to add them - curve.splines.clear() - - xyz = np.array([x, y, z], dtype=float).T - - # To be compatible with other frameworks such as plotly and matplotlib, - # we allow x, y and z to contain None values that indicate discontinuities - # E.g.: x=[0, 1, None, 2, 3] means we should draw a line from 0 to 1 and another - # from 2 to 3. - # Here, we get the breakpoints (i.e. indices where there is a None). We add - # -1 and None at the sides o facilitate iterating. - breakpoint_indices = [-1, *np.where(np.isnan(xyz).any(axis=1))[0], None] - - # Now loop through all segments using the known breakpoints - for start_i, end_i in zip(breakpoint_indices, breakpoint_indices[1:]): - # Get the coordinates of the segment - segment_xyz = xyz[start_i+1: end_i] - - # If there is nothing to draw, go to next segment - if len(segment_xyz) == 0: - continue - - # Create a new spline (within the curve, we are not creating a new object!) - segment = curve.splines.new("BEZIER") - # Splines by default have only 1 point, add as many as we need - segment.bezier_points.add(len(segment_xyz) - 1) - # Assign the coordinates to each point - segment.bezier_points.foreach_set('co', np.ravel(segment_xyz)) - - # We want linear interpolation between points. If we wanted cubic interpolation, - # we would set this parameter to 3, for example. - segment.resolution_u = 1 - - # Give a color to our new curve object if it needs to be colored. - self._color_obj(curve_obj, line.get("color", None), line.get("opacity", 1)) - - return self - - @staticmethod - def _to_rgb_color(color): - - if isinstance(color, str): - try: - import matplotlib.colors - - color = matplotlib.colors.to_rgb(color) - except ModuleNotFoundError: - raise ValueError("Blender does not understand string colors."+ - "Please provide the color in rgb (tuple of length 3, values from 0 to 1) or install matplotlib so that we can convert it." - ) - - return color - - @classmethod - def _color_obj(cls, obj, color, opacity=1): - """Utiity method to quickly color a given object. - - Parameters - ----------- - obj: blender Object - object to be colored - color: str or array-like of shape (3,) - color, it is converted to rgb using `matplotlib.colors.to_rgb` - opacity: - the opacity that should be given to the object. It doesn't - work currently. - """ - color = cls._to_rgb_color(color) - - if color is not None: - mat = bpy.data.materials.new("material") - mat.use_nodes = True - - BSDF_inputs = mat.node_tree.nodes["Principled BSDF"].inputs - - BSDF_inputs["Base Color"].default_value = (*color, 1) - BSDF_inputs["Alpha"].default_value = opacity - - obj.active_material = mat - - def show(self, *args, **kwargs): - bpy.context.scene.collection.children.link(self._collection) - - -class BlenderMultiplePlotBackend(MultiplePlotBackend, BlenderBackend): - - def draw(self, backend_info): - children = backend_info["children"] - # Start assigning each plot to a position of the layout - for child in children: - self._draw_child_in_scene(child) - - def _draw_child_in_ax(self, child): - child.get_figure(clear_fig=False) - - -class BlenderAnimationBackend(BlenderBackend, AnimationBackend): - - def draw(self, backend_info): - - # Get the collections that make sense to implement. This property is defined - # in each backend. See for example BlenderGeometryBackend - animatable_collections = backend_info["children"][0]._animatable_collections - # Get the number of frames that should be interpolated between two animation frames. - interpolated_frames = backend_info["interpolated_frames"] - - # Iterate over all collections - for key, animate_config in animatable_collections.items(): - - # Get the collection in the animation's instance - collection = self.get_collection(key) - # Copy all the objects from first child's collection - for obj in backend_info["children"][0].get_collection(key).objects: - new_obj = obj.copy() - new_obj.data = obj.data.copy() - # Some objects don't have materials associated. - try: - new_obj.data.materials[0] = obj.data.materials[0].copy() - except Exception: - pass - collection.objects.link(new_obj) - - # Loop over all child plots - for i_plot, plot in enumerate(backend_info["children"]): - # Calculate the frame number - frame = i_plot * interpolated_frames - # Ask the provided function to build the keyframes. - animate_config["add_frame"](collection.objects, plot.get_collection(key).objects, frame=frame) - - -Animation.backends.register("blender", BlenderAnimationBackend) -MultiplePlot.backends.register("blender", BlenderMultiplePlotBackend) diff --git a/src/sisl/viz/backends/matplotlib/__init__.py b/src/sisl/viz/backends/matplotlib/__init__.py deleted file mode 100644 index 6566d18bf8..0000000000 --- a/src/sisl/viz/backends/matplotlib/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -r"""Matplotlib -========== - -Implementations of the sisl-provided matplotlib backends. - -""" -import matplotlib - -from ._plots import * -from .backend import ( - MatplotlibBackend, - MatplotlibMultiplePlotBackend, - MatplotlibSubPlotsBackend, -) diff --git a/src/sisl/viz/backends/matplotlib/_plots/__init__.py b/src/sisl/viz/backends/matplotlib/_plots/__init__.py deleted file mode 100644 index c5e4646eac..0000000000 --- a/src/sisl/viz/backends/matplotlib/_plots/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .bands import MatplotlibBandsBackend -from .bond_length import MatplotlibBondLengthMapBackend -from .fatbands import MatplotlibFatbandsBackend -from .geometry import MatplotlibGeometryBackend -from .grid import MatplotlibGridBackend -from .pdos import MatplotlibPDOSBackend diff --git a/src/sisl/viz/backends/matplotlib/_plots/bands.py b/src/sisl/viz/backends/matplotlib/_plots/bands.py deleted file mode 100644 index 5c8099d879..0000000000 --- a/src/sisl/viz/backends/matplotlib/_plots/bands.py +++ /dev/null @@ -1,74 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np -from matplotlib.collections import LineCollection -from matplotlib.pyplot import Normalize - -from ....plots import BandsPlot -from ...templates import BandsBackend -from ..backend import MatplotlibBackend - - -class MatplotlibBandsBackend(MatplotlibBackend, BandsBackend): - - _axes_defaults = { - 'xlabel': 'K', - 'ylabel': 'Energy [eV]' - } - - def _init_ax(self): - super()._init_ax() - self.axes.grid(axis="x") - - def draw_bands(self, filtered_bands, spin_texture, **kwargs): - - if spin_texture["show"]: - # Create the normalization for the colorscale of spin_moments. - self._spin_texture_norm = Normalize(spin_texture["values"].min(), spin_texture["values"].max()) - self._spin_texture_colorscale = spin_texture["colorscale"] - - super().draw_bands(filtered_bands=filtered_bands, spin_texture=spin_texture, **kwargs) - - if spin_texture["show"]: - # Add the colorbar for spin texture. - self.figure.colorbar(self._colorbar) - - # Add the ticks - tick_vals = getattr(filtered_bands, "ticks", None) - if tick_vals is not None: - self.axes.set_xticks(tick_vals) - tick_labels = getattr(filtered_bands, "ticklabels", None) - if tick_labels is not None: - self.axes.set_xticklabels(tick_labels) - # Set the limits - self.axes.set_xlim(*filtered_bands.k.values[[0, -1]]) - self.axes.set_ylim(filtered_bands.min(), filtered_bands.max()) - - def _draw_spin_textured_band(self, x, y, spin_texture_vals=None, **kwargs): - # This is heavily based on - # https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html - - points = np.array([x, y]).T.reshape(-1, 1, 2) - segments = np.concatenate([points[:-1], points[1:]], axis=1) - - lc = LineCollection(segments, cmap=self._spin_texture_colorscale, norm=self._spin_texture_norm) - - # Set the values used for colormapping - lc.set_array(spin_texture_vals) - lc.set_linewidth(kwargs["line"].get("width", 1)) - self._colorbar = self.axes.add_collection(lc) - - def draw_gap(self, ks, Es, color, name, **kwargs): - - name = f"{name} ({Es[1] - Es[0]:.2f} eV)" - gap = self.axes.plot( - ks, Es, color=color, marker=".", label=name - ) - - self.axes.legend(gap, [name]) - - def _test_is_gap_drawn(self): - return self.axes.lines[-1].get_label().startswith("Gap") - -BandsPlot.backends.register("matplotlib", MatplotlibBandsBackend) diff --git a/src/sisl/viz/backends/matplotlib/_plots/bond_length.py b/src/sisl/viz/backends/matplotlib/_plots/bond_length.py deleted file mode 100644 index 393769c5d0..0000000000 --- a/src/sisl/viz/backends/matplotlib/_plots/bond_length.py +++ /dev/null @@ -1,22 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ....plots import BondLengthMap -from ...templates import BondLengthMapBackend -from .geometry import MatplotlibGeometryBackend - - -class MatplotlibBondLengthMapBackend(BondLengthMapBackend, MatplotlibGeometryBackend): - - def draw_2D(self, backend_info, **kwargs): - self._colorscale = None - if "bonds_coloraxis" in backend_info: - self._colorscale = backend_info["bonds_coloraxis"]["colorscale"] - - super().draw_2D(backend_info, **kwargs) - - def _draw_bonds_2D_multi_color_size(self, *args, **kwargs): - kwargs["colorscale"] = self._colorscale - super()._draw_bonds_2D_multi_color_size(*args, **kwargs) - -BondLengthMap.backends.register("matplotlib", MatplotlibBondLengthMapBackend) diff --git a/src/sisl/viz/backends/matplotlib/_plots/fatbands.py b/src/sisl/viz/backends/matplotlib/_plots/fatbands.py deleted file mode 100644 index c2759e8815..0000000000 --- a/src/sisl/viz/backends/matplotlib/_plots/fatbands.py +++ /dev/null @@ -1,18 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ....plots import FatbandsPlot -from ...templates import FatbandsBackend -from .bands import MatplotlibBandsBackend - - -class MatplotlibFatbandsBackend(MatplotlibBandsBackend, FatbandsBackend): - - def _draw_band_weights(self, x, y, weights, name, color, is_group_first): - - self.axes.fill_between( - x, y + weights, y - weights, - color=color, label=name - ) - -FatbandsPlot.backends.register("matplotlib", MatplotlibFatbandsBackend) diff --git a/src/sisl/viz/backends/matplotlib/_plots/geometry.py b/src/sisl/viz/backends/matplotlib/_plots/geometry.py deleted file mode 100644 index 182796dbf7..0000000000 --- a/src/sisl/viz/backends/matplotlib/_plots/geometry.py +++ /dev/null @@ -1,35 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from collections.abc import Iterable - -import numpy as np - -from ....plots import GeometryPlot -from ...templates import GeometryBackend -from ..backend import MatplotlibBackend - - -class MatplotlibGeometryBackend(MatplotlibBackend, GeometryBackend): - - def draw_1D(self, backend_info, **kwargs): - super().draw_1D(backend_info, **kwargs) - - self.axes.set_xlabel(backend_info["axes_titles"]["xaxis"]) - self.axes.set_ylabel(backend_info["axes_titles"]["yaxis"]) - - def draw_2D(self, backend_info, **kwargs): - super().draw_2D(backend_info, **kwargs) - - self.axes.set_xlabel(backend_info["axes_titles"]["xaxis"]) - self.axes.set_ylabel(backend_info["axes_titles"]["yaxis"]) - self.axes.axis("equal") - - def _draw_atoms_2D_scatter(self, *args, **kwargs): - kwargs["zorder"] = 2.1 - super()._draw_atoms_2D_scatter(*args, **kwargs) - - def draw_3D(self, backend_info): - return NotImplementedError(f"3D geometry plots are not implemented by {self.__class__.__name__}") - -GeometryPlot.backends.register("matplotlib", MatplotlibGeometryBackend) diff --git a/src/sisl/viz/backends/matplotlib/_plots/grid.py b/src/sisl/viz/backends/matplotlib/_plots/grid.py deleted file mode 100644 index bd0ca6a2c9..0000000000 --- a/src/sisl/viz/backends/matplotlib/_plots/grid.py +++ /dev/null @@ -1,62 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import matplotlib.pyplot as plt - -from ....plots.grid import GridPlot -from ...templates import GridBackend -from ..backend import MatplotlibBackend - - -class MatplotlibGridBackend(MatplotlibBackend, GridBackend): - - def draw_1D(self, backend_info, **kwargs): - super().draw_1D(backend_info, **kwargs) - - self.axes.set_xlabel(backend_info["axes_titles"]["xaxis"]) - self.axes.set_ylabel(backend_info["axes_titles"]["yaxis"]) - - def draw_2D(self, backend_info, **kwargs): - - # Define the axes values - x = backend_info["x"] - y = backend_info["y"] - - extent = [x[0], x[-1], y[0], y[-1]] - - # Draw the values of the grid - self.axes.imshow( - backend_info["values"], vmin=backend_info["cmin"], vmax=backend_info["cmax"], - label=backend_info["name"], cmap=backend_info["colorscale"], extent=extent, - origin="lower" - ) - - # Draw the isocontours - for contour in backend_info["contours"]: - self.axes.plot( - contour["x"], contour["y"], - color=contour["color"], - alpha=contour["opacity"], - label=contour["name"] - ) - - self.axes.set_xlabel(backend_info["axes_titles"]["xaxis"]) - self.axes.set_ylabel(backend_info["axes_titles"]["yaxis"]) - - def draw_3D(self, backend_info, **kwargs): - # This will basically raise the NotImplementedError - super().draw_3D(backend_info, **kwargs) - - # The following code is just here as reference of how this MIGHT - # be done in matplotlib. - self.figure = plt.figure() - self.axes = self.figure.add_subplot(projection="3d") - - for isosurf in backend_info["isosurfaces"]: - - x, y, z = isosurf["vertices"].T - I, J, K = isosurf["faces"].T - - self.axes.plot_trisurf(x, y, z, linewidth=0, antialiased=True) - -GridPlot.backends.register("matplotlib", MatplotlibGridBackend) diff --git a/src/sisl/viz/backends/matplotlib/_plots/pdos.py b/src/sisl/viz/backends/matplotlib/_plots/pdos.py deleted file mode 100644 index 7cebdbebb8..0000000000 --- a/src/sisl/viz/backends/matplotlib/_plots/pdos.py +++ /dev/null @@ -1,23 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ....plots import PdosPlot -from ...templates import PdosBackend -from ..backend import MatplotlibBackend - - -class MatplotlibPDOSBackend(MatplotlibBackend, PdosBackend): - - _axes_defaults = { - 'xlabel': 'Density of states [1/eV]', - 'ylabel': 'Energy [eV]' - } - - def draw_PDOS_lines(self, backend_info): - super().draw_PDOS_lines(backend_info) - - Es = backend_info["Es"] - self.axes.set_ylim(min(Es), max(Es)) - - -PdosPlot.backends.register("matplotlib", MatplotlibPDOSBackend) diff --git a/src/sisl/viz/backends/matplotlib/backend.py b/src/sisl/viz/backends/matplotlib/backend.py deleted file mode 100644 index 7bcd037083..0000000000 --- a/src/sisl/viz/backends/matplotlib/backend.py +++ /dev/null @@ -1,149 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import itertools - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.axes import Axes - -from sisl.messages import warn - -from ...plot import MultiplePlot, Plot, SubPlots -from ..templates.backend import Backend, MultiplePlotBackend, SubPlotsBackend - - -class MatplotlibBackend(Backend): - """Generic backend for the matplotlib framework. - - On initialization, `matplotlib.pyplot.subplots` is called and the figure and and - axes obtained are stored under `self.figure` and `self.axeses`, respectively. - If an attribute is not found on the backend, it is looked for - in the axes. - - On initialization, we also take the class attribute `_axes_defaults` (a dictionary) - and run `self.axes.update` with those parameters. Therefore this parameter can be used - to provide default parameters for the axes. - """ - - _axes_defaults = {} - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.figure, self.axes = self._init_figure() - self._init_axes() - - def draw_on(self, axes, axes_indices=None): - """Draws this plot in a different figure. - - Parameters - ----------- - axes: Plot, PlotlyBackend or matplotlib.axes.Axes - The axes to draw this plot in. - """ - if isinstance(axes, Plot): - axes = axes._backend.axes - elif isinstance(axes, MatplotlibBackend): - axes = axes.axes - - if axes_indices is not None: - axes = axes[axes_indices] - - if not isinstance(axes, Axes): - raise TypeError(f"{self.__class__.__name__} was provided a {axes.__class__.__name__} to draw on.") - - self_axes = self.axes - self.axes = axes - self._init_axes() - self._plot.get_figure(backend=self._backend_name, clear_fig=False) - self.axes = self_axes - - def _init_figure(self): - """Initializes the matplotlib figure and axes - - Returns - -------- - Figure: - the matplotlib figure of this plot. - Axes: - the matplotlib axes of this plot. - """ - return plt.subplots() - - def _init_axes(self): - """Does some initial modification on the axes.""" - self.axes.update(self._axes_defaults) - - def __getattr__(self, key): - if key != "axes": - return getattr(self.axes, key) - raise AttributeError(key) - - def clear(self, layout=False): - """ Clears the plot canvas so that data can be reset - - Parameters - -------- - layout: boolean, optional - whether layout should also be deleted - """ - if layout: - self.axes.clear() - - for artist in self.axes.lines + self.axes.collections: - artist.remove() - - return self - - def get_ipywidget(self): - return self.figure - - def show(self, *args, **kwargs): - return self.figure.show(*args, **kwargs) - - # Methods for testing - def _test_number_of_items_drawn(self): - return len(self.axes.lines + self.axes.collections) - - def draw_line(self, x, y, name=None, line={}, marker={}, text=None, **kwargs): - return self.axes.plot(x, y, color=line.get("color"), linewidth=line.get("width", 1), markersize=marker.get("size"), label=name) - - def draw_scatter(self, x, y, name=None, marker={}, text=None, **kwargs): - try: - return self.axes.scatter(x, y, c=marker.get("color"), s=marker.get("size", 1), cmap=marker.get("colorscale"), alpha=marker.get("opacity"), label=name, **kwargs) - except TypeError as e: - if str(e) == "alpha must be a float or None": - warn(f"Your matplotlib version doesn't support multiple opacity values, please upgrade to >=3.4 if you want to use opacity.") - return self.axes.scatter(x, y, c=marker.get("color"), s=marker.get("size", 1), cmap=marker.get("colorscale"), label=name, **kwargs) - else: - raise e - - -class MatplotlibMultiplePlotBackend(MatplotlibBackend, MultiplePlotBackend): - pass - - -class MatplotlibSubPlotsBackend(MatplotlibMultiplePlotBackend, SubPlotsBackend): - - def draw(self, backend_info): - children = backend_info["children"] - rows, cols = backend_info["rows"], backend_info["cols"] - - self.figure, self.axes = plt.subplots(rows, cols) - - # Normalize the axes array to have two dimensions - if rows == 1 and cols == 1: - self.axes = np.array([[self.axes]]) - elif rows == 1: - self.axes = np.expand_dims(self.axes, axis=0) - elif cols == 1: - self.axes = np.expand_dims(self.axes, axis=1) - - indices = itertools.product(range(rows), range(cols)) - # Start assigning each plot to a position of the layout - for (row, col), child in zip(indices, children): - self.draw_other_plot(child, axes_indices=(row, col)) - -MultiplePlot.backends.register("matplotlib", MatplotlibMultiplePlotBackend) -SubPlots.backends.register("matplotlib", MatplotlibSubPlotsBackend) diff --git a/src/sisl/viz/backends/plotly/__init__.py b/src/sisl/viz/backends/plotly/__init__.py deleted file mode 100644 index 715faa5fe8..0000000000 --- a/src/sisl/viz/backends/plotly/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -r"""Plotly -========== - -Plotly is a backend that provides expert plotting utilities using `plotly`. -It features a rich set of settings enabling fine-tuning of many parameters. - - GeometryPlot - BandsPlot - FatbandsPlot - PdosPlot - BondLengthMap - ForcesPlot - GridPlot - WavefunctionPlot - -""" -import plotly - -from ._plots import * -from ._templates import * -from .backend import ( - PlotlyAnimationBackend, - PlotlyBackend, - PlotlyMultiplePlotBackend, - PlotlySubPlotsBackend, -) diff --git a/src/sisl/viz/backends/plotly/_express.py b/src/sisl/viz/backends/plotly/_express.py deleted file mode 100644 index 4276d19f8b..0000000000 --- a/src/sisl/viz/backends/plotly/_express.py +++ /dev/null @@ -1,60 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" This file implements a smooth interface between sisl and plotly express, to make visualization of sisl objects even easier - -This goes hand by hand with the implementation of dataframe extraction in sisl -objects, which is not already implemented (https://github.com/zerothi/sisl/issues/220) -""" -from functools import wraps - -import plotly.express as px - -from sisl._dispatcher import AbstractDispatch - -__all__ = ["sx"] - - -class WithSislManagement(AbstractDispatch): - def __init__(self, px): - self._obj = px - - def dispatch(self, method): - """ Wraps the methods of the object to preprocess the inputs """ - @wraps(method) - def with_sisl_support(*args, **kwargs): - - if args: - # Try to generate the dataframe for this object. - if hasattr(args[0], 'to_df'): - args[0] = args[0].to_df() - else: - # Otherwise, we are just going to interpret it as if the user wants to get the attributes - # of the object. We will support deep attribute getting here using points as separators. - # (I don't know if this makes sense because there's probably hardly any attributes that are - # ready to be plotted, i.e. they are 1d arrays) - obj = args.pop(0) - for key, val in kwargs.items(): - if isinstance(val, str): - attrs = val.split('.') - search_obj = obj - - # We try to recursively get the attributes - for attr in attrs: - newval = getattr(obj, attr, None) - if newval is None: - break - search_obj = newval - - else: - # If we've gotten to the end of the loop, it is because we've found the attribute. - val = newval - - # Replace the provided string by the actual value of the attribute - kwargs[key] = val - - return method(*args, **kwargs) - - return with_sisl_support - -sx = WithSislManagement(px) diff --git a/src/sisl/viz/backends/plotly/_plots/__init__.py b/src/sisl/viz/backends/plotly/_plots/__init__.py deleted file mode 100644 index 7f88d2e315..0000000000 --- a/src/sisl/viz/backends/plotly/_plots/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .bands import PlotlyBandsBackend -from .bond_length import PlotlyBondLengthMapBackend -from .fatbands import PlotlyFatbandsBackend -from .geometry import PlotlyGeometryBackend -from .grid import PlotlyGridBackend -from .pdos import PlotlyPDOSBackend diff --git a/src/sisl/viz/backends/plotly/_plots/bands.py b/src/sisl/viz/backends/plotly/_plots/bands.py deleted file mode 100644 index 01aaeadbef..0000000000 --- a/src/sisl/viz/backends/plotly/_plots/bands.py +++ /dev/null @@ -1,72 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from ....plots import BandsPlot -from ...templates import BandsBackend -from ..backend import PlotlyBackend - - -class PlotlyBandsBackend(PlotlyBackend, BandsBackend): - - _layout_defaults = { - 'xaxis_title': 'K', - 'xaxis_mirror': True, - 'yaxis_mirror': True, - 'xaxis_showgrid': True, - 'yaxis_title': 'Energy [eV]' - } - - def draw_bands(self, filtered_bands, spin_texture, **kwargs): - super().draw_bands(filtered_bands=filtered_bands, spin_texture=spin_texture, **kwargs) - - # Add the ticks - tickvals = getattr(filtered_bands, "ticks", None) - # We need to convert tick values to a list, otherwise sometimes plotly fails to display them - self.figure.layout.xaxis.tickvals = list(tickvals) if tickvals is not None else None - self.figure.layout.xaxis.ticktext = getattr(filtered_bands, "ticklabels", None) - self.figure.layout.yaxis.range = [filtered_bands.min(), filtered_bands.max()] - self.figure.layout.xaxis.range = filtered_bands.k.values[[0, -1]] - - # If we are showing spin textured bands, customize the colorbar - if spin_texture["show"]: - self.layout.coloraxis.colorbar = {"title": f"Spin texture ({str(spin_texture['values'].axis.item())})"} - self.update_layout(coloraxis = {"cmin": spin_texture["values"].min().item(), "cmax": spin_texture["values"].max().item(), "colorscale": spin_texture["colorscale"]}) - - def _draw_band(self, x, y, *args, **kwargs): - kwargs = { - "hovertemplate": '%{y:.2f} eV', - "hoverinfo": "name", - **kwargs - } - return super()._draw_band(x, y, *args, **kwargs) - - def _draw_spin_textured_band(self, *args, spin_texture_vals=None, **kwargs): - kwargs.update({ - "mode": "markers", - "marker": {"color": spin_texture_vals, "size": kwargs["line"]["width"], "showscale": True, "coloraxis": "coloraxis"}, - "hovertemplate": '%{y:.2f} eV (spin moment: %{marker.color:.2f})', - "showlegend": False - }) - return self._draw_band(*args, **kwargs) - - def draw_gap(self, ks, Es, color, name, **kwargs): - - self.add_trace({ - 'type': 'scatter', - 'mode': 'lines+markers+text', - 'x': ks, - 'y': Es, - 'text': [f'Gap: {Es[1] - Es[0]:.3f} eV', ''], - 'marker': {'color': color}, - 'line': {'color': color}, - 'name': name, - 'textposition': 'top right', - **kwargs - }) - - def _test_is_gap_drawn(self): - return len([True for trace in self.figure.data if trace.name == "Gap"]) > 0 - -BandsPlot.backends.register("plotly", PlotlyBandsBackend) diff --git a/src/sisl/viz/backends/plotly/_plots/bond_length.py b/src/sisl/viz/backends/plotly/_plots/bond_length.py deleted file mode 100644 index 0f4b1ca5cb..0000000000 --- a/src/sisl/viz/backends/plotly/_plots/bond_length.py +++ /dev/null @@ -1,25 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ....plots import BondLengthMap -from ...templates import BondLengthMapBackend -from .geometry import PlotlyGeometryBackend - - -class PlotlyBondLengthMapBackend(BondLengthMapBackend, PlotlyGeometryBackend): - - def draw_2D(self, backend_info, **kwargs): - super().draw_2D(backend_info, **kwargs) - self._setup_coloraxis(backend_info) - - def draw_3D(self, backend_info, **kwargs): - super().draw_3D(backend_info, **kwargs) - self._setup_coloraxis(backend_info) - - def _setup_coloraxis(self, backend_info): - if "bonds_coloraxis" in backend_info: - self.update_layout(coloraxis=backend_info["bonds_coloraxis"]) - - self.update_layout(legend_orientation='h') - -BondLengthMap.backends.register("plotly", PlotlyBondLengthMapBackend) diff --git a/src/sisl/viz/backends/plotly/_plots/fatbands.py b/src/sisl/viz/backends/plotly/_plots/fatbands.py deleted file mode 100644 index 462219cd49..0000000000 --- a/src/sisl/viz/backends/plotly/_plots/fatbands.py +++ /dev/null @@ -1,39 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from ....plots import FatbandsPlot -from ...templates import FatbandsBackend -from .bands import PlotlyBandsBackend - - -class PlotlyFatbandsBackend(PlotlyBandsBackend, FatbandsBackend): - - def draw(self, backend_info): - super().draw(backend_info) - - if backend_info["draw_bands"]["spin_texture"]["show"]: - self.update_layout(legend_orientation="h") - - def _draw_band_weights(self, x, y, weights, name, color, is_group_first): - - for i_chunk, chunk in enumerate(self._yield_band_chunks(x, y, weights)): - - # Removing the parts of the band where y is nan handles bands that - # flow outside the plot. - chunk_x, chunk_y, chunk_weights = chunk[:, ~np.isnan(chunk[1])] - - self.add_trace({ - "type": "scatter", - "mode": "lines", - "x": [*chunk_x, *reversed(chunk_x)], - "y": [*(chunk_y + chunk_weights), *reversed(chunk_y - chunk_weights)], - "line": {"width": 0, "color": color}, - "showlegend": is_group_first and i_chunk == 0, - "name": name, - "legendgroup": name, - "fill": "toself" - }) - -FatbandsPlot.backends.register("plotly", PlotlyFatbandsBackend) diff --git a/src/sisl/viz/backends/plotly/_plots/geometry.py b/src/sisl/viz/backends/plotly/_plots/geometry.py deleted file mode 100644 index 25e6d19ad8..0000000000 --- a/src/sisl/viz/backends/plotly/_plots/geometry.py +++ /dev/null @@ -1,80 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from ....plots import GeometryPlot -from ...templates import GeometryBackend -from ..backend import PlotlyBackend - - -class PlotlyGeometryBackend(PlotlyBackend, GeometryBackend): - - _layout_defaults = { - 'xaxis_showgrid': False, - 'xaxis_zeroline': False, - 'yaxis_showgrid': False, - 'yaxis_zeroline': False, - } - - def draw_1D(self, backend_info, **kwargs): - super().draw_1D(backend_info, **kwargs) - - self.update_layout(**{f"{k}_title": v for k, v in backend_info["axes_titles"].items()}) - - def draw_2D(self, backend_info, **kwargs): - super().draw_2D(backend_info, **kwargs) - - self.update_layout(**{f"{k}_title": v for k, v in backend_info["axes_titles"].items()}) - - self.layout.yaxis.scaleanchor = "x" - self.layout.yaxis.scaleratio = 1 - - def draw_3D(self, backend_info): - self._one_atom_trace = False - - super().draw_3D(backend_info) - - self.layout.scene.aspectmode = 'data' - - def _draw_bonds_3D(self, *args, line={}, bonds_labels=None, x_labels=None, y_labels=None, z_labels=None, **kwargs): - if "hoverinfo" not in kwargs: - kwargs["hoverinfo"] = None - super()._draw_bonds_3D(*args, line=line, **kwargs) - - if bonds_labels: - self.add_trace({ - 'type': 'scatter3d', 'mode': 'markers', - 'x': x_labels, 'y': y_labels, 'z': z_labels, - 'text': bonds_labels, 'hoverinfo': 'text', - 'marker': {'size': line["width"]*3, "color": "rgba(255,255,255,0)"}, - "showlegend": False - }) - - def _draw_single_atom_3D(self, xyz, size, color="gray", name=None, group="Atoms", vertices=15, **kwargs): - - self.add_trace({ - 'type': 'mesh3d', - **{key: np.ravel(val) for key, val in GeometryPlot._sphere(xyz, size, vertices=vertices).items()}, - 'showlegend': not self._one_atom_trace, - 'alphahull': 0, - 'color': color, - 'showscale': False, - 'legendgroup': group, - 'name': name, - 'meta': ['({:.2f}, {:.2f}, {:.2f})'.format(*xyz)], - 'hovertemplate': '%{meta[0]}', - **kwargs - }) - - self._one_atom_trace = True - - def _draw_single_bond_3D(self, *args, group=None, showlegend=False, line_kwargs={}, **kwargs): - kwargs["legendgroup"] = group - kwargs["showlegend"] = showlegend - super()._draw_single_bond_3D(*args, **kwargs) - - def _draw_cell_3D_axes(self, cell, geometry, **kwargs): - return super()._draw_cell_3D_axes(cell, geometry, mode="lines+markers", **kwargs) - -GeometryPlot.backends.register("plotly", PlotlyGeometryBackend) diff --git a/src/sisl/viz/backends/plotly/_plots/grid.py b/src/sisl/viz/backends/plotly/_plots/grid.py deleted file mode 100644 index d6af14f03f..0000000000 --- a/src/sisl/viz/backends/plotly/_plots/grid.py +++ /dev/null @@ -1,75 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import plotly.graph_objects as go - -from ....plots.grid import GridPlot -from ...templates import GridBackend -from ..backend import PlotlyBackend - - -class PlotlyGridBackend(PlotlyBackend, GridBackend): - - def draw_1D(self, backend_info, **kwargs): - self.figure.layout.yaxis.scaleanchor = None - self.figure.layout.yaxis.scaleratio = None - - super().draw_1D(backend_info, **kwargs) - - self.update_layout(**{f"{k}_title": v for k, v in backend_info["axes_titles"].items()}) - - def draw_2D(self, backend_info, **kwargs): - - # Draw the heatmap - self.add_trace({ - 'type': 'heatmap', - 'name': backend_info["name"], - 'z': backend_info["values"], - 'x': backend_info["x"], - 'y': backend_info["y"], - 'zsmooth': backend_info["zsmooth"], - 'zmin': backend_info["cmin"], - 'zmax': backend_info["cmax"], - 'zmid': backend_info["cmid"], - 'colorscale': backend_info["colorscale"], - **kwargs - }) - - # Draw the isocontours - for contour in backend_info["contours"]: - self.add_scatter( - x=contour["x"], y=contour["y"], - marker_color=contour["color"], line_color=contour["color"], - opacity=contour["opacity"], - name=contour["name"] - ) - - self.update_layout(**{f"{k}_title": v for k, v in backend_info["axes_titles"].items()}) - - self.figure.layout.yaxis.scaleanchor = "x" - self.figure.layout.yaxis.scaleratio = 1 - - def draw_3D(self, backend_info, **kwargs): - - for isosurf in backend_info["isosurfaces"]: - - x, y, z = isosurf["vertices"].T - I, J, K = isosurf["faces"].T - - self.add_trace(go.Mesh3d( - x=x, y=y, z=z, - i=I, j=J, k=K, - color=isosurf["color"], - opacity=isosurf["opacity"], - name=isosurf["name"], - showlegend=True, - **kwargs - )) - - self.layout.scene = {'aspectmode': 'data'} - self.update_layout(**{f"scene_{k}_title": v for k, v in backend_info["axes_titles"].items()}) - - def _after_get_figure(self): - self.update_layout(legend_orientation='h') - -GridPlot.backends.register("plotly", PlotlyGridBackend) diff --git a/src/sisl/viz/backends/plotly/_plots/pdos.py b/src/sisl/viz/backends/plotly/_plots/pdos.py deleted file mode 100644 index 44a3a88e2a..0000000000 --- a/src/sisl/viz/backends/plotly/_plots/pdos.py +++ /dev/null @@ -1,26 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ....plots import PdosPlot -from ...templates import PdosBackend -from ..backend import PlotlyBackend - - -class PlotlyPDOSBackend(PlotlyBackend, PdosBackend): - - _layout_defaults = { - 'xaxis_title': 'Density of states [1/eV]', - 'xaxis_mirror': True, - 'yaxis_mirror': True, - 'yaxis_title': 'Energy [eV]', - 'showlegend': True - } - - def draw_PDOS_lines(self, drawer_info): - super().draw_PDOS_lines(drawer_info) - - Es = drawer_info["Es"] - self.update_layout(yaxis_range=[min(Es), max(Es)]) - - -PdosPlot.backends.register("plotly", PlotlyPDOSBackend) diff --git a/src/sisl/viz/backends/plotly/_templates.py b/src/sisl/viz/backends/plotly/_templates.py deleted file mode 100644 index a5059ee770..0000000000 --- a/src/sisl/viz/backends/plotly/_templates.py +++ /dev/null @@ -1,145 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" -Plotly templates should be defined in this file -""" -import itertools - -import plotly.graph_objs as go -import plotly.io as pio - -__all__ = ["get_plotly_template", "add_plotly_template", - "set_default_plotly_template", "available_plotly_templates"] - - -def get_plotly_template(name): - """ - Gets a plotly template from plotly global space. - - Doing `get_plotly_template(name)` is equivalent to - `plotly.io.templates[name]`. - - Parameters - ---------- - name: str - the name of the plotly template - """ - return pio.templates[name] - - -def add_plotly_template(name, template, default=False): - """ - Adds a plotly template to plotly's register. - - In this way the visualization module can profit from it. - - Parameters - ----------- - name: str - the name of the plotly template that you want to add - template: dict or plotly.graph_objs.layout.Template - the template that you want to add. - See https://plotly.com/python/templates/ to understand how they work. - default: bool, optional - whether this template should be set as the default during this runtime. - - If you want a permanent default, consider using the opportunity that 'user_customs' - gives you to customize the sisl visualization package by acting every time the - package is imported. - """ - pio.templates[name] = template - - if default: - set_default_plotly_template(name) - - return - - -def set_default_plotly_template(name): - """ - Sets a template as the default during this runtime. - - If you want a permanent default, consider using the opportunity that 'user_customs' - gives you to customize the sisl visualization package by acting every time the - package is imported. - - Parameters - ----------- - name: str - the name of the template that you want to use as default - """ - pio.templates.default = name - - -def available_plotly_templates(): - """ - Gets a list of the plotly templates that are currently available. - - Returns - --------- - list - the list with all the template's names. - """ - list(pio.templates.keys()) - -pio.templates["sisl"] = go.layout.Template( - layout={ - "plot_bgcolor": "white", - "paper_bgcolor": "white", - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "black"), ("showgrid", False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, - "hovermode": "closest", - "scene": { - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis", "zaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "black"), ("showgrid", - False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", - "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, - } - #"editrevision": True - #"title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} - }, -) - -pio.templates["sisl_dark"] = go.layout.Template( - layout={ - "plot_bgcolor": "black", - "paper_bgcolor": "black", - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "white"), ("showgrid", - False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, - "font": {'color': 'white'}, - "hovermode": "closest", - "scene": { - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis", "zaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "white"), ("showgrid", - False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", - "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, - } - #"editrevision": True - #"title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} - }, -) - -# This will be the default one for the sisl.viz.plotly module -pio.templates.default = "sisl" diff --git a/src/sisl/viz/backends/plotly/_user_customs.py b/src/sisl/viz/backends/plotly/_user_customs.py deleted file mode 100644 index 38384096b2..0000000000 --- a/src/sisl/viz/backends/plotly/_user_customs.py +++ /dev/null @@ -1,120 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import importlib -import os -import sys -from pathlib import Path - -from sisl._environ import get_environ_variable -from sisl.messages import warn - -__all__ = ["import_user_presets", "import_user_plots", - "import_user_sessions", "import_user_plugins"] - -USER_CUSTOM_FOLDER = get_environ_variable("SISL_CONFIGDIR") / "viz" / "plotly" - -# Here we let python know that there are importable files -# in USER_CUSTOM_FOLDER -sys.path.append(str(USER_CUSTOM_FOLDER.resolve())) - - -def import_user_extension(extension_file): - """ - Basis for importing users extensions. - - Parameters - ------------ - extension_file: str - the name of the file that you want to import (NOT THE FULL PATH). - """ - try: - return importlib.import_module(str(extension_file).replace(".py", "")) - except ModuleNotFoundError: - return None - -#-------------------------------------- -# Presets -#-------------------------------------- -# File where the user's presets will be searched -PRESETS_FILE_NAME = "presets.py" -PRESETS_FILE = USER_CUSTOM_FOLDER / PRESETS_FILE_NAME -# We will look for presets under this variable -PRESETS_VARIABLE = "presets" - - -def import_user_presets(): - """ - Imports the users presets. - - All the presets that the user wants to import into sisl - should be in the 'presets' variable as a dict in the 'user_presets.py' - file. Then, this method will add them to the global dictionary of presets. - """ - from ._presets import add_presets - - module = import_user_extension(PRESETS_FILE_NAME) - - # Add these presets - if module is not None: - if PRESETS_VARIABLE in vars(module): - add_presets(**vars(module)[PRESETS_VARIABLE]) - else: - warn(f"We found the custom presets file ({PRESETS_FILE}) but no '{PRESETS_VARIABLE}' variable was found.\n Please put your presets as a dict under this variable.") - - return module - -#-------------------------------------- -# Plots -#-------------------------------------- -# File where the user's plots will be searched -PLOTS_FILE_NAME = "plots.py" -PLOTS_FILE = USER_CUSTOM_FOLDER / PLOTS_FILE_NAME - - -def import_user_plots(): - """ - Imports the user's plots. - - We don't need to do anything here because all plots available - are tracked by checking the subclasses of `Plot`. - Therefore, the user only needs to make sure that their plot classes - are defined. - """ - return import_user_extension(PLOTS_FILE_NAME) - -#-------------------------------------- -# Sessions -#-------------------------------------- -# File where the user's sessions will be searched -SESSION_FILE_NAME = "sessions.py" -SESSION_FILE = USER_CUSTOM_FOLDER / SESSION_FILE_NAME - - -def import_user_sessions(): - """ - Imports the user's sessions. - - We don't need to do anything here because all sessions available - are tracked by checking the subclasses of `Session`. - Therefore, the user only needs to make sure that their session classes - are defined. - """ - return import_user_extension(SESSION_FILE_NAME) - - -#---------------------------------------- -# Plugins -#--------------------------------------- -# This is a general file that the user can have for convenience so that everytime -# that sisl is imported, it can automatically import all their utilities that they -# developed to work with sisl -PLUGINS_FILE_NAME = "plugins.py" - - -def import_user_plugins(): - """ - This imports an extra file where the user can do really anything - that they want to finish customizing the package. - """ - return import_user_extension(PLUGINS_FILE_NAME) diff --git a/src/sisl/viz/backends/plotly/backend.py b/src/sisl/viz/backends/plotly/backend.py deleted file mode 100644 index 4f34ed2cb2..0000000000 --- a/src/sisl/viz/backends/plotly/backend.py +++ /dev/null @@ -1,656 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import itertools -from collections import defaultdict -from functools import partial - -import numpy as np -import plotly.graph_objects as go -from plotly.subplots import make_subplots - -from ...plot import Animation, MultiplePlot, Plot, SubPlots -from ..templates.backend import ( - AnimationBackend, - Backend, - MultiplePlotBackend, - SubPlotsBackend, -) - - -class PlotlyBackend(Backend): - """Generic backend for the plotly framework. - - On initialization, a plotly.graph_objs.Figure object is created and stored - under `self.figure`. If an attribute is not found on the backend, it is looked for - in the figure. Therefore, you can apply all the methods that are appliable to a plotly - figure! - - On initialization, we also take the class attribute `_layout_defaults` (a dictionary) - and run `update_layout` with those parameters. - """ - - _layout_defaults = {} - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.figure = go.Figure() - self.update_layout(**self._layout_defaults) - - def __getattr__(self, key): - if key != "figure": - return getattr(self.figure, key) - raise AttributeError(key) - - def show(self, *args, **kwargs): - return self.figure.show(*args, **kwargs) - - def draw_on(self, figure): - """Draws this plot in a different figure. - - Parameters - ----------- - figure: Plot, PlotlyBackend or plotly.graph_objs.Figure - The figure to draw this plot in. - """ - if isinstance(figure, Plot): - figure = figure._backend.figure - elif isinstance(figure, PlotlyBackend): - figure = figure.figure - - if not isinstance(figure, go.Figure): - raise TypeError(f"{self.__class__.__name__} was provided a {figure.__class__.__name__} to draw on.") - - self_fig = self.figure - self.figure = figure - self._plot.get_figure(backend=self._backend_name, clear_fig=False) - self.figure = self_fig - - def clear(self, frames=True, layout=False): - """ Clears the plot canvas so that data can be reset - - Parameters - -------- - frames: boolean, optional - whether frames should also be deleted - layout: boolean, optional - whether layout should also be deleted - """ - self.figure.data = [] - - if frames: - self.figure.frames = [] - - if layout: - self.figure.layout = {} - - return self - - def get_ipywidget(self): - return go.FigureWidget(self.figure, ) - - def _update_ipywidget(self, fig_widget): - """ Updates a figure widget so that it is in sync with this plot's data - - Parameters - ---------- - fig_widget: plotly.graph_objs.FigureWidget - The figure widget that we need to extend. - """ - fig_widget.data = [] - fig_widget.add_traces(self.data) - fig_widget.layout = self.layout - fig_widget.update(frames=self.frames) - - #------------------------------------------- - # PLOT MANIPULATION METHODS - #------------------------------------------- - - def group_legend(self, by=None, names=None, show_all=False, extra_updates=None, **kwargs): - """ Joins plot traces in groups in the legend - - As the result of this method, plot traces end up with a legendgroup attribute. - You can use that for selecting traces in further processing of your plot. - - This also provides the ability to toggle the whole group from the legend, which is nice. - - Parameters - --------- - by: str or function, optional - it defines what are the criteria to group the traces. - - If it's a string: - It is the name of the trace attribute. Remember that plotly allows you to - lookup for nested attributes using underscores. E.g: "line_color" gets {line: {color: THIS VALUE}} - If it's a function: - It will recieve each trace and needs to decide which group to put it in by returning the group value. - Note that the value will also be used as the group name if `names` is not provided, so you can save yourself - some code and directly return the group's name. - If not provided: - All traces will be put in the same group - names: array-like, dict or function, optional - it defines what the names of the generated groups will be. - - If it's an array: - When a new group is found, the name will be taken from this array (order can be very arbitrary) - If it's a dict: - When a new group is found, the value of the group will be used as a key to get the name from this dictionary. - If the key is not found, the name will just be the value. - E.g.: If grouping by `line_color` and `blue` is found, the name will be `names.get('blue', 'blue')` - If it's a function: - It will recieve the group value and the trace and needs to return the name of the TRACE. - NOTE: If `show_all` is set to `True` all traces will appear in the legend, so it would be nice - to give them different names. Otherwise, you can just return the group's name. - If you provided a grouping function and `show_all` is False you don't need this, as you can return - directly the group name from there. - If not provided: - the values will be used as names. - show_all: boolean, optional - whether all the items of the group should be displayed in the legend. - If `False`, only one item per group will be displayed. - If `True`, all the items of the group will be displayed. - extra_updates: dict, optional - A dict stating extra updates that you want to do for each group. - - E.g.: `{"blue": {"line_width": 4}}` - - would also convert the lines with a group VALUE (not name) of "blue" to a width of 4. - - This is just for convenience so that you can run other methods after this one. - Note that you can always do something like this by doing - - ``` - plot.update_traces( - selector={"line_width": "blue"}, # Selects the traces that you will update - line_width=4, - ) - ``` - - If you use a function to return the group values, there is probably no point on using this - argument. Since you recieve the trace, you can run `trace.update(...)` inside your function. - **kwargs: - like extra_updates but they are passed to all groups without distinction - """ - unique_values = [] - - # Normalize the "by" parameter to a function - if by is None: - if show_all: - name = names[0] if names is not None else "Group" - self.figure.update_traces(showlegend=True, legendgroup=name, name=name) - return self - else: - func = lambda trace: 0 - if isinstance(by, str): - def func(trace): - try: - return trace[by] - except Exception: - return None - else: - func = by - - # Normalize also the names parameter to a function - if names is None: - def get_name(val, trace): - return str(val) if not show_all else f'{val}: {trace.name}' - elif callable(names): - get_name = names - elif isinstance(names, dict): - def get_name(val, trace): - name = names.get(val, val) - return str(name) if not show_all else f'{name}: {trace.name}' - else: - def get_name(val, trace): - name = names[len(unique_values) - 1] - return str(name) if not show_all else f'{name}: {trace.name}' - - # And finally normalize the extra updates - if extra_updates is None: - get_extra_updates = lambda *args, **kwargs: {} - elif isinstance(extra_updates, dict): - get_extra_updates = lambda val, trace: extra_updates.get(val, {}) - elif callable(extra_updates): - get_extra_updates = extra_updates - - # Build the function that will apply the change - def check_and_apply(trace): - - val = func(trace) - - if isinstance(val, np.ndarray): - val = val.tolist() - if isinstance(val, list): - val = ", ".join([str(item) for item in val]) - - if val in unique_values: - showlegend = show_all - else: - unique_values.append(val) - showlegend = True - - customdata = trace.customdata if trace.customdata is not None else [{}] - - trace.update( - showlegend=showlegend, - legendgroup=str(val), - name=get_name(val, trace=trace), - customdata=[{**customdata[0], "name": trace.name}, *customdata[1:]], - **get_extra_updates(val, trace=trace), - **kwargs - ) - - # And finally apply all the changes - self.figure.for_each_trace( - lambda trace: check_and_apply(trace) - ) - - return self - - def ungroup_legend(self): - """ Ungroups traces if a legend contains groups """ - self.figure.for_each_trace( - lambda trace: trace.update( - legendgroup=None, - showlegend=True, - name=trace.customdata[0]["name"] - ) - ) - - return self - - def normalize(self, min_val=0, max_val=1, axis="y", **kwargs): - """ Normalizes traces to a given range along an axis - - Parameters - ----------- - min_val: float, optional - The lower bound of the range. - max_val: float, optional - The upper part of the range - axis: {"x", "y", "z"}, optional - The axis along which we want to normalize. - **kwargs: - keyword arguments that are passed directly to plotly's Figure `for_each_trace` - method. You can check its documentation. One important thing is that you can pass a - 'selector', which will choose if the trace is updated or not. - """ - from ...plotutils import normalize_trace - - self.for_each_trace(partial(normalize_trace, min_val=min_val, max_val=max_val, axis=axis), **kwargs) - - return self - - def swap_axes(self, ax1='x', ax2='y', **kwargs): - """ Swaps two axes in the plot - - Parameters - ----------- - ax1, ax2: str, {'x', 'x*', 'y', 'y*', 'z', 'z*'} - The names of the axes that you want to swap. - **kwargs: - keyword arguments that are passed directly to plotly's Figure `for_each_trace` - method. You can check its documentation. One important thing is that you can pass a - 'selector', which will choose if the trace is updated or not. - """ - from ...plotutils import swap_trace_axes - - # Swap the traces - self.for_each_trace(partial(swap_trace_axes, ax1=ax1, ax2=ax2), **kwargs) - - # Try to also swap the axes - try: - self.update_layout({ - f'{ax1}axis': self.layout[f'{ax2}axis'].to_plotly_json(), - f'{ax2}axis': self.layout[f'{ax1}axis'].to_plotly_json(), - }, overwrite=True) - except Exception: - pass - - return self - - def shift(self, shift, axis="y", **kwargs): - """ Shifts the traces of the plot by a given value in the given axis - - Parameters - ----------- - shift: float or array-like - If it's a float, it will be a solid shift (i.e. all points moved equally). - If it's an array, an element-wise sum will be performed - axis: {"x","y","z"}, optional - The axis along which we want to shift the traces. - **kwargs: - keyword arguments that are passed directly to plotly's Figure `for_each_trace` - method. You can check its documentation. One important thing is that you can pass a - 'selector', which will choose if the trace is updated or not. - """ - from ...plotutils import shift_trace - - self.for_each_trace(partial(shift_trace, shift=shift, axis=axis), **kwargs) - - return self - - # ----------------------------- - # SOME OTHER METHODS - # ----------------------------- - - def to_chart_studio(self, *args, **kwargs): - """ Sends the plot to chart studio if it is possible - - For it to work, the user should have their credentials correctly set up. - - It is a shortcut for chart_studio.plotly.plot(self.figure, ...etc) so you can pass any extra arguments as if - you were using `py.plot` - """ - import chart_studio.plotly as py - - return py.plot(self.figure, *args, **kwargs) - - # ----------------------------- - # METHODS FOR TESTING - # ----------------------------- - - def _test_number_of_items_drawn(self): - return len(self.figure.data) - - # -------------------------------- - # METHODS TO STANDARIZE BACKENDS - # -------------------------------- - - def draw_line(self, x, y, name=None, line={}, **kwargs): - """Draws a line in the current plot.""" - opacity = kwargs.get("opacity", line.get("opacity", 1)) - self.add_trace({ - 'type': 'scatter', - 'x': x, - 'y': y, - 'mode': 'lines', - 'name': name, - 'line': {k: v for k, v in line.items() if k != "opacity"}, - 'opacity': opacity, - **kwargs, - }) - - def draw_scatter(self, x, y, name=None, marker={}, **kwargs): - self.draw_line(x, y, name, marker=marker, mode="markers", **kwargs) - - def draw_line3D(self, x, y, z, **kwargs): - self.draw_line(x, y, type="scatter3d", z=z, **kwargs) - - def draw_scatter3D(self, *args, **kwargs): - self.draw_line3D(*args, mode="markers", **kwargs) - - def draw_arrows3D(self, xyz, dxyz, arrowhead_angle=20, arrowhead_scale=0.3, **kwargs): - """Draws 3D arrows in plotly using a combination of a scatter3D and a Cone trace.""" - final_xyz = xyz + dxyz - - color = kwargs.get("line", {}).get("color") - if color is None: - color = "red" - - name = kwargs.get("name", "Arrows") - - arrows_coords = np.empty((xyz.shape[0]*3, 3), dtype=np.float64) - - arrows_coords[0::3] = xyz - arrows_coords[1::3] = final_xyz - arrows_coords[2::3] = np.nan - - conebase_xyz = xyz + (1 - arrowhead_scale) * dxyz - - self.figure.add_traces([{ - "x": arrows_coords[:, 0], - "y": arrows_coords[:, 1], - "z": arrows_coords[:, 2], - "mode": "lines", - "type": "scatter3d", - "hoverinfo": "none", - "line": {**kwargs.get("line"), "color": color, }, - "legendgroup": name, - "name": f"{name} lines", - "showlegend": False, - }, - { - "type": "cone", - "x": conebase_xyz[:, 0], - "y": conebase_xyz[:, 1], - "z": conebase_xyz[:, 2], - "u": arrowhead_scale * dxyz[:, 0], - "v": arrowhead_scale * dxyz[:, 1], - "w": arrowhead_scale * dxyz[:, 2], - "hovertemplate": "[%{u}, %{v}, %{w}]", - "sizemode": "absolute", - "sizeref": arrowhead_scale * np.linalg.norm(dxyz, axis=1).max() / 2, - "colorscale": [[0, color], [1, color]], - "showscale": False, - "legendgroup": name, - "name": name, - "showlegend": True, - }]) - - -class PlotlyMultiplePlotBackend(PlotlyBackend, MultiplePlotBackend): - pass - - -class PlotlySubPlotsBackend(PlotlyBackend, SubPlotsBackend): - - def draw(self, backend_info): - children = backend_info["children"] - rows, cols = backend_info["rows"], backend_info["cols"] - - # Check if all childplots have the same xaxis or yaxis titles. - axes_titles = defaultdict(list) - for child_plot in children: - axes_titles["x"].append(child_plot.layout.xaxis.title.text) - axes_titles["y"].append(child_plot.layout.yaxis.title.text) - - # If so, we will set the subplots figure x_title and/or y_title so that it looks cleaner. - # See how we remove the titles from the axis layout below when we allocate each plot. - axes_titles = {f"{key}_title": val[0] for key, val in axes_titles.items() if len(set(val)) == 1} - - self.figure = make_subplots(**{ - "rows": rows, "cols": cols, - **axes_titles, - **backend_info["make_subplots_kwargs"] - }) - - # Start assigning each plot to a position of the layout - for (row, col), plot in zip(itertools.product(range(1, rows + 1), range(1, cols + 1)), children): - - ntraces = len(plot.data) - - self.add_traces(plot.data, rows=[row]*ntraces, cols=[col]*ntraces) - - for ax in "x", "y": - ax_layout = getattr(plot.layout, f"{ax}axis").to_plotly_json() - - # If we have set a global title for this axis, just remove it from the plot - if axes_titles.get(f"{ax}_title"): - ax_layout["title"] = None - - update_axis = getattr(self, f"update_{ax}axes") - - update_axis(ax_layout, row=row, col=col) - - # Since we have directly copied the layouts of the child plots, there may be some references - # between axes that we need to fix. E.g.: if yaxis was set to follow xaxis in the second child plot, - # since the second child plot is put in (xaxes2, yaxes2) the reference will be now to the first child - # plot xaxis, not itself. This is best understood by printing the figure of a subplot :) - new_layouts = {} - for ax, layout in self.figure.layout.to_plotly_json().items(): - if "axis" in ax: - ax_name, ax_num = ax.split("axis") - - # Go over all possible problematic keys - for key in ["anchor", "scaleanchor"]: - val = layout.get(key) - if val in ["x", "y"]: - layout[key] = f"{val}{ax_num}" - - new_layouts[ax] = layout - - self.update_layout(**new_layouts) - - -class PlotlyAnimationBackend(PlotlyBackend, AnimationBackend): - - def draw(self, backend_info): - children = backend_info["children"] - frame_names = backend_info["frame_names"] - frames_layout = self._build_frames(children, None, frame_names) - self.update_layout(**frames_layout) - - def _build_frames(self, children, ani_method, frame_names): - """ Builds the frames of the plotly figure from the child plots' data - - It actually sets the frames of the figure. - - Returns - ----------- - dict - keys and values that need to be added to the layout - in order for frames to work. - """ - if ani_method is None: - same_traces = np.unique( - [len(plot.data) for plot in children] - ).shape[0] == 1 - - ani_method = "animate" if same_traces else "update" - - # Choose the method that we need to run in order to get the figure - if ani_method == "animate": - figure_builder = self._figure_animate_method - elif ani_method == "update": - figure_builder = self._figure_update_method - - steps, updatemenus = figure_builder(children, frame_names) - - frames_layout = { - - "sliders": [ - { - "active": 0, - "yanchor": "top", - "xanchor": "left", - "currentvalue": { - "font": {"size": 20}, - #"prefix": "Bands file:", - "visible": True, - "xanchor": "right" - }, - #"transition": {"duration": 300, "easing": "cubic-in-out"}, - "pad": {"b": 10, "t": 50}, - "len": 0.9, - "x": 0.1, - "y": 0, - "steps": steps - } - ], - - "updatemenus": updatemenus - } - - return frames_layout - - def _figure_update_method(self, children, frame_names): - """ - In the update method, we give all the traces to data, and we are just going to toggle - their visibility depending on which 'frame' needs to be displayed. - """ - # Add all the traces - for i, (frame_name, plot) in enumerate(zip(frame_names, children)): - - visible = i == 0 - - self.add_traces([{ - **trace.to_plotly_json(), - 'customdata': [{'frame': frame_name, "iFrame": i}], - 'visible': visible - } for trace in plot.data]) - - # Generate the steps - steps = [] - for i, frame_name in enumerate(frame_names): - - steps.append({ - "label": frame_name, - "method": "restyle", - "args": [{"visible": [trace.customdata[0]["iFrame"] == i for trace in self.data]}] - }) - - # WE SHOULD DEFINE PLAY AND PAUSE BUTTONS TO BE RENDERED IN JUPYTER'S NOTEBOOK HERE - # IT IS IMPOSSIBLE TO PASS CONDITIONS TO DECIDE WHAT TO DISPLAY USING PLOTLY JSON - self.animate_widgets = [] - - return steps, [] - - def _figure_animate_method(self, children, frame_names): - """ - In the animate method, we explicitly define frames, And the transition from one to the other - will be animated - """ - # Here are some things that were settings - frame_duration = 500 - redraw = True - - # Data will actually only be the first frame - self.figure.update(data=children[0].data) - - frames = [] - - maxN = np.max([len(plot.data) for plot in children]) - for frame_name, plot in zip(frame_names, children): - - data = plot.data - nTraces = len(data) - if nTraces < maxN: - nAddTraces = maxN - nTraces - data = [ - *data, *np.full(nAddTraces, {"type": "scatter", "x": [0], "y": [0], "visible": False})] - - frames = [ - *frames, {'name': frame_name, 'data': data, "layout": plot.get_settings_group("layout")}] - - self.figure.update(frames=frames) - - steps = [ - {"args": [ - [frame["name"]], - {"frame": {"duration": int(frame_duration), "redraw": redraw}, - "mode": "immediate", - "transition": {"duration": 300}} - ], - "label": frame["name"], - "method": "animate"} for frame in self.figure.frames - ] - - updatemenus = [ - - {'type': 'buttons', - 'buttons': [ - { - 'label': '▶', - 'method': 'animate', - 'args': [None, {"frame": {"duration": int(frame_duration), "redraw": True}, - "fromcurrent": True, "transition": {"duration": 100, - "easing": "quadratic-in-out"}}], - }, - - { - 'label': '⏸', - 'method': 'animate', - 'args': [[None], {"frame": {"duration": 0}, "redraw": True, - 'mode': 'immediate', - "transition": {"duration": 0}}], - } - ]} - ] - - return steps, updatemenus - -Animation.backends.register("plotly", PlotlyAnimationBackend) -MultiplePlot.backends.register("plotly", PlotlyMultiplePlotBackend) -SubPlots.backends.register("plotly", PlotlySubPlotsBackend) diff --git a/src/sisl/viz/backends/templates/__init__.py b/src/sisl/viz/backends/templates/__init__.py deleted file mode 100644 index cb48c03a9e..0000000000 --- a/src/sisl/viz/backends/templates/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ._plots import * -from .backend import AnimationBackend, Backend, MultiplePlotBackend, SubPlotsBackend diff --git a/src/sisl/viz/backends/templates/_plots/__init__.py b/src/sisl/viz/backends/templates/_plots/__init__.py deleted file mode 100644 index 218df31c50..0000000000 --- a/src/sisl/viz/backends/templates/_plots/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .bands import BandsBackend -from .bond_length import BondLengthMapBackend -from .fatbands import FatbandsBackend -from .geometry import GeometryBackend -from .grid import GridBackend -from .pdos import PdosBackend diff --git a/src/sisl/viz/backends/templates/_plots/bands.py b/src/sisl/viz/backends/templates/_plots/bands.py deleted file mode 100644 index 94df66eabb..0000000000 --- a/src/sisl/viz/backends/templates/_plots/bands.py +++ /dev/null @@ -1,120 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from abc import abstractmethod - -from ....plots import BandsPlot -from ..backend import Backend - - -class BandsBackend(Backend): - """Draws the bands provided by a `BandsPlot` - - The workflow implemented by it is as follows: - First, `self.draw_bands` draws all bands like: - for band in bands: - if (spin texture needs to be drawn): - `self._draw_spin_textured_band()`, NO GENERIC IMPLEMENTATION (optional) - else: - `self._draw_band()`, generic implementation that calls `self._draw_line` - Once all bands are drawn, `self.draw_gaps` loops through all the gaps to be drawn: - for gap in gaps: - `self.draw_gap()`, MUST BE IMPLEMENTED! - """ - - def draw(self, backend_info): - self.draw_bands(**backend_info["draw_bands"]) - - self._draw_gaps(backend_info["gaps"]) - - def draw_bands(self, filtered_bands, line, spindown_line, spin, spin_texture, add_band_data): - """ - Manages the flow of drawing all the bands - - Parameters - ----------- - filtered_bands: xarray.DataArray - The bands values, with only those bands that need to be plotted. - line: dict - The line style of the bands, as with plotly standards. - spindown_line: dict - Special styles for spin down bands. All styles not specified will be taken - from `line`. - spin: Spin - The spin class associated to the bands calculation - spin_texture: dict - Containing the keys: - - "show": bool, whether spin texture needs to be displayed - - "values": xarray.DataArray, the spin texture values that have to be displayed. - - "colorscale": str, the colorscale to use for the spin texture values. - """ - if spin_texture["show"]: - draw_band_func = self._draw_spin_textured_band - spin_moments = spin_texture["values"] - else: - draw_band_func = self._draw_band - - if "spin" not in filtered_bands.coords: - filtered_bands = filtered_bands.expand_dims("spin") - - # Now loop through all bands to draw them - for spin_bands in filtered_bands.transpose('spin', 'band', 'k'): - ispin = int(spin_bands.spin) if "spin" in spin_bands.coords else 0 - line_style = line - if ispin == 1: - line_style.update(spindown_line) - for band in spin_bands: - # Get the xy values for the band - x = band.k.values - y = band.values - kwargs = { - "name": "{} spin {}".format(band.band.values, ["up", "down"][ispin]) if spin.is_polarized else str(band.band.values), - "line": line_style, - **add_band_data(band, self._plot) - } - - # And plot it differently depending on whether we need to display spin texture or not. - if not spin_texture["show"]: - draw_band_func(x, y, **kwargs) - else: - spin_texture_vals = spin_moments.sel(band=band.band.values).values - draw_band_func(x, y, spin_texture_vals=spin_texture_vals, **kwargs) - - def _draw_band(self, *args, **kwargs): - return self.draw_line(*args, **kwargs) - - def _draw_spin_textured_band(self, *args, **kwargs): - return NotImplementedError(f"{self.__class__.__name__} doesn't implement plotting spin_textured bands.") - - def _draw_gaps(self, gaps_info): - """Iterates over all gaps to draw them""" - for gap_info in gaps_info: - self.draw_gap(**gap_info) - - @abstractmethod - def draw_gap(self, ks, Es, color, name, **kwargs): - """This method should draw a gap, given the k and E coordinates. - - The color of the line should be determined by `color`, and `name` should be used for labeling. - - Parameters - ----------- - ks: numpy array of shape (2,) - The two k coordinates of the gap. - Es: numpy array of shape (2,) - The two E coordinates of the gap, sorted from minor to major. - color: str - Color with which the gap should be drawn. - name: str - Label that should be asigned to the gap. - """ - - # Methods needed for testing - - def _test_is_gap_drawn(self): - """ - Should return `True` if the gap is currently drawn, otherwise `False`. - """ - raise NotImplementedError - -BandsPlot.backends.register_template(BandsBackend) diff --git a/src/sisl/viz/backends/templates/_plots/bond_length.py b/src/sisl/viz/backends/templates/_plots/bond_length.py deleted file mode 100644 index 094a0f7bcf..0000000000 --- a/src/sisl/viz/backends/templates/_plots/bond_length.py +++ /dev/null @@ -1,25 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ....plots import BondLengthMap -from .geometry import GeometryBackend - - -class BondLengthMapBackend(GeometryBackend): - """Draws a bond length map provided by `BondLengthMap` - - The flow is exactly the same as `GeometryPlot`, in fact this class might only be extended to - manipulate color bars or things like that. Otherwise, if you already have a `MyGeometryBackend`, - you can just create a bond length map backend like - - ``` - class MyBondLengthMapBackend(BondLengthMapBackend, MyGeometryBackend): - pass - ``` - - """ - - def draw_1D(self, backend_info, **kwargs): - return NotImplementedError("1D representations of bond length maps are not implemented") - -BondLengthMap.backends.register_template(BondLengthMapBackend) diff --git a/src/sisl/viz/backends/templates/_plots/fatbands.py b/src/sisl/viz/backends/templates/_plots/fatbands.py deleted file mode 100644 index 0953507f53..0000000000 --- a/src/sisl/viz/backends/templates/_plots/fatbands.py +++ /dev/null @@ -1,133 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from ....plots import FatbandsPlot -from .bands import BandsBackend - - -class FatbandsBackend(BandsBackend): - """Draws fatbands provided by `FatbandsPlot` - - The flow implemented by it is as follows: - First, it draws all the weights, iterating through all the requests: - for weight_request in weight_requests: - Call `self.draw_group_weights`, which loops through all the bands to be drawn: - for band in bands: - `self._draw_band_weights` - Then it just calls the `draw` method of `BandsBackend`, which takes care of the rest. See - the documentation of `BandsBackend` to understand the rest of the workflow. - - """ - - def draw(self, backend_info): - """Controls the flow for drawing Fatbands. - - It draws first all the weights and then the bands. - """ - - groups_weights = backend_info["groups_weights"] - groups_metadata = backend_info["groups_metadata"] - filtered_bands = backend_info["draw_bands"]["filtered_bands"] - - x = filtered_bands.k.values - - for group_name in groups_weights: - self.draw_group_weights( - weights=groups_weights[group_name], metadata=groups_metadata[group_name], - name=group_name, bands=filtered_bands, x=x - ) - - super().draw(backend_info) - - def draw_group_weights(self, weights, metadata, name, bands, x): - """Draws all weights for a group - - It will iterate over all the bands that need to be drawn for a certain group - and ask the backend to draw them. The backend should implement `_draw_band_weights` - as specified below. - - Parameters - ----------- - weights: xarray.DataArray with indices (spin, band, k) - Contains all the weights to be drawn. - metadata: dict - Contains extra data specifying how the contributions of a group must be drawn. - name: str - The name of the group to which the weights correspond - bands: xarray.DataArray with indices (spin, band, k) - Contains all the eigenvalues of the band structure. - """ - if "spin" not in bands.coords: - bands = bands.expand_dims("spin") - - # Find where the discontinuities are - self._discontinuities = np.where(np.isnan(x))[0] - - # Loop over spin - for ispin, spin_weights in enumerate(weights.transpose("spin", "band", "k")): - # Loop over bands - for i, band_weights in enumerate(spin_weights): - # For each band, draw the fatband - band_values = bands.sel(band=band_weights.band, spin=ispin) - - self._draw_band_weights( - x=x, y=band_values, weights=band_weights.values, - color=metadata["style"]["line"]["color"], name=name, - is_group_first=i==0 and ispin == 0 - ) - - def _draw_band_weights(self, x, y, weights, color, name, is_group_first): - """Default implementation to draw a fatband. - - It uses a scatter plot, where the size of each scatter point is proportional - to the weight. - - Parameters - ----------- - x: np.ndarray of shape (nk,) - Contains the k coordinates of the band - y: np.ndarray of shape (nk,) - Contains the energy values of the band - weights: np.ndarray of shape (nk,) - Contains the weight values for each k coordinate - color: str - The color with which the contribution must be drawn. - name: str - The name of the group to which these band weights correspond - is_group_first: bool - Whether this is the first fatband plotted for a given request. This might - be useful for grouping items drawn, for example. - """ - size = weights - size[np.isnan(size)] = 0 - self.draw_scatter(x, y, name=name, marker={"color": color, "size": weights}) - - def _yield_band_chunks(self, *arrays): - """For backends that can not handle fatbands discontinuities - out of the box, this method can be used to yield continuous chunks - to be drawn. - - Parameters - ----------- - *arrays: - All the arrays that must be divided into chunks. - They must all have the same datatype. - - Yields - ----------- - np.ndarray of shape (chunk_nk, n_arrays) - A continous chunk of data of the potentially discontinuous band. - """ - # If there are discontinuities, we need to split - chunks = np.split(np.array([*arrays]).T, self._discontinuities) - - for i_chunk, chunk in enumerate(chunks): - # Chunks other than the first one begin with np.nan (which was the signal for a discontinuity) - if i_chunk > 0: - chunk = chunk[1:] - - yield chunk.T - -FatbandsPlot.backends.register_template(FatbandsBackend) diff --git a/src/sisl/viz/backends/templates/_plots/geometry.py b/src/sisl/viz/backends/templates/_plots/geometry.py deleted file mode 100644 index 0cd59ca842..0000000000 --- a/src/sisl/viz/backends/templates/_plots/geometry.py +++ /dev/null @@ -1,334 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from collections.abc import Iterable - -import numpy as np - -from sisl.messages import warn - -from ....plots import GeometryPlot -from ..backend import Backend - - -class GeometryBackend(Backend): - """Draws the geometry as provided by `GeometryPlot`. - - Checks the dimensionality of the geometry and then calls: - - 1D case: `self.draw_1D` - - 2D case: `self.draw_2D` - - 3D case: `self.draw_3D` - - These 3 functions contain generic implementations, although some parts may need - a method to be implemented. Here are more details of each case: - - 1D workflow (`self.draw_1D`): - `self._draw_atoms_2D_scatter()`, generic implementation that calls `self.draw_scatter` - - 2D workflow (`self.draw_2D`): - if (bonds need to be drawn): - Calls `self._draw_bonds_2D()` which may call: - if (all bonds are same size and same color): - `self._draw_bonds_2D_single_color_size`, generic implementation that calls `self.draw_line` - else: - `self._draw_bonds_2D_multi_color_size`, generic implementation that calls `self.draw_scatter` - Then call `self._draw_atoms_2D_scatter()` to draw the atoms, generic implementation that calls `self.draw_scatter` - And finally draw the cell: - if (cell to be drawn as axes): - `self._draw_cell_2D_axes()`: - for axis in axes: - `self._draw_axis_2D()`, generic implementation that calls `self.draw_line` - elif (cell to be drawn as a box): - `self._draw_cell_2D_box()`, generic implementation that calls `self.draw_line` - - 3D workflow (`self.draw_3D`): - if (bonds need to be drawn): - if (all bonds are same size and same color): - `self._bonds_3D_scatter()`: - Manages all arguments and then calls `self._draw_bonds_3D`, generic implementation that uses `self.draw_line3D`. - else: - for bond in bonds: - `self._draw_single_bond_3D()`, generic implementation that uses `self.draw_line3D`. - if (atoms need to be drawn): - for atom in atoms: - `self._draw_single_atom_3D`, NOT IMPLEMENTED (optional) - And finally draw the cell: - if (cell to be drawn as axes): - `self._draw_cell_3D_axes()`, generic implementation that calls `self.draw_line3D` for each axis. - elif (cell to be drawn as a box): - `self._draw_cell_3D_box()`, generic implementation that calls `self.draw_line3D` - """ - - def draw(self, backend_info): - drawing_func = getattr(self, f"draw_{backend_info['ndim']}D") - - drawing_func(backend_info) - - def draw_1D(self, backend_info, **kwargs): - # Add the atoms - if len(backend_info["atoms_props"]["xy"]) > 0: - self._draw_atoms_2D_scatter(**{k: v for k, v in backend_info["atoms_props"].items() if k != "arrow"}) - # Now draw the arrows - for arrow_spec in backend_info["arrows"]: - is_arrow = ~np.isnan(arrow_spec["data"]).any(axis=1) - arrow_data = np.array([arrow_spec["data"][:, 0], np.zeros_like(arrow_spec["data"][:, 0])]).T - self.draw_arrows( - xy=backend_info["atoms_props"]["xy"][is_arrow], dxy=arrow_data[is_arrow]*arrow_spec["scale"], - arrowhead_angle=arrow_spec["arrowhead_angle"], arrowhead_scale=arrow_spec["arrowhead_scale"], - line={k: arrow_spec.get(k) for k in ("color", "width", "dash")}, - name=arrow_spec["name"] - ) - - def draw_2D(self, backend_info, **kwargs): - geometry = backend_info["geometry"] - xaxis = backend_info["xaxis"] - yaxis = backend_info["yaxis"] - bonds_props = backend_info["bonds_props"] - - # If there are bonds to draw, draw them - if len(bonds_props) > 0: - bonds_kwargs = {} - for k in bonds_props[0]: - if k == "xys": - new_k = k - else: - new_k = f"bonds_{k}" - bonds_kwargs[new_k] = [x[k] for x in bonds_props] - - self._draw_bonds_2D(**bonds_kwargs, points_per_bond=backend_info["points_per_bond"]) - - # Add the atoms scatter - if len(backend_info["atoms_props"]["xy"]) > 0: - self._draw_atoms_2D_scatter(**{k: v for k, v in backend_info["atoms_props"].items() if k != "arrow"}) - # Now draw the arrows from the atoms - for arrow_spec in backend_info["arrows"]: - is_arrow = ~np.isnan(arrow_spec["data"]).any(axis=1) - self.draw_arrows( - xy=backend_info["atoms_props"]["xy"][is_arrow], dxy=arrow_spec["data"][is_arrow]*arrow_spec["scale"], - arrowhead_angle=arrow_spec["arrowhead_angle"], arrowhead_scale=arrow_spec["arrowhead_scale"], - line={k: arrow_spec.get(k) for k in ("color", "width", "dash")}, - name=arrow_spec["name"] - ) - - # And finally draw the unit cell - show_cell = backend_info["show_cell"] - cell = geometry.cell - if show_cell == "axes": - self._draw_cell_2D_axes(geometry=geometry, cell=cell, xaxis=xaxis, yaxis=yaxis, line=backend_info["cell_style"]) - elif show_cell == "box": - self._draw_cell_2D_box( - geometry=geometry, cell=cell, - xaxis=xaxis, yaxis=yaxis, - line=backend_info["cell_style"] - ) - - def _draw_atoms_2D_scatter(self, xy, color="gray", size=10, name='atoms', marker_colorscale=None, opacity=None, **kwargs): - self.draw_scatter(xy[:, 0], xy[:, 1], name=name, marker={'size': size, 'color': color, 'colorscale': marker_colorscale, "opacity": opacity}, **kwargs) - - def _draw_bonds_2D(self, xys, points_per_bond=5, force_bonds_as_points=False, - bonds_color='#cccccc', bonds_width=3, bonds_opacity=1, bonds_name=None, name="bonds", **kwargs): - """ - Cheaper than _bond_trace2D because it draws all bonds in a single trace. - It is also more flexible, since it allows providing bond colors as floats that all - relate to the same colorscale. - However, the bonds are represented as dots between the two atoms (if you use enough - points per bond it almost looks like a line). - """ - # Check if we have a single style for all bonds or not. - bonds_style = {"color": bonds_color, "width": bonds_width, "opacity": bonds_opacity} - single_style = True - for k, style in bonds_style.items(): - if isinstance(style, Iterable) and not isinstance(style, str): - if np.unique(style).shape[0] == 1: - bonds_style[k] = style[0] - else: - bonds_style[k] = np.repeat(style, points_per_bond) - single_style = False - - x = [] - y = [] - text = [] - if single_style and not force_bonds_as_points: - # Then we can display this trace as lines! :) - for i, ((x1, y1), (x2, y2)) in enumerate(xys): - - x = [*x, x1, x2, None] - y = [*y, y1, y2, None] - - if bonds_name: - text = np.repeat(bonds_name, 3) - - draw_bonds_func = self._draw_bonds_2D_single_color_size - - else: - # Otherwise we will need to draw points in between atoms - # representing the bonds - for i, ((x1, y1), (x2, y2)) in enumerate(xys): - - x = [*x, *np.linspace(x1, x2, points_per_bond)] - y = [*y, *np.linspace(y1, y2, points_per_bond)] - - draw_bonds_func = self._draw_bonds_2D_multi_color_size - if bonds_name: - text = np.repeat(bonds_name, points_per_bond) - - draw_bonds_func(x, y, **bonds_style, name=name, text=text if len(text) != 0 else None, **kwargs) - - def _draw_bonds_2D_single_color_size(self, x, y, color, width, opacity, name, text, **kwargs): - self.draw_line( - x, y, name=name, line={"color": color, "width": width, "opacity": opacity}, - text=text, **kwargs - ) - - def _draw_bonds_2D_multi_color_size(self, x, y, color, width, opacity, name, text, coloraxis="coloraxis", colorscale=None, **kwargs): - self.draw_scatter( - x, y, name=name, - marker={"color": color, "size": width, "opacity": opacity, "coloraxis": coloraxis, "colorscale": colorscale}, - text=text, **kwargs - ) - - def _draw_cell_2D_axes(self, geometry, cell, xaxis="x", yaxis="y", **kwargs): - cell_xy = GeometryPlot._projected_2Dcoords(geometry, xyz=cell, xaxis=xaxis, yaxis=yaxis) - origo_xy = GeometryPlot._projected_2Dcoords(geometry, xyz=geometry.origin, xaxis=xaxis, yaxis=yaxis) - - for i, vec in enumerate(cell_xy): - x = np.array([0, vec[0]]) + origo_xy[0] - y = np.array([0, vec[1]]) + origo_xy[1] - name = f'Axis {i}' - self._draw_axis_2D(x, y, name=name, **kwargs) - - def _draw_axis_2D(self, x, y, name, **kwargs): - self.draw_line(x, y, name=name, **kwargs) - - def _draw_cell_2D_box(self, cell, geometry, xaxis="x", yaxis="y", **kwargs): - - cell_corners = GeometryPlot._get_cell_corners(cell) + geometry.origin - x, y = GeometryPlot._projected_2Dcoords(geometry, xyz=cell_corners, xaxis=xaxis, yaxis=yaxis).T - - self.draw_line(x, y, name="Unit cell", **kwargs) - - def draw_3D(self, backend_info): - - geometry = backend_info["geometry"] - bonds_props = backend_info["bonds_props"] - - # If there are bonds to draw, draw them - if len(bonds_props) > 0: - # Unless we have different bond sizes, we want to plot all bonds in the same trace - different_bond_sizes = False - if "width" in bonds_props[0]: - first_size = bonds_props[0].get("width") - for bond_prop in bonds_props: - if bond_prop.get("width") != first_size: - different_bond_sizes = True - break - - if different_bond_sizes: - for bond_props in bonds_props: - self._draw_single_bond_3D(**bond_props) - else: - bonds_kwargs = {} - for k in bonds_props[0]: - if k == "r": - v = bonds_props[0][k] - else: - v = [x[k] for x in bonds_props] - bonds_kwargs[k] = v - - draw_bonds_kwargs = self._get_draw_bonds_3D_kwargs(**bonds_kwargs) - self._draw_bonds_3D(**draw_bonds_kwargs) - - # Now draw the atoms - for i, _ in enumerate(backend_info["atoms_props"]["xyz"]): - self._draw_single_atom_3D(**{k: v[i] for k, v in backend_info["atoms_props"].items()}) - # Draw the arrows - for arrow_spec in backend_info["arrows"]: - is_arrow = ~np.isnan(arrow_spec["data"]).any(axis=1) - try: - self.draw_arrows3D( - xyz=backend_info["atoms_props"]["xyz"][is_arrow], dxyz=arrow_spec["data"][is_arrow]*arrow_spec["scale"], - arrowhead_angle=arrow_spec["arrowhead_angle"], arrowhead_scale=arrow_spec["arrowhead_scale"], - line={k: arrow_spec.get(k) for k in ("color", "width", "dash")}, - name=arrow_spec["name"] - ) - except NotImplementedError as e: - # If the arrows can not be drawn in 3D, we will just warn the user and not draw them - warn(str(e)) - break - - # And finally draw the unit cell - show_cell = backend_info["show_cell"] - cell = geometry.cell - if show_cell == "axes": - self._draw_cell_3D_axes(cell=cell, geometry=geometry, line=backend_info["cell_style"]) - elif show_cell == "box": - self._draw_cell_3D_box(cell=cell, geometry=geometry, line=backend_info["cell_style"]) - - def _get_draw_bonds_3D_kwargs(self, xyz1, xyz2, width=10, color='gray', opacity=1, - name=None, line_name=None, coloraxis='coloraxis', **kwargs): - """Generates the arguments for the bond drawing function""" - xyz1 = np.array(xyz1) - bonds_labels = name - if not line_name: - line_name = 'Bonds' - - # Check if we have a single style for all bonds or not. - bonds_style = {"color": color, "width": width, "opacity": opacity} - for k, style in bonds_style.items(): - if isinstance(style, Iterable) and not isinstance(style, str): - if np.unique(style).shape[0] == 1: - bonds_style[k] = style[0] - else: - bonds_style[k] = np.repeat(style, 3) - bonds_style[k][2::3] = 0 - - bonds_xyz = np.full((3*xyz1.shape[0], 3), np.nan) - bonds_xyz[0::3] = xyz1 - bonds_xyz[1::3] = xyz2 - x, y, z = bonds_xyz.T - - x_labels, y_labels, z_labels = None, None, None - if bonds_labels: - x_labels, y_labels, z_labels = ((xyz1 + xyz2) / 2).T - - return dict( - x=x, y=y, z=z, name=line_name, - line={**bonds_style, 'coloraxis': coloraxis}, - bonds_labels=bonds_labels, x_labels=x_labels, y_labels=y_labels, z_labels=z_labels, - **kwargs - ) - - def _draw_bonds_3D(self, x, y, z, name=None, line={}, marker={}, bonds_labels=None, x_labels=None, y_labels=None, z_labels=None, **kwargs): - """Draws all bonds in a single line in 3D - - This method should be overwritten to implement: - - show_markers=True -> Draw markers as well - - Write bonds_labels - """ - self.draw_line3D(x, y, z, line=line, marker=marker, name=name, **kwargs) - - def _draw_single_atom_3D(self, xyz, size, color="gray", name=None, group=None, showlegend=False, vertices=15, **kwargs): - raise NotImplementedError(f"{self.__class__.__name__} does not implement a method to draw a single atom in 3D") - - def _draw_single_bond_3D(self, xyz1, xyz2, width=0.3, color="#ccc", name=None, group=None, showlegend=False, line_kwargs={}, **kwargs): - x, y, z = np.array([xyz1, xyz2]).T - - self.draw_line3D(x, y, z, line={'width': width, 'color': color, **line_kwargs}, name=name, **kwargs) - - def _draw_cell_3D_axes(self, cell, geometry, **kwargs): - - for i, vec in enumerate(cell): - self.draw_line3D( - x=np.array([0, vec[0]]) + geometry.origin[0], - y=np.array([0, vec[1]]) + geometry.origin[1], - z=np.array([0, vec[2]]) + geometry.origin[2], - name=f'Axis {i}', - **kwargs - ) - - def _draw_cell_3D_box(self, cell, geometry, **kwargs): - x, y, z = (GeometryPlot._get_cell_corners(cell) + geometry.origin).T - - self.draw_line3D(x, y, z, name="Unit cell", **kwargs) - -GeometryPlot.backends.register_template(GeometryBackend) diff --git a/src/sisl/viz/backends/templates/_plots/grid.py b/src/sisl/viz/backends/templates/_plots/grid.py deleted file mode 100644 index 2c5b679c21..0000000000 --- a/src/sisl/viz/backends/templates/_plots/grid.py +++ /dev/null @@ -1,42 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ....plots.grid import GridPlot -from ..backend import Backend - - -class GridBackend(Backend): - """Draws a grid as provided by `GridPlot`. - - Checks the dimensionality of the grid and then calls: - - 1D case: `self.draw_1D`, generic implementation that uses `self.draw_line` - - 2D case: `self.draw_2D`, NOT IMPLEMENTED (optional) - - 3D case: `self.draw_3D`, NOT IMPLEMENTED (optional) - - Then, if the geometry needs to be plotted, it plots the geometry. This will use - the `GeometryBackend` with the same name as your grid backend, so make sure it is implemented - if you want to allow showing geometries along with the grid. - """ - - def draw(self, backend_info): - # Choose which function we need to use to plot - drawing_func = getattr(self, f"draw_{backend_info['ndim']}D") - - drawing_func(backend_info) - - if backend_info["geom_plot"] is not None: - self.draw_other_plot(backend_info["geom_plot"]) - - def draw_1D(self, backend_info, **kwargs): - """Draws the grid in 1D""" - self.draw_line(backend_info["ax_range"], backend_info["values"], name=backend_info["name"], **kwargs) - - def draw_2D(self, backend_info, **kwargs): - """Should draw the grid in 2D, and draw contours if requested.""" - raise NotImplementedError(f"{self.__class__.__name__} does not implement displaying grids in 2D") - - def draw_3D(self, backend_info, **kwargs): - """Should draw all the isosurfaces of the grid in 3D""" - raise NotImplementedError(f"{self.__class__.__name__} does not implement displaying grids in 3D") - -GridPlot.backends.register_template(GridBackend) diff --git a/src/sisl/viz/backends/templates/_plots/pdos.py b/src/sisl/viz/backends/templates/_plots/pdos.py deleted file mode 100644 index a705287773..0000000000 --- a/src/sisl/viz/backends/templates/_plots/pdos.py +++ /dev/null @@ -1,33 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from abc import abstractmethod - -from ....plots import PdosPlot -from ..backend import Backend - - -class PdosBackend(Backend): - """It draws the PDOS values provided by a `PdosPlot` - - The workflow implemented by it is as follows: - for line in PDOS_lines: - `self.draw_PDOS_line()`, generic implementation that calls `self._draw_line`. - """ - - def draw(self, backend_info): - self.draw_PDOS_lines(backend_info) - - def draw_PDOS_lines(self, backend_info): - lines = backend_info["PDOS_values"] - Es = backend_info["Es"] - - for name, values in lines.items(): - self.draw_PDOS_line(Es, values, backend_info["request_metadata"][name], name) - - def draw_PDOS_line(self, Es, values, request_metadata, name): - line_style = request_metadata["style"]["line"] - - self.draw_line(x=values, y=Es, name=name, line=line_style) - -PdosPlot.backends.register_template(PdosBackend) diff --git a/src/sisl/viz/backends/templates/backend.py b/src/sisl/viz/backends/templates/backend.py deleted file mode 100644 index f57669939c..0000000000 --- a/src/sisl/viz/backends/templates/backend.py +++ /dev/null @@ -1,359 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from abc import ABC, abstractmethod - -import numpy as np - -from ...plot import Animation, MultiplePlot, SubPlots - - -class Backend(ABC): - """Base backend class that all backends should inherit from. - - This class contains various methods that need to be implemented by its subclasses. - - Methods that MUST be implemented are marked as abstract methods, therefore you won't - even be able to use the class if you don't implement them. On the other hand, there are - methods that are not absolutely essential to the general workings of the framework. - These are written in this class to raise a NotImplementedError. Therefore, the backend - will be instantiable but errors may happen during the plotting process. - - Below are all methods that need to be implemented by... - - (1) the generic backend of the framework: - - `clear`, MUST - - `draw_on`, optional (highly recommended, otherwise no multiple plot functionality) - - `draw_line`, optional (highly recommended for 2D) - - `draw_scatter`, optional (highly recommended for 2D) - - `draw_line3D`, optional - - `draw_scatter3D`, optional - - `draw_arrows3D`, optional - - `show`, optional - - (2) specific backend of a plot: - - `draw`, MUST - - Also, you probably need to write an `__init__` method to initialize the state of the plot. - Usually drawing methods will add to the state and finally on `show` you display the full - plot. - """ - - def __init__(self, plot): - # Let's store our parent plot, we might need it. - self._plot = plot - - @abstractmethod - def draw(self, backend_info): - """Draws the plot, given the info passed by it. - - This is plot specific and is implemented in the templates, you don't need to worry about it! - For example: if you inherit from `BandsBackend`, this class already contains a draw method that - manages the flow of drawing the bands. - """ - - def draw_other_plot(self, plot, backend=None, **kwargs): - """Method that draws a different plot in the current canvas. - - Note that the other plot might have a different active backend, which might be incompatible. - We take care of it in this method. - - This method will be used by `MultiplePlotBackend`, but it's also used in some cases by regular plots. - - NOTE: This needs the `draw_on` method, which is specific to each framework. See below. - - Parameters - ------------ - plot: Plot - The plot we want to draw in the current canvas - backend: str, optional - The name of the backend that we want to force on the plot to be drawn. If not provided, we use - the name of the current backend. - **kwargs: - passed directly to `draw_on` - """ - backend_name = backend or self._backend_name - - # Get the current backend of the plot that we have to draw - plot_backend = getattr(plot, "_backend", None) - - # If the current backend of the plot is incompatible with this backend, we are going to - # setup a compatible backend. Note that here we assume a backend to be compatible if its - # prefixed with the name of the current backend. I.e. if the current backend is "plotly" - # "plotly_*" backends are assumed to be compatible. - if plot_backend is None or not plot_backend._backend_name.startswith(backend_name): - plot.backends.setup(plot, backend_name) - - # Make the plot draw in this backend instance - plot.draw_on(self, **kwargs) - - # Restore the initial backend of the plot, so that it doesn't feel affected - plot._backend = plot_backend - - def draw_on(self, figure, **kwargs): - """Should draw the method in another instance of a compatible backend. - - Parameters - ----------- - figure: - The types of objects accepted by this argument are dependent on each backend. - However, it should always be able to accept a compatible backend. See `PlotlyBackend` - or `MatplotlibBackend` as examples. - """ - raise NotImplementedError(f"{self.__class__.__name__} does not implement a 'draw_on' method and therefore doesn't know"+ - "how to draw outside its own instance.") - - @abstractmethod - def clear(self): - """Clears the figure so that we can draw again.""" - - def show(self): - pass - - # Methods needed for testing - def _test_number_of_items_drawn(self): - """Returns the number of items drawn currently in the plot.""" - raise NotImplementedError - - def draw_line(self, x, y, name=None, line={}, marker={}, text=None, **kwargs): - """Should draw a line satisfying the specifications - - Parameters - ----------- - x: array-like - the coordinates of the points along the X axis. - y: array-like - the coordinates of the points along the Y axis. - name: str, optional - the name of the line - line: dict, optional - specifications for the line style, following plotly standards. The backend - should at least be able to implement `line["color"]` and `line["width"]` - marker: dict, optional - specifications for the markers style, following plotly standards. The backend - should at least be able to implement `marker["color"]` and `marker["size"]` - text: str, optional - contains the text asigned to each marker. On plotly this is seen on hover, - other options could be annotating. However, it is not necessary that this - argument is supported. - **kwargs: - should allow other keyword arguments to be passed directly to the creation of - the line. This will of course be framework specific - """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line method.") - - def draw_scatter(self, x, y, name=None, marker={}, text=None, **kwargs): - """Should draw a scatter satisfying the specifications - - Parameters - ----------- - x: array-like - the coordinates of the points along the X axis. - y: array-like - the coordinates of the points along the Y axis. - name: str, optional - the name of the scatter - marker: dict, optional - specifications for the markers style, following plotly standards. The backend - should at least be able to implement `marker["color"]` and `marker["size"]`, but - it is very advisable that it supports also `marker["opacity"]` and `marker["colorscale"]` - text: str, optional - contains the text asigned to each marker. On plotly this is seen on hover, - other options could be annotating. However, it is not necessary that this - argument is supported. - **kwargs: - should allow other keyword arguments to be passed directly to the creation of - the scatter. This will of course be framework specific - """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter method.") - - def draw_arrows(self, xy, dxy, arrowhead_scale=0.2, arrowhead_angle=20, **kwargs): - """Draws multiple arrows using the generic draw_line method. - - Parameters - ----------- - xy: np.ndarray of shape (n_arrows, 2) - the positions where the atoms start. - dxy: np.ndarray of shape (n_arrows, 2) - the arrow vector. - arrow_head_scale: float, optional - how big is the arrow head in comparison to the arrow vector. - arrowhead_angle: angle - the angle that the arrow head forms with the direction of the arrow (in degrees). - """ - # Get the destination of the arrows - final_xy = xy + dxy - - # Convert from degrees to radians. - arrowhead_angle = np.radians(arrowhead_angle) - - # Get the rotation matrices to get the tips of the arrowheads - rot_matrix = np.array([[np.cos(arrowhead_angle), -np.sin(arrowhead_angle)], [np.sin(arrowhead_angle), np.cos(arrowhead_angle)]]) - inv_rot = np.linalg.inv(rot_matrix) - - # Calculate the tips of the arrow heads - arrowhead_tips1 = final_xy - (dxy*arrowhead_scale).dot(rot_matrix) - arrowhead_tips2 = final_xy - (dxy*arrowhead_scale).dot(inv_rot) - - # Now build an array with all the information to draw the arrows - # This has shape (n_arrows * 7, 2). The information to draw an arrow - # occupies 7 rows and the columns are the x and y coordinates. - arrows = np.empty((xy.shape[0]*7, xy.shape[1]), dtype=np.float64) - - arrows[0::7] = xy - arrows[1::7] = final_xy - arrows[2::7] = np.nan - arrows[3::7] = arrowhead_tips1 - arrows[4::7] = final_xy - arrows[5::7] = arrowhead_tips2 - arrows[6::7] = np.nan - - return self.draw_line(arrows[:, 0], arrows[:, 1], **kwargs) - - def draw_line3D(self, x, y, z, name=None, line={}, marker={}, text=None, **kwargs): - """Should draw a 3D line satisfying the specifications - - Parameters - ----------- - x: array-like - the coordinates of the points along the X axis. - y: array-like - the coordinates of the points along the Y axis. - z: array-like - the coordinates of the points along the Z axis. - name: str, optional - the name of the line - line: dict, optional - specifications for the line style, following plotly standards. The backend - should at least be able to implement `line["color"]` and `line["width"]` - marker: dict, optional - specifications for the markers style, following plotly standards. The backend - should at least be able to implement `marker["color"]` and `marker["size"]` - text: str, optional - contains the text asigned to each marker. On plotly this is seen on hover, - other options could be annotating. However, it is not necessary that this - argument is supported. - **kwargs: - should allow other keyword arguments to be passed directly to the creation of - the line. This will of course be framework specific - """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line3D method.") - - def draw_scatter3D(self, x, y, z, name=None, marker={}, text=None, **kwargs): - """Should draw a 3D scatter satisfying the specifications - - Parameters - ----------- - x: array-like - the coordinates of the points along the X axis. - y: array-like - the coordinates of the points along the Y axis. - z: array-like - the coordinates of the points along the Z axis. - name: str, optional - the name of the scatter - marker: dict, optional - specifications for the markers style, following plotly standards. The backend - should at least be able to implement `marker["color"]` and `marker["size"]` - text: str, optional - contains the text asigned to each marker. On plotly this is seen on hover, - other options could be annotating. However, it is not necessary that this - argument is supported. - **kwargs: - should allow other keyword arguments to be passed directly to the creation of - the scatter. This will of course be framework specific - """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter3D method.") - - def draw_arrows3D(self, xyz, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, **kwargs): - """Draws multiple arrows using the generic draw_line method. - - Parameters - ----------- - xy: np.ndarray of shape (n_arrows, 2) - the positions where the atoms start. - dxy: np.ndarray of shape (n_arrows, 2) - the arrow vector. - arrow_head_scale: float, optional - how big is the arrow head in comparison to the arrow vector. - arrowhead_angle: angle - the angle that the arrow head forms with the direction of the arrow (in degrees). - """ - # Get the destination of the arrows - final_xyz = xyz + dxyz - - # Convert from degrees to radians. - arrowhead_angle = np.radians(arrowhead_angle) - - # Calculate the arrowhead positions. This is a bit more complex than the 2D case, - # since there's no unique plane to rotate all vectors. - # First, we get a unitary vector that is perpendicular to the direction of the arrow in xy. - dxy_norm = np.linalg.norm(dxyz[:, :2], axis=1) - # Some vectors might be only in the Z direction, which will result in dxy_norm being 0. - # We avoid problems by dividinc - dx_p = np.divide(dxyz[:, 1], dxy_norm, where=dxy_norm != 0, out=np.zeros(dxyz.shape[0], dtype=np.float64)) - dy_p = np.divide(-dxyz[:, 0], dxy_norm, where=dxy_norm != 0, out=np.ones(dxyz.shape[0], dtype=np.float64)) - - # And then we build the rotation matrices. Since each arrow needs a unique rotation matrix, - # we will have n 3x3 matrices, where n is the number of arrows, for each arrowhead tip. - c = np.cos(arrowhead_angle) - s = np.sin(arrowhead_angle) - - # Rotation matrix to build the first arrowhead tip positions. - rot_matrices = np.array( - [[c + (dx_p ** 2) * (1 - c), dx_p * dy_p * (1 - c), dy_p * s], - [dy_p * dx_p * (1 - c), c + (dy_p ** 2) * (1 - c), -dx_p * s], - [-dy_p * s, dx_p * s, np.full_like(dx_p, c)]]) - - # The opposite rotation matrix, to get the other arrowhead's tip positions. - inv_rots = rot_matrices.copy() - inv_rots[[0, 1, 2, 2], [2, 2, 0, 1]] *= -1 - - # Calculate the tips of the arrow heads. - arrowhead_tips1 = final_xyz - np.einsum("ij...,...j->...i", rot_matrices, dxyz * arrowhead_scale) - arrowhead_tips2 = final_xyz - np.einsum("ij...,...j->...i", inv_rots, dxyz * arrowhead_scale) - - # Now build an array with all the information to draw the arrows - # This has shape (n_arrows * 7, 3). The information to draw an arrow - # occupies 7 rows and the columns are the x and y coordinates. - arrows = np.empty((xyz.shape[0]*7, 3)) - - arrows[0::7] = xyz - arrows[1::7] = final_xyz - arrows[2::7] = np.nan - arrows[3::7] = arrowhead_tips1 - arrows[4::7] = final_xyz - arrows[5::7] = arrowhead_tips2 - arrows[6::7] = np.nan - - return self.draw_line3D(arrows[:, 0], arrows[:, 1], arrows[:, 2], **kwargs) - - -class MultiplePlotBackend(Backend): - - def draw(self, backend_info): - """Recieves the child plots and is responsible for drawing all of them in the same canvas""" - for child in backend_info["children"]: - self.draw_other_plot(child) - - -class SubPlotsBackend(Backend): - - @abstractmethod - def draw(self, backend_info): - """Draws the subplots layout - - It must use `rows` and `cols`, and draw the children row by row. - """ - - -class AnimationBackend(Backend): - - @abstractmethod - def draw(self, backend_info): - """Generates an animation out of the child plots. - """ - -MultiplePlot.backends.register_template(MultiplePlotBackend) -SubPlots.backends.register_template(SubPlotsBackend) -Animation.backends.register_template(AnimationBackend) diff --git a/src/sisl/viz/configurable.py b/src/sisl/viz/configurable.py deleted file mode 100644 index c94125c16c..0000000000 --- a/src/sisl/viz/configurable.py +++ /dev/null @@ -1,975 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import inspect -from collections import defaultdict, deque -from collections.abc import Iterable -from copy import copy, deepcopy -from functools import wraps - -import numpy as np - -from sisl.messages import info - -from ._presets import get_preset -from .plotutils import get_configurable_docstring - -__all__ = ["Configurable", "vizplotly_settings"] - - -class NamedHistory: - """ Useful for tracking and modifying the history of named parameters - - This is useful to keep track of how a dict changes, for example. - - Parameters - ---------- - init_params: dict - The initial values for the parameters. - - If defaults are not provided, this will be treated as "defaults" (only if keep_defaults is true!) - defaults: dict, optional - the default values for each parameter. In case some parameter is missing in `init_params` - it will be initialized with the default value. - - This will also be used to restore settings to defaults. - history_len: int, optional - how much steps of history should be recorded - keep_defaults: boolean, optional - whether the defaults should be kept in case you want to restore them. - - Attributes - ---------- - current : dict - the current values for the parameters - """ - - def __init__(self, init_params, defaults=None, history_len=20, keep_defaults=True): - self._defaults_kept = keep_defaults - - if defaults is not None: - if keep_defaults: - self._defaults = defaults - - # This makes it easier to restore the parameters - if hasattr(self, "_defaults"): - init_params = {**self._defaults, **init_params} - - # Vals will contain the unique values for each parameter - self._vals = {key: [val] for key, val in init_params.items()} - - # And _hist will just hold params - self._hist = {key: deque([0], maxlen=history_len) for key in init_params} - - def __str__(self): - """ str of the object """ - return self.__class__.__name__ + f"{{history: {self._hist}, parameters={list(self._vals.keys())}}}" - - @property - def current(self): - """ The current state of the history """ - return self.step(-1) - - def step(self, i): - """ Retrieves a given step of the history - - Parameters - ----------- - i: int - the index of the step that you want. It can be negative. - - Returns - ----------- - dict - The key-value pairs for the given step - """ - return {key: self._vals[key][hist[i]] for key, hist in self._hist.items()} - - def __len__(self): - """ Returns the number of steps stored in the history """ - # TODO is this really what you want? - # Needs clarification, in any case len(self._hist.values()[0]) would - # clarify that you don't really care...? - for _, hist in self._hist.items(): - return len(hist) - - def __getitem__(self, item): - if isinstance(item, int): - # If an int is provided, we get that step of the history - return self.step(item) - elif isinstance(item, str): - # If a string is provided, we get the full history of a given setting - return [self._vals[item][i] for i in self._hist[item]] - elif isinstance(item, Iterable): - try: - # If it's an array-like of strings, we will get the full history of each key - return {key: np.array(self._vals[key])[self._hist[key]] for key in item} - except Exception: - # Otherwise, we just map the array with __getitem__() - return [self.__getitem__(i) for i in item] - elif isinstance(item, slice): - # Finally, if it's a slice, we will get the steps that are within that slice. - return {key: np.array(self._vals[key])[hist][item] for key, hist in self._hist.items()} - - def __contains__(self, item): - """ Check if we are storing that named item """ - return item in self._vals - - def update(self, **new_settings): - """ Updates the history by appending a new step - - Parameters - ---------- - **new_settings: - all the settings that you want to change passed as keyword arguments. - - You don't need to provide the values for all parameters, only those that you - need to change. - - Returns - ------- - self - """ - for key in self._vals: - if key not in new_settings: - new_index = self._hist[key][-1] - - else: - # Check if we already have that value - val = new_settings[key] - is_nparray = isinstance(val, np.ndarray) - - # We have to do this because np.arrays don't like being compared :) - # Otherwise we would just do if val in self._vals[key] - for i, saved_val in enumerate(self._vals[key]): - if not isinstance(saved_val, np.ndarray) and not is_nparray: - try: - if val == saved_val: - new_index = i - break - except Exception: - # It is possible that the value itself is not a numpy array - # but contains one. This is very hard to handle - # Also we will assume that any other exception raised mean the - # values are not equal. - pass - else: - self._vals[key].append(val) - new_index = len(self._vals[key]) - 1 - - # Append the index to the history - self._hist[key].append(new_index) - - return self - - @property - def last_updated(self): - """ The names of the parameters that were changed in the last update """ - return self.updated_params(-1) - - def last_update_for(self, key): - """ Returns the index of the last update for a given parameter - - Parameters - ----------- - key: str - the parameter we want the last update for. - - Returns - ----------- - int or None - the index of the last update. - If the parameter was never updated it returns None. - """ - current = self._vals[key][self._hist[key][-1]] - - for i, val in enumerate(reversed(self._hist[key])): - if val != current: - return len(self._hist[key]) - (i+1) - - def updated_params(self, step): - """ Gets the keys of the parameters that were updated in a given step - - Parameters - ----------- - step: int - the index of the step that you want to check. - - Returns - ----------- - list of str - the list of parameters that were updated at that step - """ - return self.diff_keys(step, step - 1) - - def was_updated(self, key, step=-1): - """ Checks whether a given step updated the parameter's value - - Parameters - ----------- - key: str - the name of the parameter that you want to check. - step: int - the index of the step we want to check - - Returns - ----------- - bool - whether the parameter was updated or not - """ - return self.is_different(key, step1=step, step2=step-1) - - def is_different(self, key, step1, step2): - """ Checks if a parameter has a different value between two steps - - The steps DO NOT need to be consecutive. - - Parameters - ----------- - key:str - the name of the parameter. - step1 and step2: int - the indices of the two steps that you want to check. - - Returns - ----------- - bool - whether the value of the parameter is different in these - two steps - """ - hist = self._hist[key] - return hist[step1] != hist[step2] - - def diff_keys(self, step1, step2): - """ Gets the keys that are different between two steps of the history - - The steps DO NOT need to be consecutive. - - Parameters - ----------- - step1 and step2: int - the indices of the two steps that you want to check. - - Returns - ----------- - list of str - the names of the parameters that are different between these two steps. - """ - return [key for key in self._vals if self.is_different(key, step1, step2)] - - def delta(self, step_after=-1, step_before=None): - """ Gets a dictionary with the diferences between two steps - - Parameters - ----------- - step_after: int, optional - the step that is considered as "after" in the delta log. - step_before: int, optional - the step that is considered as "before" in the delta log. - - If not provided, it will just be one step previous to `step_after` - - Returns - ----------- - dict - a dictionary containing, for each CHANGED key, the values before - and after. - - The dict will not contain the parameters that were not changed. - """ - if step_before is None: - step_before = step_after -1 - - keys = self.diff_keys(step_before, step_after) - return { - key: { - "before": self[step_before][key], - "after": self[step_after][key], - } for key in keys - } - - @property - def last_delta(self): - """ A log with the last changes - - See `delta` for more information. - """ - return self.delta(-1, -2) - - def undo(self, steps=1): - """ Takes the history back a number of steps - - Currently, this is irreversible, there is no "redo". - - Parameters - ----------- - steps: int, optional - the number of steps that you want to move the history back. - - Returns - ----------- - self - """ - # Decide which keys to undo. - # This can not be done yet because all keys - # are supposed to have the same number of steps. - # if only is not None: - # keys = only - # else: - # keys = self._vals.keys() - # if exclude is not None: - # keys = [ key for key in keys if key not in exclude] - - for hist in self._hist.values(): - for _ in range(steps): - hist.pop() - - # Clear the unused values (for the moment we are setting them to None - # so that we don't need to change the indices of the history) - for key in self._vals: - if self._defaults_kept: - self._vals[key] = [val if i in hist or i==0 else None for i, val in enumerate(self._vals[key])] - else: - self._vals[key] = [val if i in hist else None for i, val in enumerate(self._vals[key])] - - return self - - def clear(self): - """ Clears the history. - - It sets all settings to `None`. If you want to restore the defaults, - use `restore_defaults` instead. - - """ - self.__init__(init_settings={key: None for key in self._vals}) - return self - - def restore_initial(self): - """ Restores the history to its initial values (the first step) """ - self.__init__(init_settings=self.step(0)) - return self - - def restore_defaults(self): - """ Restores the history to its defaults """ - if self._defaults_kept: - if hasattr(self, "_defaults"): - self.__init__({}) - else: - self.restore_initial() - else: - raise RuntimeError("Defaults were not kept! You need to use keep_defaults=True on initialization") - - return self - - @property - def defaults(self): - """ The default values for this history """ - if self._defaults_kept: - if hasattr(self, "_defaults"): - return self._defaults - else: - return self.step(0) - else: - raise RuntimeError("Defaults were not kept! You need to use keep_defaults=True on initialization") - - return self - - def is_default(self, key): - """ Checks if a parameter currently has the default value - - Parameters - ----------- - key: str - the parameter that you want to check. - - Returns - ----------- - bool - whether the parameter currently holds the default value. - """ - return self[key][-1] == self.defaults[key] - - -class ConfigurableMeta(type): - """ Metaclass used to build the Configurable class and its childs. - - This is used mainly for two reasons, and they both affect only subclasses of Configurable - not Configurable itself.: - - Make the class functions able to access settings through their arguments - (see the `_populate_with_settings` function in this same file) - - Set documentation to the `update_settings` method that is specific to the particular class - so that the user can check what each parameter does exactly. - """ - - def __new__(cls, name, bases, attrs): - """Prepares a subclass of Configurable, as explained in this class' docstring.""" - # If there are no bases, it is the Configurable class, and we don't need to modify its methods. - if bases: - # If this is a sub class, we add the parameters from its parents. - class_params = attrs.get("_parameters", []) - class_param_groups = list(attrs.get("_param_groups", [])) - for base in bases: - if "_parameters" in vars(base): - class_params = [*class_params, *deepcopy(base._parameters)] - if "_param_groups" in vars(base): - class_param_groups = [*deepcopy(base._param_groups), *class_param_groups] - - attrs["_parameters"] = class_params - attrs["_param_groups"] = [group for group in class_param_groups if group["key"] is not None] - attrs["_param_groups"].append({ - "key": None, - "name": "Other settings", - "icon": "settings", - "description": "Here are some unclassified settings. Even if they don't belong to any group, they might still be important. They may be here just because the developer was too lazy to categorize them or forgot to do so. If you are the developer and it's the first case, shame on you." - }) - - # If methods have arguments whose keys correspond to settings, they will receive the - # current value of the setting as the default. We can not do that if this is a staticmethod - # or a classmethod, because they are not aware of the instance. - for f_name, f in attrs.items(): - if callable(f) and (not f_name.startswith("__")) and (not isinstance(f, (staticmethod, classmethod))): - attrs[f_name] = _populate_with_settings(f, [param["key"] for param in class_params]) - - new_cls = super().__new__(cls, name, bases, attrs) - - new_cls._create_update_maps() - - if bases: - # Change the docs of the update_settings method to truly reflect - # the available kwargs for the plot class and provide more help to the user - def update_settings(self, *args, **kwargs): - return self._update_settings(*args, **kwargs) - - update_settings.__doc__ = f"Updates the settings of this plot.\n\nDocs for {new_cls.__name__}:\n\n{get_configurable_docstring(new_cls)}" - new_cls.update_settings = update_settings - - return new_cls - - -class Configurable(metaclass=ConfigurableMeta): - - def init_settings(self, presets=None, **kwargs): - """ - Initializes the settings for the object. - - Parameters - ----------- - presets: str or array-like of str - all the presets that you want to use. - Note that you can register new presets using `sisl.viz.plotly.add_preset` - **kwargs: - the values of the settings passed as keyword arguments. - - If a setting is not provided, the default value will be used. - """ - # If the class needs to overwrite some defaults of settings that has inherited, do it - overwrite_defaults = getattr(self, "_overwrite_defaults", {}) - for key, val in overwrite_defaults.items(): - if key not in kwargs: - kwargs[key] = val - - #Get the parameters of all the classes the object belongs to - self.params, self.param_groups = deepcopy(self._parameters), deepcopy(self._param_groups) - - if presets is not None: - if isinstance(presets, str): - presets = [presets] - - for preset in presets: - preset_settings = get_preset(preset) - kwargs = {**preset_settings, **kwargs} - - # Define the settings dictionary, taking the value of each parameter from kwargs if it is there or from the defaults otherwise. - # And initialize the settings history - defaults = {param.key: param.default for param in self.params} - self.settings_history = NamedHistory( - {key: kwargs.get(key, val) for key, val in defaults.items()}, - defaults=defaults, history_len=20, keep_defaults=True - ) - - return self - - @classmethod - def _create_update_maps(cls): - """ Generates a mapping from setting keys to functions that use them - - Therefore, this mapping (`cls._run_on_update`) contains information about - which functions need to be executed again when a setting is updated. - - The mapping generated here is used in `Configurable.run_updates` - """ - #Initialize the object where we are going to store what each setting needs to rerun when it is updated - if hasattr(cls, "_run_on_update"): - updates_dict = copy(cls._run_on_update) - else: - updates_dict = defaultdict(list) - - cls._run_on_update = updates_dict - - for name, f in inspect.getmembers(cls, predicate=inspect.isfunction): - for _, param in getattr(f, "_settings", []): - cls._run_on_update[param].append(f.__name__) - - @property - def settings(self): - """ The current settings of the object """ - return self.settings_history.current - - @classmethod - def _get_class_params(cls): - """ Returns all the parameters that can be tweaked for that class - - These are obtained from the `_parameters` class variable. - - Note that parameters are inherited even if you overwrite the `_parameters` - variable. - - Probably there should be a variable `_exclude_params` to avoid some parameters. - """ - - return cls._parameters, cls._param_groups - - def update_settings(self, *args, **kwargs): - """ This method will be overwritten for each class. See `_update_settings` """ - return self._update_settings(*args, **kwargs) - - def _update_settings(self, run_updates=True, **kwargs): - """ Updates the settings of the object - - Note that this is only private because we provide a public update_settings - with the specific kwargs for each class so that users can quickly know which - settings are available. You can see how we define this method in `__init_subclass__` - - Parameters - ------------ - run_updates: bool, optional - whether we should run updates after updating the settings. If not, the settings - will be updated, but you won't see any change in the object. - **kwargs: - the values of the settings that we want to update passed as keyword arguments. - """ - #Initialize the settings in case there are none yet - if not hasattr(self, "settings_history"): - return self.init_settings(**kwargs) - - # Otherwise, update them - updates = {key: val for key, val in kwargs.items() if key in self.settings_history} - if updates: - self.settings_history.update(**updates) - - #Do things after updating the settings - if len(self.settings_history.last_updated) > 0 and run_updates: - self._run_updates(self.settings_history.last_updated) - - return self - - def _run_updates(self, for_keys): - """ Runs the functions/methods that are supposed to be ran when given settings are updated - - It uses the `_run_on_update` dict, which contains what to run - in case each setting is updated. - - Parameters - ----------- - for_keys: array-like of str - the keys of the settings that have been updated. - """ - # Get the functions that need to be executed for each key that has been updated and - # put them in a list - func_names = [self._run_on_update.get(setting_key, []) for setting_key in for_keys] - - # Flatten that list (list comprehension) and take only the unique values (set) - func_names = set([f_name for sublist in func_names for f_name in sublist]) - - # Give the oportunity to parse the functions that need to be ran. See `Plot._parse_update_funcs` - # for an example - func_names = self._parse_update_funcs(func_names) - - # Execute the functions that we need to execute. - for f_name in func_names: - getattr(self, f_name)() - - return self - - def _parse_update_funcs(self, func_names): - """ Called on _run_updates as a final oportunity to decide what functions to run - - May be overwritten in child classes. - - Parameters - ----------- - func_names: set of str - the unique functions names that are to be executed unless you modify them. - - Returns - ----------- - array-like of str - the final list of functions that will be executed. - """ - return func_names - - def undo_settings(self, steps=1, run_updates=True): - """ Brings the settings back a number of steps - - Parameters - ------------ - steps: int, optional - the number of steps you want to go back. - run_updates: bool, optional - whether we should run updates after updating the settings. If not, the settings - will be updated, but you won't see any change in the object. - """ - try: - diff = self.settings_history.diff_keys(-1, -steps-1) - self.settings_history.undo(steps=steps) - if run_updates: - self._run_updates(diff) - except IndexError: - info(f"This instance of {self.__class__.__name__} does not " - f"contain earlier settings as requested ({steps} step(s) back)") - - return self - - def undo_setting(self, key): - """ Undoes only a particular setting and leaves the others unchanged - - At the moment it is a 'fake' undo function, since it actually updates the settings. - - Parameters - ----------- - key: str - the key of the setting that you want to undo. - """ - i = self.settings_history.last_update_for(key) - - if i is None: - info(f"key={key} was never changed; cannot undo nothing.") - - self.update_settings(key=self.settings_history[key][i]) - - return self - - def undo_settings_group(self, group): - """ Takes the desired group of settings one step back, but the rest of the settings remain unchanged - - At the moment it is a 'fake' undo function, since it actually updates the settings. - - Parameters - ----------- - group: str - the key of the settings group for which you want to undo its values. - """ - #Get the actual settings for that group - actualSettings = self.get_settings_group(group) - - #Try to find any different values for the settings - for i in range(len(self.settings_history)): - - previousSettings = self.get_settings_group(group, steps_back = i) - - if previousSettings != actualSettings: - - return self.update_settings(previousSettings) - else: - info(f"group={group} was never changed; cannot undo nothing.") - - return self - - def get_param(self, key, as_dict=False, params_extractor=False): - """ Gets the parameter for a given setting - - Arguments - --------- - key: str - The key of the desired parameter. - as_dict: bool, optional - If set to True, returns a dictionary instead of the actual parameter object. - params_extractor: function, optional - A function that accepts the object (self) and returns its params (NOT A COPY OF THEM!). - This will only be used in case this method is used outside the class, where objects - have a different structure (e.g. QueriesInput inputField) or if there is some nested params - field that the class is not aware of (although this second case is probably not advisable). - - Returns - ------- - param: dict or InputField - The parameter in the form specified by as_dict. - """ - for param in self.params if not params_extractor else params_extractor(self): - if param.key == key: - return param.__dict__ if as_dict else param - else: - raise KeyError(f"There is no parameter '{key}' in {self.__class__.__name__}") - - @classmethod - def get_class_param(cls, key, as_dict=False): - try: - return cls.get_param(cls, key, as_dict=as_dict, params_extractor=lambda cls: cls._parameters) - except KeyError: - raise KeyError(f"There is no parameter '{key}' in {cls.__name__}") - - def modify_param(self, key, *args, **kwargs): - """ Modifies a given parameter - - See *args to know how can it be used. - - This is a general schema of how an input field parameter looks internally, so that you - can know what do you want to change: - - (Note that it is very easy to modify nested values, more on this in *args explanation) - - { - "key": whatever, - "name": whatever, - "default": whatever, - . - . (keys that affect, let's say, the programmatic functionality of the parameter, - . they can be modified with Configurable.modify_param) - . - "inputField": { - "type": whatever, - "width": whatever, (keys that affect the inputField control that is displayed - "params": { they can be modified with Configurable.modifyInputField) - whatever - }, - "style": { - whatever - } - - } - } - - Arguments - -------- - key: str - The key of the parameter to be modified - *args: - Depending on what you pass the setting will be modified in different ways: - - Two arguments: - the first argument will be interpreted as the attribute that you want to change, - and the second one as the value that you want to set. - - Ex: obj.modify_param("length", "default", 3) - will set the default attribute of the parameter with key "length" to 3 - - Modifying nested keys is possible using dot notation. - - Ex: obj.modify_param("length", "inputField.width", 3) - will modify the width key inside inputField on the schema above. - - The last key, but only the last one, will be created if it does not exist. - - Ex: obj.modify_param("length", "inputField.width.inWinter.duringDay", 3) - will only work if all the path before duringDay exists and the value of inWinter is a dictionary. - - Otherwise you could go like this: obj.modify_param("length", "inputField.width.inWinter", {"duringDay": 3}) - - - One argument and it is a dictionary: - the keys will be interpreted as attributes that you want to change and the values - as the value that you want them to have. - - Each key-value pair in the dictionary will be updated in exactly the same way as - it is in the previous case. - - - One argument and it is a function: - - the function will recieve the parameter and can act on it in any way you like. - It doesn't need to return the parameter, just modify it. - In this function, you can call predefined methods of the parameter, for example. - - Ex: obj.modify_param("length", lambda param: param.incrementByOne() ) - - given that you know that this type of parameter has this method. - **kwargs: optional - They are passed directly to the Configurable.get_param method to retrieve the parameter. - - Returns - -------- - self: - The configurable object. - """ - self.get_param(key, as_dict = False, **kwargs).modify(*args) - - return self - - def get_setting(self, key, copy=True, parse=True): - """ Gets the value for a given setting - - Parameters - ------------ - key: str - The key of the setting we want to get - copy: boolean, optional - Whether you want a copy of the object or the actual object - parse: boolean, optional - whether the setting should be parsed before returning it. - """ - # Get the value of the setting and parse it using the parse method - # defined for the parameter - val = self.get_param(key).parse(self.settings[key]) - - return deepcopy(val) if copy else val - - def get_settings_group(self, group, steps_back=0): - """ Gets the subset of the settings that corresponds to a given group - - Arguments - --------- - group: str - The key of the settings group that we desire. - steps_back: optional, int - If you don't want the actual settings, but some point of the settings history, - use this argument to state how many steps back you want the settings' values. - - Returns - ------- - settings_group: dict - A subset of the settings with only those that belong to the asked group. - """ - if steps_back: - settings = self.settings_history[-steps_back] - else: - settings = self.settings - - return deepcopy({setting.key: settings[setting.key] for setting in self.params if getattr(setting, "group", None) == group}) - - def has_these_settings(self, settings={}, **kwargs): - """ Checks if the object settings match the provided settings - - Parameters - ---------- - settings: dict - dictionary containing the settings keys and values - **kwargs: - setting keys and values can also be passed as keyword arguments. - - You can use settings and **kwargs at the same time, they will be merged. - """ - settings = {**settings, **kwargs} - - for key, val in settings.items(): - if self.get_setting(key) != val: - return False - else: - return True - - -# DECORATOR TO USE WHEN DEFINING METHODS IN CLASSES THAT INHERIT FROM Configurable - -def vizplotly_settings(when='before', init=False): - """ Specifies how settings should be updated when running a method - - It can only decorate a method of a class that inherits from Configurable. - - Works by grabbing the kwargs from the method and taking the ones whose keys - represent settings. - - Parameters - ---------- - when: {'after', 'before'} - specifies when should the settings be updated. - - 'after': After the method has been ran. - 'before': Before running the method. - - init: boolean, optional - whether the settings should be initialized (restored). - - If `False`, the settings are just updated. - """ - extra_kwargs = {} - if init: - method_name = 'init_settings' - else: - method_name = '_update_settings' - extra_kwargs = {'from_decorator': True, 'run_updates': True} - - def decorator(method): - if when == 'before': - @wraps(method) - def func(obj, *args, **kwargs): - getattr(obj, method_name)(**kwargs, **extra_kwargs) - return method(obj, *args, **kwargs) - - elif when == 'after': - @wraps(method) - def func(obj, *args, **kwargs): - ret = method(obj, *args, **kwargs) - getattr(obj, method_name)(**kwargs, **extra_kwargs) - return ret - else: - raise ValueError("Incorrect decorator usage") - return func - return decorator - - -def _populate_with_settings(f, class_params): - """ Makes functions of a Configurable object able to access settings through arguments - - Parameters - ----------- - f: function - the function that you want to give this functionality - class_params: array-like of str - the keys of the parameters that this function will be able to access. Presumably these - are the keys of the parameters of the class where the function is defined. - - Returns - ------------ - function - in case the function has some arguments named like parameters that are available to it, - this will be a wrapped function that defaults the values of those arguments to the values - of the settings. - - Otherwise, it returns the same function. - - Examples - ----------- - - >>> class MyPlot(Configurable): - >>> _parameters = (TextInput(key="my_param", name=...)) - >>> - >>> def some_method(self, my_param): - >>> return my_param - - After `some_method` has been correctly passed through `_populate_with_settings`: - >>> plot = MyPlot(my_param=3) - >>> plot.some_method() # Returns 3 - >>> plot.some_method(5) # Returns 5 - >>> plot.some_method() # Returns 3 - """ - try: - # note that params takes `self` as argument - # So first actual argument has index 1 - params = inspect.signature(f).parameters - # Also, there is no need to use numpy if not needed - # In this case it was just an overhead. - idx_params = tuple(filter(lambda i_p: i_p[1] in class_params, - enumerate(params))) - except Exception: - return f - - if len(idx_params) == 0: - # no need to wrap it - return f - - # Tuples are immutable, so they should have a *slightly* lower overhead. - # Also, get rid of zip below - # The below gets called alot, I suspect. - # So it should probably be *fast* :) - f._settings = idx_params - - @wraps(f) - def f_default_setting_args(self, *args, **kwargs): - nargs = len(args) - for i, param in f._settings: - # nargs does not count `self` and then the above indices will fine - if i > nargs and param not in kwargs: - try: - kwargs[param] = self.get_setting(param, copy=False) - except KeyError: - pass - - return f(self, *args, **kwargs) - - return f_default_setting_args diff --git a/src/sisl/viz/data/__init__.py b/src/sisl/viz/data/__init__.py new file mode 100644 index 0000000000..38d64da2b2 --- /dev/null +++ b/src/sisl/viz/data/__init__.py @@ -0,0 +1,6 @@ +from .bands import BandsData +from .data import Data +from .eigenstate import EigenstateData +from .pdos import PDOSData +from .sisl_objs import GeometryData, GridData, HamiltonianData +from .xarray import XarrayData diff --git a/src/sisl/viz/data/bands.py b/src/sisl/viz/data/bands.py new file mode 100644 index 0000000000..ad261ffb48 --- /dev/null +++ b/src/sisl/viz/data/bands.py @@ -0,0 +1,619 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Dict, Optional, Sequence, Union + +import numpy as np +import xarray as xr + +import sisl +from sisl.io import bandsSileSiesta, fdfSileSiesta, wfsxSileSiesta +from sisl.physics.brillouinzone import BrillouinZone +from sisl.physics.spin import Spin + +from .._single_dispatch import singledispatchmethod +from ..data_sources import FileDataSIESTA, HamiltonianDataSource +from .xarray import XarrayData + +try: + import pathos + _do_parallel_calc = True +except: + _do_parallel_calc = False + +try: + from aiida import orm + Aiida_node = orm.Node + AIIDA_AVAILABLE = True +except ModuleNotFoundError: + class Aiida_node: pass + AIIDA_AVAILABLE = False + +class BandsData(XarrayData): + + def sanity_check(self, + n_spin: Optional[int] = None, nk: Optional[int] = None, nbands: Optional[int] = None, + klabels: Optional[Sequence[str]] = None, kvals: Optional[Sequence[float]] = None + ): + """Check that the dataarray satisfies the requirements to be treated as PDOSData.""" + super().sanity_check() + + array = self._data + + for k in ("k", "band"): + assert k in array.dims, f"'{k}' dimension missing, existing dimensions: {array.dims}" + + spin = array.attrs['spin'] + assert isinstance(spin, Spin) + + if n_spin is not None: + if n_spin == 1: + assert spin.is_unpolarized, f"Spin in the data is {spin}, but n_spin=1 was expected" + elif n_spin == 2: + assert spin.is_polarized, f"Spin in the data is {spin}, but n_spin=2 was expected" + elif n_spin == 4: + assert not spin.is_diagonal, f"Spin in the data is {spin}, but n_spin=4 was expected" + + # Check if we have the correct number of spin channels + if spin.is_polarized: + assert "spin" in array.dims, f"'spin' dimension missing for polarized spin, existing dimensions: {array.dims}" + if n_spin is not None: + assert len(array.spin) == n_spin + else: + assert "spin" not in array.dims, f"'spin' dimension present for spin different than polarized, existing dimensions: {array.dims}" + assert "spin" not in array.coords, f"'spin' coordinate present for spin different than polarized, existing dimensions: {array.dims}" + + # Check shape of bands + if nk is not None: + assert len(array.k) == nk + if nbands is not None: + if not spin.is_diagonal: + assert len(array.band) == nbands * 2 + else: + assert len(array.band) == nbands + + # Check if k ticks match the expected ones + if klabels is not None: + assert "axis" in array.k.attrs, "No axis specification for the k dimension." + assert "ticktext" in array.k.attrs['axis'], "No ticks were found for the k dimension" + assert tuple(array.k.attrs['axis']['ticktext']) == tuple(klabels), f"Expected labels {klabels} but found {array.k.attrs['axis']['ticktext']}" + if kvals is not None: + assert "axis" in array.k.attrs, "No axis specification for the k dimension." + assert "tickvals" in array.k.attrs['axis'], "No ticks were found for the k dimension" + assert np.allclose(array.k.attrs['axis']['tickvals'], kvals), f"Expected label values {kvals} but found {array.k.attrs['axis']['tickvals']}" + + @classmethod + def toy_example(cls, spin: Union[str, int, Spin] = "", n_states: int = 20, nk: int = 30, gap: Optional[float] = None): + """Creates a toy example of a bands data array""" + + spin = Spin(spin) + + n_bands = n_states if spin.is_diagonal else n_states * 2 + + if spin.is_polarized: + polynoms_shape = (2, n_bands) + dims = ("spin", "k", "band") + shift = np.tile(np.arange(0, n_bands), 2).reshape(2, -1) + else: + polynoms_shape = (n_bands, ) + dims = ("k", "band") + shift = np.arange(0, n_bands) + + # Create some random coefficients for degree 2 polynomials that will be used to generate the bands + random_polinomials = np.random.rand(*polynoms_shape, 3) + random_polinomials[..., 0] *= 10 # Bigger curvature + random_polinomials[..., :n_bands // 2, 0] *= -1 # Make the curvature negative below the gap + random_polinomials[..., 2] += shift # Shift each polynomial so that bands stack on top of each other + + # Compute bands + x = np.linspace(0, 1, nk) + y = np.outer(x ** 2, random_polinomials[..., 0]) + np.outer(x, random_polinomials[..., 1]) + random_polinomials[..., 2].ravel() + + y = y.reshape(nk, *polynoms_shape) + + if spin.is_polarized: + # Make sure that the top of the valence band and bottom of the conduction band + # are the same spin (to facilitate computation of the gap). + VB_spin = y[..., :n_bands // 2].argmin() // (nk * n_bands) + CB_spin = y[..., n_bands // 2:].argmax() // (nk * n_bands) + + if VB_spin != CB_spin: + y[..., n_bands // 2:] = np.flip(y[..., n_bands // 2:], axis=0) + + y = y.transpose(1, 0, 2) + + # Compute gap limits + top_VB = y[..., :n_bands // 2 ].max() + bottom_CB = y[..., n_bands // 2:].min() + + # Correct the gap if some specific value was requested + generated_gap = bottom_CB - top_VB + if gap is not None: + add_shift = (gap - generated_gap) + y[..., n_bands // 2:] += add_shift + bottom_CB += add_shift + + # Compute fermi level + fermi = (top_VB + bottom_CB) / 2 + + # Create the dataarray + data = xr.DataArray( + y - fermi, + coords={ + "k": x, + "band": np.arange(0, n_bands), + }, + dims=dims, + ) + + data = xr.Dataset({"E": data}) + + # Add spin moments if the spin is not diagonal + if not spin.is_diagonal: + spin_moments = np.random.rand(nk, n_bands, 3) * 2 - 1 + data['spin_moments'] = xr.DataArray( + spin_moments, + coords={ + "k": x, + "band": np.arange(0, n_bands), + "axis": ["x", "y", "z"] + }, + dims=("k", "band", "axis") + ) + + # Add the spin class of the data + data.attrs['spin'] = spin + + # Inform of where to place the ticks + data.k.attrs["axis"] = { + "tickvals": [0, x[-1]], + "ticktext": ["Gamma", "X"], + } + + return cls.new(data) + + @singledispatchmethod + @classmethod + def new(cls, bands_data): + return cls(bands_data) + + @new.register + @classmethod + def from_dataset(cls, bands_data: xr.Dataset): + + old_attrs = bands_data.attrs + + # Check if there's a spin attribute + spin = old_attrs.get("spin", None) + + # If not, guess it + if spin is None: + if 'spin' not in bands_data: + spin = Spin(Spin.UNPOLARIZED) + else: + spin = { + 1: Spin.UNPOLARIZED, + 2: Spin.POLARIZED, + 4: Spin.NONCOLINEAR, + }[bands_data.spin.shape[0]] + + spin = Spin(spin) + + # Remove the spin coordinate if the data is not spin polarized + if 'spin' in bands_data and not spin.is_polarized: + bands_data = bands_data.isel(spin=0).drop_vars("spin") + + if spin.is_polarized: + spin_options = [0, 1] + bands_data['spin'] = ('spin', spin_options, bands_data.spin.attrs) + # elif not spin.is_diagonal: + # spin_options = get_spin_options(spin) + # bands_data['spin'] = ('spin', spin_options, bands_data.spin.attrs) + + # If the energy variable doesn't have units, set them as eV + if 'E' in bands_data and 'units' not in bands_data.E.attrs: + bands_data.E.attrs['units'] = 'eV' + # Same with the k coordinate, which we will assume are 1/Ang + if 'k' in bands_data and 'units' not in bands_data.k.attrs: + bands_data.k.attrs['units'] = '1/Ang' + # If there are ticks, show the grid. + if 'axis' in bands_data.k.attrs and bands_data.k.attrs['axis'].get('ticktext') is not None: + bands_data.k.attrs['axis'] = {"showgrid": True, **bands_data.k.attrs.get('axis', {})} + + bands_data.attrs = { + **old_attrs, "spin": spin + } + + if "geometry" not in bands_data.attrs: + if "parent" in bands_data.attrs: + parent = bands_data.attrs["parent"] + if hasattr(parent, "geometry"): + bands_data.attrs['geometry'] = parent.geometry + + return cls(bands_data) + + @new.register + @classmethod + def from_dataarray(cls, bands_data: xr.DataArray): + bands_data_ds = xr.Dataset({"E": bands_data}) + bands_data_ds.attrs.update(bands_data.attrs) + + return cls.new(bands_data_ds) + + @new.register + @classmethod + def from_path(cls, path: Path, *args, **kwargs): + """Creates a sile from the path and tries to read the PDOS from it.""" + return cls.new(sisl.get_sile(path), *args, **kwargs) + + @new.register + @classmethod + def from_string(cls, string: str, *args, **kwargs): + """Assumes the string is a path to a file""" + return cls.new(Path(string), *args, **kwargs) + + + @new.register + @classmethod + def from_fdf(cls, fdf: fdfSileSiesta, bands_file: Union[str, bandsSileSiesta, None] = None): + """Gets the bands data from a SIESTA .bands file""" + bands_file = FileDataSIESTA(fdf=fdf, path=bands_file, cls=sisl.io.bandsSileSiesta) + + assert isinstance(bands_file, bandsSileSiesta) + + return cls.new(bands_file) + + @new.register + @classmethod + def from_siesta_bands(cls, bands_file: bandsSileSiesta): + """Gets the bands data from a SIESTA .bands file""" + + bands_data = bands_file.read_data(as_dataarray=True) + bands_data.k.attrs['axis'] = { + 'tickvals': bands_data.attrs.pop('ticks'), + 'ticktext': bands_data.attrs.pop('ticklabels') + } + + return cls.new(bands_data) + + @new.register + @classmethod + def from_hamiltonian(cls, + bz: sisl.BrillouinZone, + H: Union[sisl.Hamiltonian, None] = None, + extra_vars: Sequence[Union[Dict, str]] = () + ): + """Uses a sisl's `BrillouinZone` object to calculate the bands.""" + if bz is None: + raise ValueError("No band structure (k points path) was provided") + + if not isinstance(getattr(bz, "parent", None), sisl.Hamiltonian): + H = HamiltonianDataSource(H=H) + bz.set_parent(H) + else: + H = bz.parent + + # Define the spin class of this calculation. + spin = H.spin + + if isinstance(bz, sisl.BandStructure): + ticks = bz.lineartick() + kticks = bz.lineark() + else: + ticks = (None, None) + kticks = np.arange(0, len(bz)) + + # Get the wrapper function that we should call on each eigenstate. + # This also returns the coordinates and names to build the final dataset. + bands_wrapper, all_vars, coords_values = _get_eigenstate_wrapper( + kticks, spin, extra_vars=extra_vars + ) + + # Get a dataset with all values for all spin indices + spin_datasets = [] + coords = [var['coords'] for var in all_vars] + name = [var['name'] for var in all_vars] + for spin_index in coords_values['spin']: + + # Non collinear routines don't accept the keyword argument "spin" + spin_kwarg = {"spin": spin_index} + if not spin.is_diagonal: + spin_kwarg = {} + + with bz.apply(pool=_do_parallel_calc, zip=True) as parallel: + spin_bands = parallel.dataarray.eigenstate( + wrap=partial(bands_wrapper, spin_index=spin_index), + **spin_kwarg, + coords=coords, name=name, + ) + + spin_datasets.append(spin_bands) + + # Merge everything into a single dataset with a spin dimension + bands_data = xr.concat(spin_datasets, "spin").assign_coords(coords_values) + + # If the band structure contains discontinuities, we will copy the dataset + # adding the discontinuities. + if isinstance(bz, sisl.BandStructure) and len(bz._jump_idx) > 0: + + old_coords = bands_data.coords + coords = { + name: bz.insert_jump(old_coords[name]) if name == "k" else old_coords[name].values + for name in old_coords + } + + def _add_jump(array): + if "k" in array.coords: + array = array.transpose("k", ...) + return (array.dims, bz.insert_jump(array)) + else: + return array + + bands_data = xr.Dataset( + {name: _add_jump(bands_data[name]) for name in bands_data}, + coords=coords + ) + + # Add the spin class of the data + bands_data.attrs['spin'] = spin + + # Inform of where to place the ticks + bands_data.k.attrs["axis"] = { + "tickvals": ticks[0], + "ticktext": ticks[1], + } + + return cls.new(bands_data) + + @new.register + @classmethod + def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=False): + """Plots bands from the eigenvalues contained in a WFSX file. + + It also needs to get a geometry. + """ + if need_H: + H = HamiltonianDataSource(H=fdf) + if H is None: + raise ValueError("Hamiltonian was not setup, and it is needed for the calculations") + parent = H + geometry = parent.geometry + else: + # Get the fdf sile + fdf = FileDataSIESTA(path=fdf) + # Read the geometry from the fdf sile + geometry = fdf.read_geometry(output=True) + parent = geometry + + # Get the wfsx file + wfsx_sile = FileDataSIESTA(fdf=fdf, path=wfsx_file, cls=sisl.io.wfsxSileSiesta, parent=parent) + + # Now read all the information of the k points from the WFSX file + k, weights, nwfs = wfsx_sile.read_info() + # Get the number of wavefunctions in the file while performing a quick check + nwf = np.unique(nwfs) + if len(nwf) > 1: + raise ValueError(f"File {wfsx_sile.file} contains different number of wavefunctions in some k points") + nwf = nwf[0] + # From the k values read in the file, build a brillouin zone object. + # We will use it just to get the linear k values for plotting. + bz = BrillouinZone(geometry, k=k, weight=weights) + + # Read the sizes of the file, which contain the number of spin channels + # and the number of orbitals and the number of k points. + nspin, nou, nk, _ = wfsx_sile.read_sizes() + + # Find out the spin class of the calculation. + spin = Spin({ + 1: Spin.UNPOLARIZED, 2: Spin.POLARIZED, + 4: Spin.NONCOLINEAR, 8: Spin.SPINORBIT + }[nspin]) + # Now find out how many spin channels we need. Note that if there is only + # one spin channel there will be no "spin" dimension on the final dataset. + nspin = 2 if spin.is_polarized else 1 + + # Determine whether spin moments will be calculated. + spin_moments = False + if not spin.is_diagonal: + # We need to set the parent + try: + H = sisl.get_sile(fdf).read_hamiltonian() + if H is not None: + # We could read a hamiltonian, set it as the parent of the wfsx sile + wfsx_sile = FileDataSIESTA(path=wfsx_sile.file, kwargs=dict(parent=parent)) + spin_moments = True + except: + pass + + # Get the wrapper function that we should call on each eigenstate. + # This also returns the coordinates and names to build the final dataset. + bands_wrapper, all_vars, coords_values = _get_eigenstate_wrapper( + sisl.physics.linspace_bz(bz), extra_vars=extra_vars, + spin_moments=spin_moments, spin=spin + ) + # Make sure all coordinates have values so that we can assume the shape + # of arrays below. + coords_values['band'] = np.arange(0, nwf) + coords_values['orb'] = np.arange(0, nou) + + # Initialize all the arrays. For each quantity we will initialize + # an array of the needed shape. + arrays = {} + for var in all_vars: + # These are all the extra dimensions of the quantity. Note that a + # quantity does not need to have extra dimensions. + extra_shape = [len(coords_values[coord]) for coord in var['coords']] + # First two dimensions will always be the spin channel and the k index. + # Then add potential extra dimensions. + shape = (nspin, len(bz), *extra_shape) + # Initialize the array. + arrays[var['name']] = np.empty(shape, dtype=var.get('dtype', np.float64)) + + # Loop through eigenstates in the WFSX file and add their contribution to the bands + ik = -1 + for eigenstate in wfsx_sile.yield_eigenstate(): + i_spin = eigenstate.info.get("spin", 0) + # Every time we encounter spin 0, we are in a new k point. + if i_spin == 0: + ik +=1 + if ik == 0: + # If this is the first eigenstate we read, get the wavefunction + # indices. We will assume that ALL EIGENSTATES have the same indices. + # Note that we already checked previously that they all have the same + # number of wfs, so this is a fair assumption. + coords_values['band'] = eigenstate.info['index'] + + # Get all the values for this eigenstate. + returns = bands_wrapper(eigenstate, spin_index=i_spin) + # And store them in the respective arrays. + for var, vals in zip(all_vars, returns): + arrays[var['name']][i_spin, ik] = vals + + # Now that we have all the values, just build the dataset. + bands_data = xr.Dataset( + data_vars={ + var['name']: (("spin", "k", *var['coords']), arrays[var['name']]) + for var in all_vars + } + ).assign_coords(coords_values) + + bands_data.attrs = {"parent": bz, "spin": spin, "geometry": geometry} + + return cls.new(bands_data) + + @new.register + @classmethod + def from_aiida(cls, aiida_bands: Aiida_node): + """ + Creates the bands plot reading from an aiida BandsData node. + """ + plot_data = aiida_bands._get_bandplot_data(cartesian=True) + bands = plot_data["y"] + + # Expand the bands array to have an extra dimension for spin + if bands.ndim == 2: + bands = np.expand_dims(bands, 0) + + # Get the info about where to put the labels + tick_info = defaultdict(list) + for tick, label in plot_data["labels"]: + tick_info["tickvals"].append(tick) + tick_info["ticktext"].append(label) + + # Construct the dataarray + data = xr.DataArray( + bands, + coords={ + "spin": np.arange(0, bands.shape[0]), + "k": ('k', plot_data["x"], {"axis": tick_info}), + "band": np.arange(0, bands.shape[2]), + }, + dims=("spin", "k", "band"), + ) + + return cls.new(data) + +def _get_eigenstate_wrapper(k_vals, spin, extra_vars: Sequence[Union[Dict, str]] = (), spin_moments: bool = True): + """Helper function to build the function to call on each eigenstate. + + Parameters + ---------- + k_vals: array_like of shape (nk,) + The (linear) values of the k points. This will be used for plotting + the bands. + extra_vars: array-like of dict, optional + This argument determines the extra quantities that should be included + in the final dataset of the bands. Energy and spin moments (if available) + are already included, so no need to pass them here. + Each item of the array defines a new quantity and should contain a dictionary + with the following keys: + - 'name', str: The name of the quantity. + - 'getter', callable: A function that gets 3 arguments: eigenstate, plot and + spin index, and returns the values of the quantity in a numpy array. This + function will be called for each eigenstate object separately. That is, once + for each (k-point, spin) combination. + - 'coords', tuple of str: The names of the dimensions of the returned array. + The number of coordinates should match the number of dimensions. + of + - 'coords_values', dict: If this variable introduces a new coordinate, you should + pass the values for that coordinate here. If the coordinates were already defined + by another variable, they will already have values. If you are unsure that the + coordinates are new, just pass the values for them, they will get overwritten. + spin_moments: bool, optional + Whether to add, if the spin is not diagonal, spin moments. + + Returns + -------- + function: + The function that should be called for each eigenstate and will return a tuple of size + n_vars with the values for each variable. + tuple of dicts: + A tuple containing the dictionaries that define all variables. Exactly the same as + the passed `extra_vars`, but with the added Energy and spin moment (if available) variables. + dict: + Dictionary containing the values for each coordinate involved in the dataset. + """ + # In case it is a non_colinear or spin-orbit calculation we will get the spin moments + if spin_moments and not spin.is_diagonal: + extra_vars = ("spin_moment", *extra_vars) + + # Define the available spin indices. Notice that at the end the spin dimension + # is removed from the dataset unless the calculation is spin polarized. So having + # spin_indices = [0] is just for convenience. + spin_indices = [0] + if spin.is_polarized: + spin_indices = [0, 1] + + # Add a variable to get the eigenvalues. + all_vars = ({ + "coords": ("band",), "coords_values": {"spin": spin_indices, "k": k_vals}, + "name": "E", "getter": lambda eigenstate, spin, spin_index: eigenstate.eig}, + *extra_vars + ) + + # Convert known variable keys to actual variables. + all_vars = tuple( + _KNOWN_EIGENSTATE_VARS[var] if isinstance(var, str) else var for var in all_vars + ) + + # Now build the function that will be called for each eigenstate and will + # return the values for each variable. + def bands_wrapper(eigenstate, spin_index): + return tuple(var["getter"](eigenstate, spin, spin_index) for var in all_vars) + + # Finally get the values for all coordinates involved. + coords_values = {} + for var in all_vars: + coords_values.update(var.get("coords_values", {})) + + return bands_wrapper, all_vars, coords_values + +def _norm2_from_eigenstate(eigenstate, spin, spin_index): + + norm2 = eigenstate.norm2(sum=False) + + if not spin.is_diagonal: + # If it is a non-colinear or spin orbit calculation, we have two weights for each + # orbital (one for each spin component of the state), so we just pair them together + # and sum their contributions to get the weight of the orbital. + norm2 = norm2.reshape(len(norm2), -1, 2).sum(2) + + return norm2.real + +def _spin_moment_getter(eigenstate, spin, spin_index): + return eigenstate.spin_moment().real + +_KNOWN_EIGENSTATE_VARS = { + "norm2": { + "coords": ("band", "orb"), + "name": "norm2", + "getter": _norm2_from_eigenstate + }, + "spin_moment": { + "coords": ("axis", "band"), "coords_values": dict(axis=["x", "y", "z"]), + "name": "spin_moments", "getter": _spin_moment_getter + } +} diff --git a/src/sisl/viz/data/data.py b/src/sisl/viz/data/data.py new file mode 100644 index 0000000000..ddd9d932c2 --- /dev/null +++ b/src/sisl/viz/data/data.py @@ -0,0 +1,39 @@ +from typing import Any, Union, get_args, get_origin, get_type_hints + + +class Data: + """Base data class""" + + _data: Any + + def __init__(self, data): + if isinstance(data, self.__class__): + data = data._data + + self._data = data + + def sanity_check(self): + + def is_valid(data, expected_type) -> bool: + if expected_type is Any: + return True + + return isinstance(data, expected_type) + + expected_type = get_type_hints(self.__class__)['_data'] + if get_origin(expected_type) is Union: + + valid = False + for valid_type in get_args(expected_type): + valid = valid | is_valid(self._data, valid_type) + + else: + valid = is_valid(self._data, expected_type) + + assert valid, f"Data must be of type {expected_type} but is {type(self._data).__name__}" + + def __getattr__(self, key): + return getattr(self._data, key) + + def __dir__(self): + return dir(self._data) diff --git a/src/sisl/viz/data/eigenstate.py b/src/sisl/viz/data/eigenstate.py new file mode 100644 index 0000000000..d53e2a4a4d --- /dev/null +++ b/src/sisl/viz/data/eigenstate.py @@ -0,0 +1,84 @@ +from pathlib import Path +from typing import Literal, Tuple + +import sisl +from sisl.io import fdfSileSiesta, wfsxSileSiesta + +from .._single_dispatch import singledispatchmethod +from ..data_sources import FileDataSIESTA +from .data import Data + + +class EigenstateData(Data): + """Wavefunction data class""" + + @singledispatchmethod + @classmethod + def new(cls, data): + return cls(data) + + @new.register + @classmethod + def from_eigenstate(cls, eigenstate: sisl.EigenstateElectron): + return cls(eigenstate) + + @new.register + @classmethod + def from_path(cls, path: Path, *args, **kwargs): + """Creates a sile from the path and tries to read the PDOS from it.""" + return cls.new(sisl.get_sile(path), *args, **kwargs) + + @new.register + @classmethod + def from_string(cls, string: str, *args, **kwargs): + """Assumes the string is a path to a file""" + return cls.new(Path(string), *args, **kwargs) + + @new.register + @classmethod + def from_fdf(cls, + fdf: fdfSileSiesta, source: Literal["wfsx", "hamiltonian"] = "wfsx", + k: Tuple[float, float, float] = (0, 0, 0), spin: int = 0, + ): + if source == "wfsx": + sile = FileDataSIESTA(fdf=fdf, cls=wfsxSileSiesta) + + assert isinstance(sile, wfsxSileSiesta) + + geometry = fdf.read_geometry(output=True) + + return cls.new(sile, geometry=geometry, k=k, spin=spin) + elif source == "hamiltonian": + H = fdf.read_hamiltonian() + + return cls.new(H, k=k, spin=spin) + + @new.register + @classmethod + def from_siesta_wfsx(cls, wfsx_file: wfsxSileSiesta, geometry: sisl.Geometry, k: Tuple[float, float, float] = (0, 0, 0), spin: int = 0): + """Reads the wavefunction coefficients from a SIESTA WFSX file""" + # Get the WFSX file. If not provided, it is inferred from the fdf. + if not wfsx_file.file.is_file(): + raise ValueError(f"File '{wfsx_file.file}' does not exist.") + + sizes = wfsx_file.read_sizes() + H = sisl.Hamiltonian(geometry, dim=sizes.nspin) + + wfsx = sisl.get_sile(wfsx_file.file, parent=H) + + # Try to find the eigenstate that we need + eigenstate = wfsx.read_eigenstate(k=k, spin=spin) + if eigenstate is None: + # We have not found it. + raise ValueError(f"A state with k={k} was not found in file {wfsx.file}.") + + return cls.new(eigenstate) + + @new.register + @classmethod + def from_hamiltonian(cls, H: sisl.Hamiltonian, k: Tuple[float, float, float] = (0, 0, 0), spin: int = 0): + """Calculates the eigenstates from a Hamiltonian and then generates the wavefunctions.""" + return cls.new(H.eigenstate(k, spin=spin)) + + def __getitem__(self, key): + return self._data[key] \ No newline at end of file diff --git a/src/sisl/viz/data/pdos.py b/src/sisl/viz/data/pdos.py new file mode 100644 index 0000000000..7971cf6902 --- /dev/null +++ b/src/sisl/viz/data/pdos.py @@ -0,0 +1,356 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +from pathlib import Path +from typing import Literal, Optional, Sequence, Union + +import numpy as np +from xarray import DataArray + +import sisl +from sisl.geometry import Geometry +from sisl.io import fdfSileSiesta, pdosSileSiesta, tbtncSileTBtrans, wfsxSileSiesta +from sisl.physics import Hamiltonian, Spin +from sisl.physics.distribution import get_distribution + +from .._single_dispatch import singledispatchmethod +from ..data_sources import FileDataSIESTA +from ..processors.spin import get_spin_options +from .xarray import OrbitalData + +try: + import pathos + _do_parallel_calc = True +except: + _do_parallel_calc = False + +class PDOSData(OrbitalData): + """Holds PDOS Data in a custom xarray DataArray. + + The point of this class is to normalize the data coming from different sources + so that functions can use it without worrying where the data came from. + """ + + def sanity_check(self, + na: Optional[int] = None, no: Optional[int] = None, n_spin: Optional[int] = None, + atom_tags: Optional[Sequence[str]] = None, + dos_checksum: Optional[float] = None + ): + """Check that the dataarray satisfies the requirements to be treated as PDOSData.""" + super().sanity_check() + + array = self._data + geometry = array.attrs["geometry"] + assert isinstance(geometry, Geometry) + + if na is not None: + assert geometry.na == na + if no is not None: + assert geometry.no == no + if atom_tags is not None: + assert len(set(atom_tags) - set([atom.tag for atom in geometry.atoms.atom])) == 0 + + for k in ("spin", "orb", "E"): + assert k in array.dims, f"'{k}' dimension missing, existing dimensions: {array.dims}" + + # Check if we have the correct number of spin channels + if n_spin is not None: + assert len(array.spin) == n_spin + # Check if we have the correct number of orbitals + assert len(array.orb) == geometry.no + + # Check if the checksum of the DOS is correct + if dos_checksum is not None: + this_dos_checksum = float(array.sum()) + assert np.allclose(this_dos_checksum, dos_checksum), f"Checksum of the DOS is incorrect. Expected {dos_checksum} but got {this_dos_checksum}" + + @classmethod + def toy_example(cls, geometry: Optional[Geometry] = None, spin: Union[str, int, Spin] = "", nE: int = 100): + """Creates a toy example of a bands data array""" + + if geometry is None: + orbitals = [ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz") + ] + + geometry = sisl.geom.graphene(atoms=sisl.Atom(Z=6, orbitals=orbitals)) + + PDOS = np.random.rand(geometry.no, nE) + + spin = Spin(spin) + + if spin.is_polarized: + PDOS = np.array([PDOS / 2, PDOS / 2]) + elif not spin.is_diagonal: + PDOS = np.array([PDOS, PDOS, np.zeros_like(PDOS), np.zeros_like(PDOS)]) + + return cls.new(PDOS, geometry, np.arange(nE), spin=spin) + + @singledispatchmethod + @classmethod + def new(cls, data: DataArray) -> "PDOSData": + return cls(data) + + @new.register + @classmethod + def from_numpy(cls, + PDOS: np.ndarray, geometry: Geometry, E: Sequence[float], E_units: str = 'eV', + spin: Optional[Union[sisl.Spin, str, int]] = None, extra_attrs: dict = {} + ): + """ + + Parameters + ---------- + PDOS: numpy.ndarray of shape ([nSpin], nE, nOrb) + The Projected Density Of States, orbital resolved. The array can have 2 or 3 dimensions, + since the spin dimension is optional. The spin class of the calculation that produced the data + is inferred from the spin dimension: + If there is no spin dimension or nSpin == 1, the calculation is spin unpolarized. + If nSpin == 2, the calculation is spin polarized. It is expected that [total, z] is + provided, not [spin0, spin1]. + If nSpin == 4, the calculation is assumed to be with noncolinear spin. + geometry: sisl.Geometry + The geometry to which the data corresponds. It must have as many orbitals as the PDOS data. + E: numpy.ndarray of shape (nE,) + The energies to which the data corresponds. + E_units: str, optional + The units of the energy. Defaults to 'eV'. + extra_attrs: dict + A dictionary of extra attributes to be added to the DataArray. One of the attributes that + """ + # Understand what the spin class is for this data. + data_spin = sisl.Spin.UNPOLARIZED + if PDOS.squeeze().ndim == 3: + data_spin = { + 1: sisl.Spin.UNPOLARIZED, + 2: sisl.Spin.POLARIZED, + 4: sisl.Spin.NONCOLINEAR + }[PDOS.shape[0]] + data_spin = sisl.Spin(data_spin) + + # If no spin specification was passed, then assume the spin is what we inferred from the data. + # Otherwise, make sure the spin specification is consistent with the data. + if spin is None: + spin = data_spin + else: + spin = sisl.Spin(spin) + if data_spin.is_diagonal: + assert spin == data_spin + else: + assert not spin.is_diagonal + + if PDOS.ndim == 2: + # Add an extra axis for spin at the beggining if the array only has dimensions for orbitals and energy. + PDOS = PDOS[None, ...] + + # Check that the number of orbitals in the geometry and the data match. + orb_dim = PDOS.ndim - 2 + if geometry is not None: + if geometry.no != PDOS.shape[orb_dim]: + raise ValueError(f"The geometry provided contains {geometry.no} orbitals, while we have PDOS information of {PDOS.shape[orb_dim]}.") + + # Build the standardized dataarray, with everything needed to understand it. + E_units = extra_attrs.pop("E_units", "eV") + + if spin.is_polarized: + spin_coords = ['total', 'z'] + elif not spin.is_diagonal: + spin_coords = get_spin_options(spin) + else: + spin_coords = ["total"] + + coords = [("spin", spin_coords), ("orb", range(PDOS.shape[orb_dim])), ("E", E, {"units": E_units})] + + attrs = {"spin": spin, "geometry": geometry, "units": f"1/{E_units}", **extra_attrs} + + return cls.new(DataArray( + PDOS, + coords=coords, + name="PDOS", + attrs=attrs + )) + + @new.register + @classmethod + def from_path(cls, path: Path, *args, **kwargs): + """Creates a sile from the path and tries to read the PDOS from it.""" + return cls.new(sisl.get_sile(path), *args, **kwargs) + + @new.register + @classmethod + def from_string(cls, string: str, *args, **kwargs): + """Assumes the string is a path to a file""" + return cls.new(Path(string), *args, **kwargs) + + @new.register + @classmethod + def from_fdf(cls, + fdf: fdfSileSiesta, source: Literal["pdos", "tbtnc", "wfsx", "hamiltonian"] = "pdos", + **kwargs + ): + """Gets the PDOS from the fdf file. + + It uses the fdf file as the pivoting point to find the rest of files needed. + + Parameters + ---------- + fdf: fdfSileSiesta + The fdf file to read the PDOS from. + source: Literal["pdos", "tbtnc", "wfsx", "hamiltonian"], optional + The source to read the PDOS data from. + **kwargs + Extra arguments to be passed to the PDOSData constructor, which depends + on the source requested. + + Except for the hamiltonian source, no extra arguments are needed (and they + won't be used). See PDOSData.from_hamiltonian for the extra arguments accepted + by the hamiltonian data constructor. + """ + if source == "pdos": + sile = FileDataSIESTA(fdf=fdf, cls=pdosSileSiesta) + + assert isinstance(sile, pdosSileSiesta) + + return cls.new(sile) + elif source == "tbtnc": + sile = FileDataSIESTA(fdf=fdf, cls=tbtncSileTBtrans) + + assert isinstance(sile, tbtncSileTBtrans) + + geometry = fdf.read_geometry(output=True) + + return cls.new(sile, geometry=geometry) + elif source == "wfsx": + sile = FileDataSIESTA(fdf=fdf, cls=wfsxSileSiesta) + + assert isinstance(sile, wfsxSileSiesta) + + geometry = fdf.read_geometry(output=True) + + return cls.new(sile, geometry=geometry) + elif source == "hamiltonian": + H = fdf.read_hamiltonian() + + return cls.new(H, **kwargs) + + @new.register + @classmethod + def from_siesta_pdos(cls, pdos_file: pdosSileSiesta): + """Gets the PDOS from a SIESTA PDOS file""" + # Get the info from the .PDOS file + geometry, E, PDOS = pdos_file.read_data() + + return cls.new(PDOS, geometry, E) + + @new.register + @classmethod + def from_tbtrans(cls, tbt_nc: tbtncSileTBtrans, geometry: Union[Geometry, None] = None): + """Reads the PDOS from a *.TBT.nc file coming from a TBtrans run.""" + PDOS = tbt_nc.DOS(sum=False).T + E = tbt_nc.E + + read_geometry_kwargs = {} + if geometry is not None: + read_geometry_kwargs["atom"] = geometry.atoms + + # Read the geometry from the TBT.nc file and get only the device part + geometry = tbt_nc.read_geometry(**read_geometry_kwargs).sub(tbt_nc.a_dev) + + return cls.new(PDOS, geometry, E) + + @new.register + @classmethod + def from_hamiltonian(cls, H: Hamiltonian, kgrid=None, kgrid_displ=(0, 0, 0), Erange=(-2, 2), + E0=0, nE=100, distribution=get_distribution("gaussian")): + """Calculates the PDOS from a sisl Hamiltonian.""" + + # Get the kgrid or generate a default grid by checking the interaction between cells + # This should probably take into account how big the cell is. + kgrid = kgrid + if kgrid is None: + kgrid = [3 if nsc > 1 else 1 for nsc in H.geometry.nsc] + + Erange = Erange + if Erange is None: + raise ValueError('You need to provide an energy range to calculate the PDOS from the Hamiltonian') + + E = np.linspace(Erange[0], Erange[-1], nE) + E0 + + bz = sisl.MonkhorstPack(H, kgrid, kgrid_displ) + + # Define the available spins + spin_indices = [0] + if H.spin.is_polarized: + spin_indices = [0, 1] + + # Calculate the PDOS for all available spins + PDOS = [] + for spin in spin_indices: + with bz.apply(pool=_do_parallel_calc) as parallel: + spin_PDOS = parallel.average.eigenstate( + spin=spin, + wrap=lambda eig: eig.PDOS(E, distribution=distribution) + ) + + PDOS.append(spin_PDOS) + + if len(spin_indices) == 1: + PDOS = PDOS[0] + else: + # Convert from spin components to total and z contributions. + total = PDOS[0] + PDOS[1] + z = PDOS[0] - PDOS[1] + + PDOS = np.concatenate([total, z]) + + PDOS = np.array(PDOS) + + return cls.new(PDOS, H.geometry, E, spin=H.spin, extra_attrs={'bz': bz}) + + @new.register + @classmethod + def from_wfsx(cls, + wfsx_file: wfsxSileSiesta, + H: Hamiltonian, geometry: Union[Geometry, None] = None, + Erange=(-2, 2), nE: int = 100, E0: float = 0, distribution=get_distribution('gaussian') + ): + """Generates the PDOS values from a file containing eigenstates.""" + if geometry is None: + geometry = getattr(H, "geometry", None) + + # Get the wfsx file + wfsx_sile = FileDataSIESTA( + path=wfsx_file, cls=sisl.io.wfsxSileSiesta, parent=H + ) + + # Read the sizes of the file, which contain the number of spin channels + # and the number of orbitals and the number of k points. + sizes = wfsx_sile.read_sizes() + # Check that spin sizes of hamiltonian and wfsx file match + assert H.spin.size == sizes.nspin, \ + f"Hamiltonian has spin size {H.spin.size} while file has spin size {sizes.nspin}" + # Get the size of the spin channel. The size returned might be 8 if it is a spin-orbit + # calculation, but we need only 4 spin channels (total, x, y and z), same as with non-colinear + nspin = min(4, sizes.nspin) + + # Get the energies for which we need to calculate the PDOS. + Erange = Erange + E = np.linspace(Erange[0], Erange[-1], nE) + E0 + + # Initialize the PDOS array + PDOS = np.zeros((nspin, sizes.no_u, E.shape[0]), dtype=np.float64) + + # Loop through eigenstates in the WFSX file and add their contribution to the PDOS. + # Note that we pass the hamiltonian as the parent here so that the overlap matrix + # for each point can be calculated by eigenstate.PDOS() + for eigenstate in wfsx_sile.yield_eigenstate(): + spin = eigenstate.info.get("spin", 0) + if nspin == 4: + spin = slice(None) + + PDOS[spin] += eigenstate.PDOS(E, distribution=distribution) * eigenstate.info.get("weight", 1) + + return cls.new(PDOS, geometry, E, spin=H.spin) diff --git a/src/sisl/viz/data/sisl_objs.py b/src/sisl/viz/data/sisl_objs.py new file mode 100644 index 0000000000..dcc4b78126 --- /dev/null +++ b/src/sisl/viz/data/sisl_objs.py @@ -0,0 +1,28 @@ +from typing import Any, get_type_hints + +from sisl import Geometry, Grid, Hamiltonian + +from .data import Data + + +class SislObjData(Data): + """Base class for sisl objects""" + def __instancecheck__(self, instance: Any) -> bool: + expected_type = get_type_hints(self.__class__)['_data'] + return isinstance(instance, expected_type) + + def __subclasscheck__(self, subclass: Any) -> bool: + expected_type = get_type_hints(self.__class__)['_data'] + return issubclass(subclass, expected_type) + +class GeometryData(Data): + """Geometry data class""" + _data: Geometry + +class GridData(Data): + """Grid data class""" + _data: Grid + +class HamiltonianData(Data): + """Hamiltonian data class""" + _data: Hamiltonian \ No newline at end of file diff --git a/src/sisl/viz/data/tests/.coverage b/src/sisl/viz/data/tests/.coverage new file mode 100644 index 0000000000..1859d98b1b Binary files /dev/null and b/src/sisl/viz/data/tests/.coverage differ diff --git a/src/sisl/viz/data/tests/conftest.py b/src/sisl/viz/data/tests/conftest.py new file mode 100644 index 0000000000..a752c0a50b --- /dev/null +++ b/src/sisl/viz/data/tests/conftest.py @@ -0,0 +1,12 @@ +import os.path as osp + +import pytest + + +@pytest.fixture(scope="session") +def siesta_test_files(sisl_files): + + def _siesta_test_files(path): + return sisl_files(osp.join('sisl', 'io', 'siesta', path)) + + return _siesta_test_files \ No newline at end of file diff --git a/src/sisl/viz/data/tests/test_bands.py b/src/sisl/viz/data/tests/test_bands.py new file mode 100644 index 0000000000..70c1abe0b4 --- /dev/null +++ b/src/sisl/viz/data/tests/test_bands.py @@ -0,0 +1,71 @@ +import pytest + +import sisl +from sisl.viz.data import BandsData + + +@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) +def test_bands_from_sisl_H(spin): + gr = sisl.geom.graphene() + H = sisl.Hamiltonian(gr) + H.construct([(0.1, 1.44), (0, -2.7)]) + + n_spin, H = { + "unpolarized": (1, H), + "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), + "noncolinear": (4, H.transform(spin=sisl.Spin.NONCOLINEAR)), + "spinorbit": (4, H.transform(spin=sisl.Spin.SPINORBIT)) + }[spin] + + bz = sisl.BandStructure(H, [[0, 0, 0], [2/3, 1/3, 0], [1/2, 0, 0]], 6, ["Gamma", "M", "K"]) + + data = BandsData.new(bz) + + data.sanity_check(n_spin=n_spin, nk=6, nbands=2, klabels=["Gamma", "M", "K"], kvals=[0., 1.70309799, 2.55464699]) + +@pytest.mark.parametrize("spin", ["unpolarized"]) +def test_bands_from_siesta_bands(spin, siesta_test_files): + + n_spin, filename = { + "unpolarized": (1, "SrTiO3.bands"), + }[spin] + + file = siesta_test_files(filename) + + data = BandsData.new(file) + + data.sanity_check(n_spin=n_spin, nk=150, nbands=72, klabels=('Gamma', 'X', 'M', 'Gamma', 'R', 'X'), kvals=[0.0, 0.429132, 0.858265, 1.465149, 2.208428, 2.815313]) + +@pytest.mark.parametrize("spin", ["noncolinear"]) +def test_bands_from_siesta_wfsx(spin, siesta_test_files): + + n_spin, filename = { + "noncolinear": (4, "bi2se3_3ql.bands.WFSX"), + }[spin] + + wfsx = sisl.get_sile(siesta_test_files(filename)) + fdf = siesta_test_files("bi2se3_3ql.fdf") + + data = BandsData.new(wfsx, fdf=fdf) + + data.sanity_check(n_spin=n_spin, nk=16, nbands=4) + +@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) +def test_toy_example(spin): + + nk = 15 + n_states = 28 + + data = BandsData.toy_example(spin=spin, nk=nk, n_states=n_states) + + n_spin = { + "unpolarized": 1, + "polarized": 2, + "noncolinear": 4, + "spinorbit": 4 + }[spin] + + data.sanity_check(n_spin=n_spin, nk=nk, nbands=n_states, klabels=["Gamma", "X"], kvals=[0, 1]) + + if n_spin == 4: + assert "spin_moments" in data.data_vars diff --git a/src/sisl/viz/data/tests/test_pdos.py b/src/sisl/viz/data/tests/test_pdos.py new file mode 100644 index 0000000000..b61e445e6a --- /dev/null +++ b/src/sisl/viz/data/tests/test_pdos.py @@ -0,0 +1,84 @@ +import pytest + +import sisl +from sisl.viz.data import PDOSData + + +@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) +def test_pdos_from_sisl_H(spin): + gr = sisl.geom.graphene() + H = sisl.Hamiltonian(gr) + H.construct([(0.1, 1.44), (0, -2.7)]) + + n_spin, H = { + "unpolarized": (1, H), + "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), + "noncolinear": (4, H.transform(spin=sisl.Spin.NONCOLINEAR)), + "spinorbit": (4, H.transform(spin=sisl.Spin.SPINORBIT)) + }[spin] + + data = PDOSData.new(H, Erange=(-5, 5)) + + checksum = 17.599343960516066 + if n_spin > 1: + checksum = checksum * 2 + + data.sanity_check(na=2, no=2, n_spin=n_spin, atom_tags=('C',), dos_checksum=checksum) + +@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear"]) +def test_pdos_from_siesta_PDOS(spin, siesta_test_files): + + n_spin, filename = { + "unpolarized": (1, "SrTiO3.PDOS"), + "polarized": (2, "SrTiO3_polarized.PDOS"), + "noncolinear": (4, "SrTiO3_noncollinear.PDOS") + }[spin] + + file = siesta_test_files(filename) + + data = PDOSData.new(file) + + checksum = 1240.0012709612743 + if n_spin > 1: + checksum = checksum * 2 + + data.sanity_check(na=5, no=72, n_spin=n_spin, atom_tags=('Sr', 'Ti', 'O'), dos_checksum=checksum) + +@pytest.mark.parametrize("spin", ["noncolinear"]) +def test_pdos_from_siesta_wfsx(spin, siesta_test_files): + + n_spin, filename = { + "noncolinear": (4, "bi2se3_3ql.bands.WFSX"), + }[spin] + + # From a siesta .WFSX file + # Since there is no hamiltonian for bi2se3_3ql.fdf, we create a dummy one + wfsx = sisl.get_sile(siesta_test_files(filename)) + + geometry = sisl.get_sile(siesta_test_files("bi2se3_3ql.fdf")).read_geometry() + geometry = sisl.Geometry(geometry.xyz, atoms=wfsx.read_basis()) + + H = sisl.Hamiltonian(geometry, dim=4) + + data = PDOSData.new(wfsx, H=H) + + # For now, the checksum is 0 because we have no overlap matrix. + checksum = 0 + if n_spin > 1: + checksum = checksum * 2 + + data.sanity_check(na=15, no=195, n_spin=n_spin, atom_tags=('Bi', 'Se'), dos_checksum=checksum) + +@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) +def test_toy_example(spin): + + data = PDOSData.toy_example(spin=spin) + + n_spin = { + "unpolarized": 1, + "polarized": 2, + "noncolinear": 4, + "spinorbit": 4 + }[spin] + + data.sanity_check(n_spin=n_spin, na=data.geometry.na, no=data.geometry.no, atom_tags=["C"]) diff --git a/src/sisl/viz/data/xarray.py b/src/sisl/viz/data/xarray.py new file mode 100644 index 0000000000..6a8747d025 --- /dev/null +++ b/src/sisl/viz/data/xarray.py @@ -0,0 +1,28 @@ +from typing import Union + +from xarray import DataArray, Dataset + +from .data import Data + + +class XarrayData(Data): + + _data: Union[DataArray, Dataset] + + def __init__(self, data: Union[DataArray, Dataset]): + super().__init__(data) + + def __getattr__(self, key): + sisl_accessor = self._data.sisl + + if hasattr(sisl_accessor, key): + return getattr(sisl_accessor, key) + + return getattr(self._data, key) + + def __dir__(self): + return dir(self._data.sisl) + dir(self._data) + + +class OrbitalData(XarrayData): + pass \ No newline at end of file diff --git a/src/sisl/viz/data_sources/__init__.py b/src/sisl/viz/data_sources/__init__.py new file mode 100644 index 0000000000..91cb83fa6a --- /dev/null +++ b/src/sisl/viz/data_sources/__init__.py @@ -0,0 +1,7 @@ +from .atom_data import * +from .bond_data import * +from .data_source import * +from .eigenstate_data import * +from .file import * +from .hamiltonian_source import * +from .orbital_data import * diff --git a/src/sisl/viz/data_sources/atom_data.py b/src/sisl/viz/data_sources/atom_data.py new file mode 100644 index 0000000000..9163db0776 --- /dev/null +++ b/src/sisl/viz/data_sources/atom_data.py @@ -0,0 +1,88 @@ +import numpy as np + +from sisl.atom import AtomGhost, PeriodicTable + +from .data_source import DataSource + + +class AtomData(DataSource): + def function(self, geometry, atoms=None): + raise NotImplementedError("") + +@AtomData.from_func +def AtomCoords(geometry, atoms=None): + return geometry.xyz[atoms] + +@AtomData.from_func +def AtomX(geometry, atoms=None): + return geometry.xyz[atoms, 0] + +@AtomData.from_func +def AtomY(geometry, atoms=None): + return geometry.xyz[atoms, 1] + +@AtomData.from_func +def AtomZ(geometry, atoms=None): + return geometry.xyz[atoms, 2] + +@AtomData.from_func +def AtomFCoords(geometry, atoms=None): + return geometry.sub(atoms).fxyz + +@AtomData.from_func +def AtomFx(geometry, atoms=None): + return geometry.sub(atoms).fxyz[:, 0] + +@AtomData.from_func +def AtomFy(geometry, atoms=None): + return geometry.sub(atoms).fxyz[:, 1] + +@AtomData.from_func +def AtomFz(geometry, atoms=None): + return geometry.sub(atoms).fxyz[:, 2] + +@AtomData.from_func +def AtomR(geometry, atoms=None): + return geometry.sub(atoms).maxR(all=True) + +@AtomData.from_func +def AtomZ(geometry, atoms=None): + return geometry.sub(atoms).atoms.Z + +@AtomData.from_func +def AtomNOrbitals(geometry, atoms=None): + return geometry.sub(atoms).orbitals + +class AtomDefaultColors(AtomData): + + _atoms_colors = { + "H": "#cccccc", + "O": "red", + "Cl": "green", + "N": "blue", + "C": "grey", + "S": "yellow", + "P": "orange", + "Au": "gold", + "else": "pink" + } + + def function(self, geometry, atoms=None): + return np.array([ + self._atoms_colors.get(atom.symbol, self._atoms_colors["else"]) + for atom in geometry.sub(atoms).atoms + ]) + +@AtomData.from_func +def AtomIsGhost(geometry, atoms=None, fill_true=True, fill_false=False): + return np.array([ + fill_true if isinstance(atom, AtomGhost) else fill_false + for atom in geometry.sub(atoms).atoms + ]) + +@AtomData.from_func +def AtomPeriodicTable(geometry, atoms=None, what=None, pt=PeriodicTable): + if not isinstance(pt, PeriodicTable): + pt = pt() + function = getattr(pt, what) + return function(geometry.sub(atoms).atoms.Z) \ No newline at end of file diff --git a/src/sisl/viz/data_sources/bond_data.py b/src/sisl/viz/data_sources/bond_data.py new file mode 100644 index 0000000000..b74d17a7a5 --- /dev/null +++ b/src/sisl/viz/data_sources/bond_data.py @@ -0,0 +1,63 @@ +from typing import Union + +import numpy as np + +import sisl +from sisl.utils.mathematics import fnorm + +from .data_source import DataSource + + +class BondData(DataSource): + + ndim: int + + @staticmethod + def function(geometry, bonds): + raise NotImplementedError("") + + pass + +def bond_lengths(geometry: sisl.Geometry, bonds: np.ndarray): + # Get an array with the coordinates defining the start and end of each bond. + # The array will be of shape (nbonds, 2, 3) + coords = geometry[np.ravel(bonds)].reshape(-1, 2, 3) + # Take the diff between the end and start -> shape (nbonds, 1 , 3) + # And then the norm of each vector -> shape (nbonds, 1, 1) + # Finally, we just ravel it to an array of shape (nbonds, ) + return fnorm(np.diff(coords, axis=1), axis=-1).ravel() + +def bond_strains(ref_geometry: sisl.Geometry, geometry: sisl.Geometry, bonds: np.ndarray): + assert ref_geometry.na == geometry.na, (f"Geometry provided (na={geometry.na}) does not have the" + f" same number of atoms as the reference geometry (na={ref_geometry.na})") + + ref_bond_lengths = bond_lengths(ref_geometry, bonds) + this_bond_lengths = bond_lengths(geometry, bonds) + + return (this_bond_lengths - ref_bond_lengths) / ref_bond_lengths + +def bond_data_from_atom(atom_data: np.ndarray, geometry: sisl.Geometry, bonds: np.ndarray, fold_to_uc: bool = False): + + if fold_to_uc: + bonds = geometry.sc2uc(bonds) + + return atom_data[bonds[:, 0]] + +def bond_data_from_matrix(matrix, geometry: sisl.Geometry, bonds: np.ndarray, fold_to_uc: bool = False): + + if fold_to_uc: + bonds = geometry.sc2uc(bonds) + + return matrix[bonds[:, 0], bonds[:, 1]] + +def bond_random(geometry: sisl.Geometry, bonds: np.ndarray, seed: Union[int, None] = None): + if seed is not None: + np.random.seed(seed) + + return np.random.random(len(bonds)) + +BondLength = BondData.from_func(bond_lengths) +BondStrain = BondData.from_func(bond_strains) +BondDataFromAtom = BondData.from_func(bond_data_from_atom) +BondDataFromMatrix = BondData.from_func(bond_data_from_matrix) +BondRandom = BondData.from_func(bond_random) diff --git a/src/sisl/viz/data_sources/data_source.py b/src/sisl/viz/data_sources/data_source.py new file mode 100644 index 0000000000..dc49475832 --- /dev/null +++ b/src/sisl/viz/data_sources/data_source.py @@ -0,0 +1,20 @@ +from sisl.nodes import Node + + +class DataSource(Node): + """Generic class for data sources. + + Data sources are a way of specifying and manipulating data without providing it explicitly. + Data sources can be passed to the settings of the plots as if they were arrays. + When the plot is being created, the data source receives the necessary inputs and is evaluated using + its ``get`` method. + + Therefore, passing a data source is like passing a function that will receive + inputs and calculate the values needed on the fly. However, it has some extra functionality. You can + perform operations with a data source. These operations will be evaluated lazily, that is, when + inputs are provided. That allows for very convenient manipulation of the data. + + Data sources are also useful for graphical interfaces, where the user is unable to explicitly + pass a function. Some of them are + """ + pass diff --git a/src/sisl/viz/data_sources/eigenstate_data.py b/src/sisl/viz/data_sources/eigenstate_data.py new file mode 100644 index 0000000000..524290c214 --- /dev/null +++ b/src/sisl/viz/data_sources/eigenstate_data.py @@ -0,0 +1,21 @@ +from typing import Literal + +import xarray as xr + +from .data_source import DataSource + + +class EigenstateData(DataSource): + pass + +def spin_moments_from_dataset(axis: Literal['x', 'y', 'z'], data: xr.Dataset) -> xr.DataArray: + if "spin_moments" not in data: + raise ValueError("The dataset does not contain spin moments") + + spin_moms = data.spin_moments.sel(axis=axis) + spin_moms = spin_moms.rename(f'spin_moments_{axis}') + return spin_moms + +class SpinMoment(EigenstateData): + + function = staticmethod(spin_moments_from_dataset) \ No newline at end of file diff --git a/src/sisl/viz/data_sources/file/__init__.py b/src/sisl/viz/data_sources/file/__init__.py new file mode 100644 index 0000000000..ee984e3f57 --- /dev/null +++ b/src/sisl/viz/data_sources/file/__init__.py @@ -0,0 +1,2 @@ +from .file_source import * +from .siesta import * \ No newline at end of file diff --git a/src/sisl/viz/data_sources/file/file_source.py b/src/sisl/viz/data_sources/file/file_source.py new file mode 100644 index 0000000000..fdb7728938 --- /dev/null +++ b/src/sisl/viz/data_sources/file/file_source.py @@ -0,0 +1,30 @@ +from pathlib import Path + +import sisl + +from ..data_source import DataSource + + +class FileData(DataSource): + """ Generic data source for reading data from a file. + + The aim of this class is twofold: + - Standarize the way data sources read files. + - Provide automatic updating features when the read files are updated. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._files_to_read = [] + + def follow_file(self, path): + self._files_to_read.append(Path(path).resolve()) + + def get_sile(self, path, **kwargs): + """ A wrapper around get_sile so that the reading of the file is registered""" + self.follow_file(path) + return sisl.get_sile(path, **kwargs) + + def function(self, **kwargs): + if isinstance(kwargs.get('path'), sisl.io.BaseSile): + kwargs['path'] = kwargs['path'].file + return self.get_sile(**kwargs) \ No newline at end of file diff --git a/src/sisl/viz/data_sources/file/siesta.py b/src/sisl/viz/data_sources/file/siesta.py new file mode 100644 index 0000000000..c64972258a --- /dev/null +++ b/src/sisl/viz/data_sources/file/siesta.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import sisl + +from .file_source import FileData + + +def get_sile(path=None, fdf=None, cls=None, **kwargs): + """ Wrapper around FileData.get_sile that infers files from the root fdf + + Parameters + ---------- + path : str or Path, optional + the path to the file to be read. + cls : sisl.io.SileSiesta, optional + if `path` is not provided, we try to infer it from the root fdf file, + looking for files that fullfill this class' rules. + + Returns + --------- + Sile: + The sile object. + """ + if fdf is not None and isinstance(fdf, (str, Path)): + fdf = get_sile(path=fdf) + + if path is None: + if cls is None: + raise ValueError(f"Either a path or a class must be provided to get_sile") + if fdf is None: + raise ValueError(f"We can not look for files of a sile type without a root fdf file.") + + for rule in sisl.get_sile_rules(cls=cls): + filename = fdf.get('SystemLabel', default='siesta') + f'.{rule.suffix}' + try: + path = fdf.dir_file(filename) + return get_sile(path=path, **kwargs) + except: + pass + else: + raise FileNotFoundError(f"Tried to find a {cls} from the root fdf ({fdf.file}), " + f"but didn't find any.") + + return sisl.get_sile(path, **kwargs) + +def FileDataSIESTA(path=None, fdf=None, cls=None, **kwargs): + if isinstance(path, sisl.io.BaseSile): + path = path.file + return get_sile(path=path, fdf=fdf, cls=cls, **kwargs) \ No newline at end of file diff --git a/src/sisl/viz/data_sources/hamiltonian_source.py b/src/sisl/viz/data_sources/hamiltonian_source.py new file mode 100644 index 0000000000..a72cf92e41 --- /dev/null +++ b/src/sisl/viz/data_sources/hamiltonian_source.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import sisl + +from .data_source import DataSource +from .file.siesta import FileData + + +class HamiltonianDataSource(DataSource): + + def __init__(self, H=None, kwargs={}): + super().__init__(H=H, kwargs=kwargs) + + def get_hamiltonian(self, H, **kwargs): + """ Setup the Hamiltonian object. + + Parameters + ---------- + H : sisl.Hamiltonian + The Hamiltonian object to be setup. + """ + + if isinstance(H, (str, Path)): + H = FileData(path=H) + if isinstance(H, (sisl.io.BaseSile)): + H = H.read_hamiltonian(**kwargs) + + if H is None: + raise ValueError("No hamiltonian found.") + + return H + + def function(self, H, kwargs): + return self.get_hamiltonian(H=H, **kwargs) diff --git a/src/sisl/viz/data_sources/orbital_data.py b/src/sisl/viz/data_sources/orbital_data.py new file mode 100644 index 0000000000..e4c7b8aa52 --- /dev/null +++ b/src/sisl/viz/data_sources/orbital_data.py @@ -0,0 +1,41 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +from typing import Literal, Union + +import numpy as np + +from ..plotutils import random_color +from .data_source import DataSource + +#from ..processors.orbital import reduce_orbital_data, get_orbital_request_sanitizer + + +class OrbitalData(DataSource): + pass + +def style_fatbands(data, groups=[{}]): + + # Get the function that is going to convert our request to something that can actually + # select orbitals from the xarray object. + _sanitize_request = get_orbital_request_sanitizer( + data, + gens={ + "color": lambda req: req.get("color") or random_color(), + } + ) + + styled = reduce_orbital_data( + data, groups, orb_dim="orb", spin_dim="spin", sanitize_group=_sanitize_request, + group_vars=('color', 'dash'), groups_dim="group", drop_empty=True, + spin_reduce=np.sum, + ) + + return styled#.color + +class FatbandsData(OrbitalData): + + function = staticmethod(style_fatbands) + + pass + diff --git a/src/sisl/viz/figure/__init__.py b/src/sisl/viz/figure/__init__.py new file mode 100644 index 0000000000..9bf7931f87 --- /dev/null +++ b/src/sisl/viz/figure/__init__.py @@ -0,0 +1,47 @@ +from .figure import BACKENDS, Figure, get_figure + + +class NotAvailableFigure(Figure): + + _package: str = "" + + def __init__(self, *args, **kwargs): + raise ModuleNotFoundError(f"{self.__class__.__name__} is not available because {self._package} is not installed.") + +try: + import plotly +except ModuleNotFoundError: + class PlotlyFigure(NotAvailableFigure): + _package = "plotly" +else: + from .plotly import PlotlyFigure + +try: + import matplotlib +except ModuleNotFoundError: + class MatplotlibFigure(NotAvailableFigure): + _package = "matplotlib" +else: + from .matplotlib import MatplotlibFigure + +try: + import py3Dmol +except ModuleNotFoundError: + class Py3DmolFigure(NotAvailableFigure): + _package = "py3Dmol" +else: + from .py3dmol import Py3DmolFigure + +try: + import bpy +except ModuleNotFoundError: + class BlenderFigure(NotAvailableFigure): + _package = "blender (bpy)" +else: + from .blender import BlenderFigure + + +BACKENDS["plotly"] = PlotlyFigure +BACKENDS["matplotlib"] = MatplotlibFigure +BACKENDS["py3dmol"] = Py3DmolFigure +BACKENDS["blender"] = BlenderFigure diff --git a/src/sisl/viz/figure/blender.py b/src/sisl/viz/figure/blender.py new file mode 100644 index 0000000000..75a42005aa --- /dev/null +++ b/src/sisl/viz/figure/blender.py @@ -0,0 +1,521 @@ +import collections +import itertools + +import bpy +import numpy as np + +from .figure import Figure + + +def add_line_frame(ani_objects, child_objects, frame): + """Creates the frames for a child plot lines. + + Given the objects of the lines collection in the animation, it uses + the corresponding lines in the child to set keyframes. + + Parameters + ----------- + ani_objects: CollectionObjects + the objects of the Atoms collection in the animation. + child_objects: CollectionObjects + the objects of the Atoms collection in the child plot. + frame: int + the frame number to which the keyframe values should be set. + """ + # Loop through all objects in the collections + for ani_obj, child_obj in zip(ani_objects, child_objects): + # Each curve object has multiple splines + for ani_spline, child_spline in zip(ani_obj.data.splines, child_obj.data.splines): + # And each spline has multiple points + for ani_point, child_point in zip(ani_spline.bezier_points, child_spline.bezier_points): + # Set the position of that point + ani_point.co = child_point.co + ani_point.keyframe_insert(data_path="co", frame=frame) + + # Loop through all the materials that the object might have associated + for ani_material, child_material in zip(ani_obj.data.materials, child_obj.data.materials): + ani_mat_inputs = ani_material.node_tree.nodes["Principled BSDF"].inputs + child_mat_inputs = child_material.node_tree.nodes["Principled BSDF"].inputs + + for input_key in ("Base Color", "Alpha"): + ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value + ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) + + +def add_atoms_frame(ani_objects, child_objects, frame): + """Creates the frames for a child plot atoms. + + Given the objects of the Atoms collection in the animation, it uses + the corresponding atoms in the child to set keyframes. + + Parameters + ----------- + ani_objects: CollectionObjects + the objects of the Atoms collection in the animation. + child_objects: CollectionObjects + the objects of the Atoms collection in the child plot. + frame: int + the frame number to which the keyframe values should be set. + """ + # Loop through all objects in the collections + for ani_obj, child_obj in zip(ani_objects, child_objects): + # Set the atom position + ani_obj.location = child_obj.location + ani_obj.keyframe_insert(data_path="location", frame=frame) + + # Set the atom size + ani_obj.scale = child_obj.scale + ani_obj.keyframe_insert(data_path="scale", frame=frame) + + # Set the atom color and opacity + ani_mat_inputs = ani_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + child_mat_inputs = child_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + + for input_key in ("Base Color", "Alpha"): + ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value + ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) + + +class BlenderFigure(Figure): + """Generic canvas for the blender framework. + + This is the first experiment with it, so it is quite simple. + + Everything is drawn in the same scene. On initialization, a collections + dictionary is started. The keys should be the local name of a collection + in the canvas environment and the values are the actual collections. + """ + + # Experimental feature to adjust 2D plottings + #_2D_scale = (1, 1) + + _animatable_collections = { + "Lines": {"add_frame": add_line_frame}, + } + + def _init_figure(self, *args, **kwargs): + # This is the collection that will store everything related to the plot. + self._collection = bpy.data.collections.new(f"sislplot_{id(self)}") + self._collections = {} + + def _init_figure_animated(self, interpolated_frames: int = 5, **kwargs): + self._animation_settings = { + "interpolated_frames": interpolated_frames + } + return self._init_figure(**kwargs) + + def _iter_animation(self, plot_actions, interpolated_frames=5): + + interpolated_frames = self._animation_settings["interpolated_frames"] + + for i, section_actions in enumerate(plot_actions): + frame = i * interpolated_frames + + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": {**action.get("kwargs", {}), "frame": frame}} + + sanitized_section_actions.append(action) + + yield sanitized_section_actions + + def draw_on(self, figure): + self._plot.get_figure(backend=self._backend_name, clear_fig=False) + + def clear(self): + """ Clears the blender scene so that data can be reset""" + + for key, collection in self._collections.items(): + + self.clear_collection(collection) + + bpy.data.collections.remove(collection) + + self._collections = {} + + def get_collection(self, key): + if key not in self._collections: + self._collections[key] = bpy.data.collections.new(key) + self._collection.children.link(self._collections[key]) + + return self._collections[key] + + def remove_collection(self, key): + if key in self._collections: + collection = self._collections[key] + + self.clear_collection(collection) + + bpy.data.collections.remove(collection) + + del self._collections[key] + + def clear_collection(self, collection): + for obj in collection.objects: + bpy.data.objects.remove(obj, do_unlink=True) + + def draw_line(self, x, y, name="", line={}, marker={}, text=None, row=None, col=None, **kwargs): + z = np.full_like(x, 0) + # x = self._2D_scale[0] * x + # y = self._2D_scale[1] * y + return self.draw_line_3D(x, y, z, name=name, line=line, marker=marker, text=text, row=row, col=col, **kwargs) + + def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs): + z = np.full_like(x, 0) + # x = self._2D_scale[0] * x + # y = self._2D_scale[1] * y + return self.draw_scatter_3D(x, y, z, name=name, marker=marker, text=text, row=row, col=col, **kwargs) + + def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, **kwargs): + """Draws a line using a bezier curve.""" + if frame is not None: + return self._animate_line_3D(x, y, z, line=line, name=name, collection=collection, frame=frame, **kwargs) + + if collection is None: + collection = self.get_collection(name) + # First, generate the curve object + bpy.ops.curve.primitive_bezier_curve_add() + # Then get it from the context + curve_obj = bpy.context.object + # And give it a name + if name is None: + name = "" + curve_obj.name = name + + # Link the curve to our collection (remove it from the context one) + context_col = bpy.context.collection + if context_col is not collection: + context_col.objects.unlink(curve_obj) + collection.objects.link(curve_obj) + + # Retrieve the curve from the object + curve = curve_obj.data + # And modify some attributes to make it look cylindric + curve.dimensions = '3D' + curve.fill_mode = 'FULL' + width = line.get("width") + curve.bevel_depth = width if width is not None else 0.1 + curve.bevel_resolution = 10 + # Clear all existing splines from the curve, as we are going to add them + curve.splines.clear() + + xyz = np.array([x, y, z], dtype=float).T + + # To be compatible with other frameworks such as plotly and matplotlib, + # we allow x, y and z to contain None values that indicate discontinuities + # E.g.: x=[0, 1, None, 2, 3] means we should draw a line from 0 to 1 and another + # from 2 to 3. + # Here, we get the breakpoints (i.e. indices where there is a None). We add + # -1 and None at the sides to facilitate iterating. + breakpoint_indices = [-1, *np.where(np.isnan(xyz).any(axis=1))[0], None] + + # Now loop through all segments using the known breakpoints + for start_i, end_i in zip(breakpoint_indices, breakpoint_indices[1:]): + # Get the coordinates of the segment + segment_xyz = xyz[start_i+1: end_i] + + # If there is nothing to draw, go to next segment + if len(segment_xyz) == 0: + continue + + # Create a new spline (within the curve, we are not creating a new object!) + segment = curve.splines.new("BEZIER") + # Splines by default have only 1 point, add as many as we need + segment.bezier_points.add(len(segment_xyz) - 1) + # Assign the coordinates to each point + segment.bezier_points.foreach_set('co', np.ravel(segment_xyz)) + + # We want linear interpolation between points. If we wanted cubic interpolation, + # we would set this parameter to 3, for example. + segment.resolution_u = 1 + + # Give a color to our new curve object if it needs to be colored. + self._color_obj(curve_obj, line.get("color", None), line.get("opacity", 1)) + + return self + + def _animate_line_3D(self, x, y, z, line={}, name="", collection=None, frame=0, **kwargs): + if collection is None: + collection = self.get_collection(name) + + # If this is the first frame, draw the object as usual + if frame == 0: + self.draw_line_3D(x, y, z, line=line, name=name, collection=collection, frame=None, **kwargs) + + # Create a collection that we are just going to use to create new objects from which + # to copy the properties. + temp_collection_name = f"__animated_{name}" + temp_collection = self.get_collection(temp_collection_name) + self.clear_collection(temp_collection) + + self.draw_line_3D(x, y, z, line=line, name=name, collection=temp_collection, frame=None, **kwargs) + + # Loop through all objects in the collections + for ani_obj, child_obj in zip(collection.objects, temp_collection.objects): + # Each curve object has multiple splines + for ani_spline, child_spline in zip(ani_obj.data.splines, child_obj.data.splines): + # And each spline has multiple points + for ani_point, child_point in zip(ani_spline.bezier_points, child_spline.bezier_points): + # Set the position of that point + ani_point.co = child_point.co + ani_point.keyframe_insert(data_path="co", frame=frame) + + # Loop through all the materials that the object might have associated + for ani_material, child_material in zip(ani_obj.data.materials, child_obj.data.materials): + ani_mat_inputs = ani_material.node_tree.nodes["Principled BSDF"].inputs + child_mat_inputs = child_material.node_tree.nodes["Principled BSDF"].inputs + + for input_key in ("Base Color", "Alpha"): + ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value + ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) + + # Remove the temporal collection + self.remove_collection(temp_collection_name) + + def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, collection=None, frame=None, **kwargs): + if frame is not None: + return self._animate_balls_3D(x, y, z, name=name, marker=marker, row=row, col=col, collection=collection, frame=frame, **kwargs) + + if collection is None: + collection = self.get_collection(name) + + bpy.ops.surface.primitive_nurbs_surface_sphere_add(radius=1, enter_editmode=False, align='WORLD') + template_ball = bpy.context.object + bpy.context.collection.objects.unlink(template_ball) + + style = { + "color": marker.get("color", "gray"), + "opacity": marker.get("opacity", 1), + "size": marker.get("size", 1), + } + + for k, v in style.items(): + if (not isinstance(v, (collections.abc.Sequence, np.ndarray))) or isinstance(v, str): + style[k] = itertools.repeat(v) + + ball = template_ball + for i, (x_i, y_i, z_i, color, opacity, size) in enumerate(zip(x, y, z, style["color"], style["opacity"], style["size"])): + if i > 0: + ball = template_ball.copy() + ball.data = template_ball.data.copy() + + ball.location = [x_i, y_i, z_i] + ball.scale = (size, size, size) + + # Link the atom to the atoms collection + collection.objects.link(ball) + + ball.name = f"{name}_{i}" + ball.data.name = f"{name}_{i}" + + self._color_obj(ball, color, opacity=opacity) + + def _animate_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, collection=None, frame=0, **kwargs): + if collection is None: + collection = self.get_collection(name) + + # If this is the first frame, draw the object as usual + if frame == 0: + self.draw_balls_3D(x, y, z, marker=marker, name=name, row=row, col=col, collection=collection, frame=None, **kwargs) + + # Create a collection that we are just going to use to create new objects from which + # to copy the properties. + temp_collection_name = f"__animated_{name}" + temp_collection = self.get_collection(temp_collection_name) + self.clear_collection(temp_collection) + + self.draw_balls_3D(x, y, z, marker=marker, name=name, row=row, col=col, collection=temp_collection, frame=None, **kwargs) + + # Loop through all objects in the collections + for ani_obj, child_obj in zip(collection.objects, temp_collection.objects): + # Set the atom position + ani_obj.location = child_obj.location + ani_obj.keyframe_insert(data_path="location", frame=frame) + + # Set the atom size + ani_obj.scale = child_obj.scale + ani_obj.keyframe_insert(data_path="scale", frame=frame) + + # Set the atom color and opacity + ani_mat_inputs = ani_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + child_mat_inputs = child_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + + for input_key in ("Base Color", "Alpha"): + ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value + ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) + + self.remove_collection(temp_collection_name) + + draw_scatter_3D = draw_balls_3D + + def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name="Mesh", row=None, col=None, **kwargs): + col = self.get_collection(name) + + mesh = bpy.data.meshes.new(name) + + obj = bpy.data.objects.new(mesh.name, mesh) + + col.objects.link(obj) + + edges = [] + mesh.from_pydata(vertices, edges, faces.tolist()) + + self._color_obj(obj, color, opacity) + + @staticmethod + def _to_rgb_color(color): + + if isinstance(color, str): + try: + import matplotlib.colors + + color = matplotlib.colors.to_rgb(color) + except ModuleNotFoundError: + raise ValueError("Blender does not understand string colors."+ + "Please provide the color in rgb (tuple of length 3, values from 0 to 1) or install matplotlib so that we can convert it." + ) + + return color + + @classmethod + def _color_obj(cls, obj, color, opacity=1.): + """Utiity method to quickly color a given object. + + Parameters + ----------- + obj: blender Object + object to be colored + color: str or array-like of shape (3,) + color, it is converted to rgb using `matplotlib.colors.to_rgb` + opacity: + the opacity that should be given to the object. It doesn't + work currently. + """ + if opacity is None: + opacity = 1. + + color = cls._to_rgb_color(color) + + if color is not None: + mat = bpy.data.materials.new("material") + mat.use_nodes = True + + BSDF_inputs = mat.node_tree.nodes["Principled BSDF"].inputs + + BSDF_inputs["Base Color"].default_value = (*color, 1) + BSDF_inputs["Alpha"].default_value = opacity + + obj.active_material = mat + + def set_axis(self, *args, **kwargs): + """There are no axes titles and these kind of things in blender. + At least for now, we might implement it later.""" + + def set_axes_equal(self, *args, **kwargs): + """Axes are always "equal" in blender, so we do nothing here""" + + def show(self, *args, **kwargs): + bpy.context.scene.collection.children.link(self._collection) + + +# class BlenderMultiplePlotBackend(MultiplePlotBackend, BlenderBackend): + +# def draw(self, backend_info): +# children = backend_info["children"] +# # Start assigning each plot to a position of the layout +# for child in children: +# self._draw_child_in_scene(child) + +# def _draw_child_in_ax(self, child): +# child.get_figure(clear_fig=False) + + +# class BlenderAnimationBackend(BlenderBackend, AnimationBackend): + +# def draw(self, backend_info): + +# # Get the collections that make sense to implement. This property is defined +# # in each backend. See for example BlenderGeometryBackend +# animatable_collections = backend_info["children"][0]._animatable_collections +# # Get the number of frames that should be interpolated between two animation frames. +# interpolated_frames = backend_info["interpolated_frames"] + +# # Iterate over all collections +# for key, animate_config in animatable_collections.items(): + +# # Get the collection in the animation's instance +# collection = self.get_collection(key) +# # Copy all the objects from first child's collection +# for obj in backend_info["children"][0].get_collection(key).objects: +# new_obj = obj.copy() +# new_obj.data = obj.data.copy() +# # Some objects don't have materials associated. +# try: +# new_obj.data.materials[0] = obj.data.materials[0].copy() +# except Exception: +# pass +# collection.objects.link(new_obj) + +# # Loop over all child plots +# for i_plot, plot in enumerate(backend_info["children"]): +# # Calculate the frame number +# frame = i_plot * interpolated_frames +# # Ask the provided function to build the keyframes. +# animate_config["add_frame"](collection.objects, plot.get_collection(key).objects, frame=frame) + + +# class BlenderGeometryBackend(BlenderBackend, GeometryBackend): + +# _animatable_collections = { +# **BlenderBackend._animatable_collections, +# "Atoms": {"add_frame": add_atoms_frame}, +# "Unit cell": BlenderBackend._animatable_collections["Lines"] +# } + +# def draw_1D(self, backend_info, **kwargs): +# raise NotImplementedError("A way of drawing 1D geometry representations is not implemented for blender") + +# def draw_2D(self, backend_info, **kwargs): +# raise NotImplementedError("A way of drawing 2D geometry representations is not implemented for blender") + +# def _draw_single_atom_3D(self, xyz, size, color="gray", name=None, opacity=1, vertices=15, **kwargs): + +# try: +# atom = self._template_atom.copy() +# atom.data = self._template_atom.data.copy() +# except Exception: +# bpy.ops.surface.primitive_nurbs_surface_sphere_add(radius=1, enter_editmode=False, align='WORLD') +# self._template_atom = bpy.context.object +# atom = self._template_atom +# bpy.context.collection.objects.unlink(atom) + +# atom.location = xyz +# atom.scale = (size, size, size) + +# # Link the atom to the atoms collection +# atoms_col = self.get_collection("Atoms") +# atoms_col.objects.link(atom) + +# atom.name = name +# atom.data.name = name + +# self._color_obj(atom, color, opacity=opacity) + +# def _draw_bonds_3D(self, *args, line=None, **kwargs): +# # Multiply the width of the bonds to 0.2, otherwise they look gigantic. +# line = line or {} +# line["width"] = 0.2 * line.get("width", 1) +# # And call the method to draw bonds (which will use self.draw_line3D) +# collection = self.get_collection("Bonds") +# super()._draw_bonds_3D(*args, line=line, collection=collection, **kwargs) + +# def _draw_cell_3D_box(self, *args, width=None, **kwargs): +# width = width or 0.1 +# # This method is only defined to provide a better default for the width in blender +# # otherwise it looks gigantic, as the bonds +# collection = self.get_collection("Unit cell") +# super()._draw_cell_3D_box(*args, width=width, collection=collection, **kwargs) + +# GeometryPlot.backends.register("blender", BlenderGeometryBackend) \ No newline at end of file diff --git a/src/sisl/viz/figure/figure.py b/src/sisl/viz/figure/figure.py new file mode 100644 index 0000000000..bf2d9a32fb --- /dev/null +++ b/src/sisl/viz/figure/figure.py @@ -0,0 +1,715 @@ +from collections import ChainMap +from typing import Any, Dict, Literal, Optional, Tuple + +import numpy as np + +from sisl.messages import warn +from sisl.viz.plotutils import values_to_colors + +BACKENDS = {} + +class Figure: + """Base figure class that all backends should inherit from. + + It contains all the plotting actions that should be supported by a figure. + + A subclass for a specific backend should implement as many methods as possible + from the ones where Figure returns NotImplementedError. + Other methods are optional because Figure contains a default implementation + using other methods, which should work for most backends. + + To create a new backend, one might take the PlotlyFigure as a template. + """ + _coloraxes: dict = {} + _multi_axes: dict = {} + + _rows: Optional[int] = None + _cols: Optional[int] = None + + # The composite mode of the plot + _composite_mode = 0 + # Here are the different composite methods that can be used. + _NONE = 0 + _SAME_AXES = 1 + _MULTIAXIS = 2 + _SUBPLOTS = 3 + _ANIMATION = 4 + + def __init__(self, plot_actions, *args, **kwargs): + self.plot_actions = plot_actions + self._build(plot_actions, *args, **kwargs) + + def _build(self, plot_actions, *args, **kwargs): + + plot_actions = self._sanitize_plot_actions(plot_actions) + + self._coloraxes = {} + self._multi_axes = {} + + fig = self.init_figure( + composite_method=plot_actions['composite_method'], + plot_actions=plot_actions['plot_actions'], + init_kwargs=plot_actions['init_kwargs'], + ) + + for section_actions in self._composite_iter(self._composite_mode, plot_actions['plot_actions']): + for action in section_actions: + getattr(self, action['method'])(*action.get('args', ()), **action.get('kwargs', {})) + + return fig + + @staticmethod + def _sanitize_plot_actions(plot_actions): + + def _flatten(plot_actions, out, level=0, root_i=0): + for i, section_actions in enumerate(plot_actions): + if level == 0: + out.append([]) + root_i = i + + if isinstance(section_actions, dict): + _flatten(section_actions['plot_actions'], out, level + 1, root_i=root_i) + else: + # If it's a plot object, we need to extract the plot_actions + out[root_i].extend(section_actions) + + if isinstance(plot_actions, dict): + composite_method = plot_actions.get('composite_method') + init_kwargs = plot_actions.get('init_kwargs', {}) + out = [] + _flatten(plot_actions['plot_actions'], out) + plot_actions = out + else: + composite_method = None + plot_actions = [plot_actions] + init_kwargs = {} + + return {"composite_method": composite_method, "plot_actions": plot_actions, "init_kwargs": init_kwargs} + + def init_figure(self, composite_method: Literal[None, "same_axes", "multiple", "multiple_x", "multiple_y", "subplots", "animation"] = None, + plot_actions=(), init_kwargs: Dict[str, Any] = {}): + if composite_method is None: + self._composite_mode = self._NONE + return self._init_figure(**init_kwargs) + elif composite_method == "same_axes": + self._composite_mode = self._SAME_AXES + return self._init_figure_same_axes(**init_kwargs) + elif composite_method.startswith("multiple"): + # This could be multiple + self._composite_mode = self._MULTIAXIS + multi_axes = [ax for ax in 'xy' if ax in composite_method[8:]] + return self._init_figure_multiple_axes(multi_axes, plot_actions, **init_kwargs) + elif composite_method == "animation": + self._composite_mode = self._ANIMATION + return self._init_figure_animated(n=len(plot_actions), **init_kwargs) + elif composite_method == "subplots": + self._composite_mode = self._SUBPLOTS + self._rows, self._cols = self._subplots_rows_and_cols( + len(plot_actions), rows=init_kwargs.get('rows'), cols=init_kwargs.get('cols'), + arrange=init_kwargs.pop('arrange', "rows"), + ) + init_kwargs = ChainMap({'rows': self._rows, 'cols': self._cols}, init_kwargs) + return self._init_figure_subplots(**init_kwargs) + else: + raise ValueError(f"Unknown composite method '{composite_method}'") + + def _init_figure(self, **kwargs): + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_figure method.") + + def _init_figure_same_axes(self, *args, **kwargs): + return self._init_figure(*args, **kwargs) + + def _init_figure_multiple_axes(self, multi_axes, plot_actions, **kwargs): + figure = self._init_figure() + + if len(multi_axes) > 2: + raise ValueError(f"{self.__class__.__name__} doesn't support more than one multiple axes.") + + for axis in multi_axes: + self._multi_axes[axis] = self._init_multiaxis(axis, len(plot_actions)) + + return figure + + def _init_multiaxis(self, axis, n): + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_multiaxis method.") + + def _init_figure_animated(self, **kwargs): + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_figure_animated method.") + + def _init_figure_subplots(self, rows, cols, **kwargs): + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_figure_subplots method.") + + def _subplots_rows_and_cols(self, n: int, rows: Optional[int] = None, cols: Optional[int] = None, + arrange: Literal["rows", "cols", "square"] = "rows") -> Tuple[int, int]: + """ Returns the number of rows and columns for a subplot grid. """ + if rows is None and cols is None: + if arrange == 'rows': + rows = n + cols = 1 + elif arrange == 'cols': + cols = n + rows = 1 + elif arrange == 'square': + cols = n ** 0.5 + rows = n ** 0.5 + # we will correct so it *fits*, always have more columns + rows, cols = int(rows), int(cols) + cols = n // rows + min(1, n % rows) + elif rows is None and cols is not None: + # ensure it is large enough by adding 1 if they don't add up + rows = n // cols + min(1, n % cols) + elif cols is None and rows is not None: + # ensure it is large enough by adding 1 if they don't add up + cols = n // rows + min(1, n % rows) + + rows, cols = int(rows), int(cols) + + if cols * rows < n: + warn(f"requested {n} subplots on a {rows}x{cols} grid layout. {n - cols*rows} plots will be missing.") + + return rows, cols + + def _composite_iter(self, mode, plot_actions): + if mode == self._NONE: + return plot_actions + elif mode == self._SAME_AXES: + return self._iter_same_axes(plot_actions) + elif mode == self._MULTIAXIS: + return self._iter_multiaxis(plot_actions) + elif mode == self._SUBPLOTS: + return self._iter_subplots(plot_actions) + elif mode == self._ANIMATION: + return self._iter_animation(plot_actions) + else: + raise ValueError(f"Unknown composite mode '{mode}'") + + def _iter_same_axes(self, plot_actions): + return plot_actions + + def _iter_multiaxis(self, plot_actions): + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _iter_multiaxis method.") + + def _iter_subplots(self, plot_actions): + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _iter_subplots method.") + + def _iter_animation(self, plot_actions): + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _iter_animation method.") + + def clear(self): + """Clears the figure so that we can draw again.""" + pass + + def show(self): + pass + + def init_3D(self): + """Called if functions that draw in 3D are going to be called.""" + return + + def init_coloraxis(self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs): + """Initializes a color axis to be used by the drawing functions""" + self._coloraxes[name] = { + 'cmin': cmin, + 'cmax': cmax, + 'cmid': cmid, + 'colorscale': colorscale, + **kwargs + } + + def draw_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, **kwargs): + """Draws a line satisfying the specifications + + Parameters + ----------- + x: array-like + the coordinates of the points along the X axis. + y: array-like + the coordinates of the points along the Y axis. + name: str, optional + the name of the line + line: dict, optional + specifications for the line style, following plotly standards. The backend + should at least be able to implement `line["color"]` and `line["width"]` + marker: dict, optional + specifications for the markers style, following plotly standards. The backend + should at least be able to implement `marker["color"]` and `marker["size"]` + text: str, optional + contains the text asigned to each marker. On plotly this is seen on hover, + other options could be annotating. However, it is not necessary that this + argument is supported. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + **kwargs: + should allow other keyword arguments to be passed directly to the creation of + the line. This will of course be framework specific + """ + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line method.") + + def draw_multicolor_line(self, *args, line={}, row=None, col=None, **kwargs): + """By default, multicoloured lines are drawn simply by drawing scatter points.""" + marker = { + **kwargs.pop('marker', {}), + 'color': line.get('color'), + 'size': line.get('width'), + 'opacity': line.get('opacity'), + 'coloraxis': line.get('coloraxis') + } + self.draw_multicolor_scatter(*args, marker=marker, row=row, col=col, **kwargs) + + def draw_multisize_line(self, *args, line={}, row=None, col=None, **kwargs): + """By default, multisized lines are drawn simple by drawing scatter points.""" + marker = { + **kwargs.pop('marker', {}), + 'color': line.get('color'), + 'size': line.get('width'), + 'opacity': line.get('opacity'), + 'coloraxis': line.get('coloraxis') + } + self.draw_multisize_scatter(*args, marker=marker, row=row, col=col, **kwargs) + + def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + """Same as draw line, but to draw a line with an area. This is for example used to draw fatbands. + + Parameters + ----------- + x: array-like + the coordinates of the points along the X axis. + y: array-like + the coordinates of the points along the Y axis. + name: str, optional + the name of the scatter + line: dict, optional + specifications for the line style, following plotly standards. The backend + should at least be able to implement `line["color"]` and `line["width"]`, but + it is very advisable that it supports also `line["opacity"]`. + text: str, optional + contains the text asigned to each marker. On plotly this is seen on hover, + other options could be annotating. However, it is not necessary that this + argument is supported. + dependent_axis: str, optional + The axis that contains the dependent variable. This is important because + the area is drawn in parallel to that axis. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + **kwargs: + should allow other keyword arguments to be passed directly to the creation of + the scatter. This will of course be framework specific + """ + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_area_line method.") + + def draw_multicolor_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + """Draw a line with an area with multiple colours. + + Parameters + ----------- + x: array-like + the coordinates of the points along the X axis. + y: array-like + the coordinates of the points along the Y axis. + name: str, optional + the name of the scatter + line: dict, optional + specifications for the line style, following plotly standards. The backend + should at least be able to implement `line["color"]` and `line["width"]`, but + it is very advisable that it supports also `line["opacity"]`. + text: str, optional + contains the text asigned to each marker. On plotly this is seen on hover, + other options could be annotating. However, it is not necessary that this + argument is supported. + dependent_axis: str, optional + The axis that contains the dependent variable. This is important because + the area is drawn in parallel to that axis. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + **kwargs: + should allow other keyword arguments to be passed directly to the creation of + the scatter. This will of course be framework specific + """ + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_multicolor_area_line method.") + + def draw_multisize_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + """Draw a line with an area with multiple colours. + + This is already usually supported by the normal draw_area_line. + + Parameters + ----------- + x: array-like + the coordinates of the points along the X axis. + y: array-like + the coordinates of the points along the Y axis. + name: str, optional + the name of the scatter + line: dict, optional + specifications for the line style, following plotly standards. The backend + should at least be able to implement `line["color"]` and `line["width"]`, but + it is very advisable that it supports also `line["opacity"]`. + text: str, optional + contains the text asigned to each marker. On plotly this is seen on hover, + other options could be annotating. However, it is not necessary that this + argument is supported. + dependent_axis: str, optional + The axis that contains the dependent variable. This is important because + the area is drawn in parallel to that axis. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + **kwargs: + should allow other keyword arguments to be passed directly to the creation of + the scatter. This will of course be framework specific + """ + # Usually, multisized area lines are already supported. + return self.draw_area_line(x, y, name=name, line=line, text=text, dependent_axis=dependent_axis, row=row, col=col, **kwargs) + + def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs): + """Draws a scatter satisfying the specifications + + Parameters + ----------- + x: array-like + the coordinates of the points along the X axis. + y: array-like + the coordinates of the points along the Y axis. + name: str, optional + the name of the scatter + marker: dict, optional + specifications for the markers style, following plotly standards. The backend + should at least be able to implement `marker["color"]` and `marker["size"]`, but + it is very advisable that it supports also `marker["opacity"]` and `marker["colorscale"]` + text: str, optional + contains the text asigned to each marker. On plotly this is seen on hover, + other options could be annotating. However, it is not necessary that this + argument is supported. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + **kwargs: + should allow other keyword arguments to be passed directly to the creation of + the scatter. This will of course be framework specific + """ + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter method.") + + def draw_multicolor_scatter(self, *args, **kwargs): + """Draws a multicoloured scatter. + + Usually the normal scatter can already support this. + """ + # Usually, multicoloured scatter plots are already supported. + return self.draw_scatter(*args, **kwargs) + + def draw_multisize_scatter(self, *args, **kwargs): + """Draws a multisized scatter. + + Usually the normal scatter can already support this. + """ + # Usually, multisized scatter plots are already supported. + return self.draw_scatter(*args, **kwargs) + + def draw_arrows(self, x, y, dxy, arrowhead_scale=0.2, arrowhead_angle=20, scale: float = 1, annotate: bool = False, row=None, col=None, **kwargs): + """Draws multiple arrows using the generic draw_line method. + + Parameters + ----------- + xy: np.ndarray of shape (n_arrows, 2) + the positions where the atoms start. + dxy: np.ndarray of shape (n_arrows, 2) + the arrow vector. + arrow_head_scale: float, optional + how big is the arrow head in comparison to the arrow vector. + arrowhead_angle: angle + the angle that the arrow head forms with the direction of the arrow (in degrees). + scale: float, optional + multiplying factor to display the arrows. It does not affect the underlying data, + therefore if the data is somehow displayed it should be without the scale factor. + annotate: + whether to annotate the arrows with the vector they represent. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + """ + # Make sure we are treating with numpy arrays + xy = np.array([x, y]).T + dxy = np.array(dxy) * scale + + # Get the destination of the arrows + final_xy = xy + dxy + + # Convert from degrees to radians. + arrowhead_angle = np.radians(arrowhead_angle) + + # Get the rotation matrices to get the tips of the arrowheads + rot_matrix = np.array([[np.cos(arrowhead_angle), -np.sin(arrowhead_angle)], [np.sin(arrowhead_angle), np.cos(arrowhead_angle)]]) + inv_rot = np.linalg.inv(rot_matrix) + + # Calculate the tips of the arrow heads + arrowhead_tips1 = final_xy - (dxy*arrowhead_scale).dot(rot_matrix) + arrowhead_tips2 = final_xy - (dxy*arrowhead_scale).dot(inv_rot) + + # Now build an array with all the information to draw the arrows + # This has shape (n_arrows * 7, 2). The information to draw an arrow + # occupies 7 rows and the columns are the x and y coordinates. + arrows = np.empty((xy.shape[0]*7, xy.shape[1]), dtype=np.float64) + + arrows[0::7] = xy + arrows[1::7] = final_xy + arrows[2::7] = np.nan + arrows[3::7] = arrowhead_tips1 + arrows[4::7] = final_xy + arrows[5::7] = arrowhead_tips2 + arrows[6::7] = np.nan + + # + hovertext = np.tile(dxy / scale, 7).reshape(dxy.shape[0] * 7, -1) + + if annotate: + # Add text annotations just at the tip of the arrows. + annotate_text = np.full((arrows.shape[0],), "", dtype=object) + annotate_text[4::7] = [str(xy / scale) for xy in dxy] + kwargs['text'] = list(annotate_text) + + return self.draw_line(arrows[:, 0], arrows[:, 1], hovertext=list(hovertext), row=row, col=col, **kwargs) + + def draw_line_3D(self, x, y, z, name=None, line={}, marker={}, text=None, row=None, col=None, **kwargs): + """Draws a 3D line satisfying the specifications. + + Parameters + ----------- + x: array-like + the coordinates of the points along the X axis. + y: array-like + the coordinates of the points along the Y axis. + z: array-like + the coordinates of the points along the Z axis. + name: str, optional + the name of the line + line: dict, optional + specifications for the line style, following plotly standards. The backend + should at least be able to implement `line["color"]` and `line["width"]` + marker: dict, optional + specifications for the markers style, following plotly standards. The backend + should at least be able to implement `marker["color"]` and `marker["size"]` + text: str, optional + contains the text asigned to each marker. On plotly this is seen on hover, + other options could be annotating. However, it is not necessary that this + argument is supported. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + **kwargs: + should allow other keyword arguments to be passed directly to the creation of + the line. This will of course be framework specific + """ + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line_3D method.") + + def draw_multicolor_line_3D(self, *args, **kwargs): + """Draws a multicoloured 3D line.""" + self.draw_line_3D(*args, **kwargs) + + def draw_multisize_line_3D(self, *args, **kwargs): + """Draws a multisized 3D line.""" + self.draw_line_3D(*args, **kwargs) + + def draw_scatter_3D(self, x, y, z, name=None, marker={}, text=None, row=None, col=None, **kwargs): + """Draws a 3D scatter satisfying the specifications + + Parameters + ----------- + x: array-like + the coordinates of the points along the X axis. + y: array-like + the coordinates of the points along the Y axis. + z: array-like + the coordinates of the points along the Z axis. + name: str, optional + the name of the scatter + marker: dict, optional + specifications for the markers style, following plotly standards. The backend + should at least be able to implement `marker["color"]` and `marker["size"]` + text: str, optional + contains the text asigned to each marker. On plotly this is seen on hover, + other options could be annotating. However, it is not necessary that this + argument is supported. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + **kwargs: + should allow other keyword arguments to be passed directly to the creation of + the scatter. This will of course be framework specific + """ + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter_3D method.") + + def draw_multicolor_scatter_3D(self, *args, **kwargs): + """Draws a multicoloured 3D scatter. + + Usually the normal 3D scatter can already support this. + """ + # Usually, multicoloured scatter plots are already supported. + return self.draw_scatter_3D(*args, **kwargs) + + def draw_multisize_scatter_3D(self, *args, **kwargs): + """Draws a multisized 3D scatter. + + Usually the normal 3D scatter can already support this. + """ + # Usually, multisized scatter plots are already supported. + return self.draw_scatter_3D(*args, **kwargs) + + def draw_balls_3D(self, x, y, z, name=None, markers={}, row=None, col=None, **kwargs): + """Draws points as 3D spheres.""" + return NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_balls_3D method.") + + def draw_multicolor_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs): + """Draws points as 3D spheres with different colours. + + If marker_color is an array of numbers, a coloraxis is created and values are converted to rgb. + """ + + kwargs['marker'] = marker.copy() + + if 'color' in marker and np.array(marker['color']).dtype in (int, float): + coloraxis = kwargs['marker']['coloraxis'] + coloraxis = self._coloraxes[coloraxis] + + kwargs['marker']['color'] = values_to_colors(kwargs['marker']['color'], coloraxis['colorscale'] or "viridis") + + return self.draw_balls_3D(x, y, z, name=name, row=row, col=col, **kwargs) + + def draw_multisize_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs): + """Draws points as 3D spheres with different sizes. + + Usually supported by the normal draw_balls_3D + """ + return self.draw_balls_3D(x, y, z, name=name, row=row, col=col, **kwargs) + + def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, scale: float = 1, row=None, col=None, **kwargs): + """Draws multiple 3D arrows using the generic draw_line_3D method. + + Parameters + ----------- + x: np.ndarray of shape (n_arrows, ) + the X coordinates of the arrow's origin. + y: np.ndarray of shape (n_arrows, ) + the Y coordinates of the arrow's origin. + z: np.ndarray of shape (n_arrows, ) + the Z coordinates of the arrow's origin. + dxyz: np.ndarray of shape (n_arrows, 2) + the arrow vector. + arrow_head_scale: float, optional + how big is the arrow head in comparison to the arrow vector. + arrowhead_angle: angle + the angle that the arrow head forms with the direction of the arrow (in degrees). + scale: float, optional + multiplying factor to display the arrows. It does not affect the underlying data, + therefore if the data is somehow displayed it should be without the scale factor. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + """ + # Make sure we are dealing with numpy arrays + xyz = np.array([x, y, z]).T + dxyz = np.array(dxyz) * scale + + # Get the destination of the arrows + final_xyz = xyz + dxyz + + # Convert from degrees to radians. + arrowhead_angle = np.radians(arrowhead_angle) + + # Calculate the arrowhead positions. This is a bit more complex than the 2D case, + # since there's no unique plane to rotate all vectors. + # First, we get a unitary vector that is perpendicular to the direction of the arrow in xy. + dxy_norm = np.linalg.norm(dxyz[:, :2], axis=1) + # Some vectors might be only in the Z direction, which will result in dxy_norm being 0. + # We avoid problems by dividinc + dx_p = np.divide(dxyz[:, 1], dxy_norm, where=dxy_norm != 0, out=np.zeros(dxyz.shape[0], dtype=np.float64)) + dy_p = np.divide(-dxyz[:, 0], dxy_norm, where=dxy_norm != 0, out=np.ones(dxyz.shape[0], dtype=np.float64)) + + # And then we build the rotation matrices. Since each arrow needs a unique rotation matrix, + # we will have n 3x3 matrices, where n is the number of arrows, for each arrowhead tip. + c = np.cos(arrowhead_angle) + s = np.sin(arrowhead_angle) + + # Rotation matrix to build the first arrowhead tip positions. + rot_matrices = np.array( + [[c + (dx_p ** 2) * (1 - c), dx_p * dy_p * (1 - c), dy_p * s], + [dy_p * dx_p * (1 - c), c + (dy_p ** 2) * (1 - c), -dx_p * s], + [-dy_p * s, dx_p * s, np.full_like(dx_p, c)]]) + + # The opposite rotation matrix, to get the other arrowhead's tip positions. + inv_rots = rot_matrices.copy() + inv_rots[[0, 1, 2, 2], [2, 2, 0, 1]] *= -1 + + # Calculate the tips of the arrow heads. + arrowhead_tips1 = final_xyz - np.einsum("ij...,...j->...i", rot_matrices, dxyz * arrowhead_scale) + arrowhead_tips2 = final_xyz - np.einsum("ij...,...j->...i", inv_rots, dxyz * arrowhead_scale) + + # Now build an array with all the information to draw the arrows + # This has shape (n_arrows * 7, 3). The information to draw an arrow + # occupies 7 rows and the columns are the x and y coordinates. + arrows = np.empty((xyz.shape[0]*7, 3)) + + arrows[0::7] = xyz + arrows[1::7] = final_xyz + arrows[2::7] = np.nan + arrows[3::7] = arrowhead_tips1 + arrows[4::7] = final_xyz + arrows[5::7] = arrowhead_tips2 + arrows[6::7] = np.nan + + return self.draw_line_3D(arrows[:, 0], arrows[:, 1], arrows[:, 2], row=row, col=col, **kwargs) + + def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None): + """Draws a heatmap following the specifications.""" + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_heatmap method.") + + def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name=None, row=None, col=None, **kwargs): + """Draws a 3D mesh following the specifications.""" + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_mesh_3D method.") + + def set_axis(self, **kwargs): + """Sets the axis parameters. + + The specification for the axes is exactly the plotly one. This is to have a good + reference for consistency. Other frameworks should translate the calls to their + functionality. + """ + + def set_axes_equal(self): + """Sets the axes equal.""" + raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a set_axes_equal method.") + + def to(self, key: str): + """Converts the figure to another backend. + + Parameters + ----------- + key: str + the backend to convert to. + """ + return BACKENDS[key](self.plot_actions) + +def get_figure(backend: str, plot_actions, *args, **kwargs) -> Figure: + """Get a figure object. + + Parameters + ---------- + backend : {"plotly", "matplotlib", "py3dmol", "blender"} + the backend to use + plot_actions : list of callable + the plot actions to perform + *args, **kwargs + passed to the figure constructor + """ + return BACKENDS[backend](plot_actions, *args, **kwargs) \ No newline at end of file diff --git a/src/sisl/viz/figure/matplotlib.py b/src/sisl/viz/figure/matplotlib.py new file mode 100644 index 0000000000..56fbb18bab --- /dev/null +++ b/src/sisl/viz/figure/matplotlib.py @@ -0,0 +1,345 @@ +import itertools + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.collections import LineCollection +from matplotlib.pyplot import Normalize +from mpl_toolkits.axisartist.parasite_axes import HostAxes, ParasiteAxes + +from sisl.messages import warn + +from .figure import Figure + + +class MatplotlibFigure(Figure): + """Generic backend for the matplotlib framework. + + On initialization, `matplotlib.pyplot.subplots` is called and the figure and and + axes obtained are stored under `self.figure` and `self.axeses`, respectively. + If an attribute is not found on the backend, it is looked for + in the axes. + + On initialization, we also take the class attribute `_axes_defaults` (a dictionary) + and run `self.axes.update` with those parameters. Therefore this parameter can be used + to provide default parameters for the axes. + """ + + _axes_defaults = {} + + def _init_figure(self, *args, **kwargs): + self.figure, self.axes = self._init_plt_figure() + self._init_axes() + return self.figure + + # def draw_on(self, axes, axes_indices=None): + # """Draws this plot in a different figure. + + # Parameters + # ----------- + # axes: Plot, PlotlyBackend or matplotlib.axes.Axes + # The axes to draw this plot in. + # """ + # if isinstance(axes, Plot): + # axes = axes._backend.axes + # elif isinstance(axes, MatplotlibBackend): + # axes = axes.axes + + # if axes_indices is not None: + # axes = axes[axes_indices] + + # if not isinstance(axes, Axes): + # raise TypeError(f"{self.__class__.__name__} was provided a {axes.__class__.__name__} to draw on.") + + # self_axes = self.axes + # self.axes = axes + # self._init_axes() + # self._plot.get_figure(backend=self._backend_name, clear_fig=False) + # self.axes = self_axes + + def _init_plt_figure(self): + """Initializes the matplotlib figure and axes + + Returns + -------- + Figure: + the matplotlib figure of this plot. + Axes: + the matplotlib axes of this plot. + """ + return plt.subplots() + + def _init_axes(self): + """Does some initial modification on the axes.""" + self.axes.update(self._axes_defaults) + + def _init_figure_subplots(self, rows, cols, **kwargs): + self.figure, self.axes = plt.subplots(rows, cols) + + # Normalize the axes array to have two dimensions + if rows == 1 and cols == 1: + self.axes = np.array([[self.axes]]) + elif rows == 1: + self.axes = np.expand_dims(self.axes, axis=0) + elif cols == 1: + self.axes = np.expand_dims(self.axes, axis=1) + + return self.figure + + def _get_subplot_axes(self, row=None, col=None) -> plt.Axes: + if row is None or col is None: + # This is not a subplot + return self.axes + # Otherwise, it is indeed a subplot, so we get the axes + return self.axes[row, col] + + def _iter_subplots(self, plot_actions): + + it = zip(itertools.product(range(self._rows), range(self._cols)), plot_actions) + + # Start assigning each plot to a position of the layout + for i, ((row, col), section_actions) in enumerate(it): + + row_col_kwargs = {"row": row, "col": col} + # active_axes = { + # ax: f"{ax}axis" if row == 0 and col == 0 else f"{ax}axis{i + 1}" + # for ax in "xyz" + # } + + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} + elif action_name.startswith("set_ax"): + action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} + + sanitized_section_actions.append(action) + + yield sanitized_section_actions + + + def _init_figure_multiple_axes(self, multi_axes, plot_actions, **kwargs): + + if len(multi_axes) > 1: + self.figure = plt.figure() + self.axes = self.figure.add_axes([0.15, 0.1, 0.65, 0.8], axes_class=HostAxes) + self._init_axes() + + multi_axis = "xy" + else: + self.figure = self._init_figure() + multi_axis = multi_axes[0] + + self._multi_axes[multi_axis] = self._init_multiaxis(multi_axis, len(plot_actions)) + + return self.figure + + def _init_multiaxis(self, axis, n): + + axes = [self.axes] + for i in range(n - 1): + if axis == "x": + axes.append(self.axes.twiny()) + elif axis == "y": + axes.append(self.axes.twinx()) + elif axis == "xy": + new_axes = ParasiteAxes(self.axes, visible=True) + + new_axes.axis["right"].set_visible(True) + new_axes.axis["right"].major_ticklabels.set_visible(True) + new_axes.axis["right"].label.set_visible(True) + new_axes.axis["top"].set_visible(True) + new_axes.axis["top"].major_ticklabels.set_visible(True) + new_axes.axis["top"].label.set_visible(True) + + self.axes.axis["right"].set_visible(False) + self.axes.axis["top"].set_visible(False) + + self.axes.parasites.append(new_axes) + axes.append(new_axes) + + return axes + + def _iter_multiaxis(self, plot_actions): + multi_axis = list(self._multi_axes)[0] + for i, section_actions in enumerate(plot_actions): + axes = self._multi_axes[multi_axis][i] + + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": {**action.get("kwargs", {}), "_axes": axes}} + elif action_name == "set_axis": + action = {**action, "kwargs": {**action.get("kwargs", {}), "_axes": axes}} + + sanitized_section_actions.append(action) + + yield sanitized_section_actions + + def __getattr__(self, key): + if key != "axes": + return getattr(self.axes, key) + raise AttributeError(key) + + def clear(self, layout=False): + """ Clears the plot canvas so that data can be reset + + Parameters + -------- + layout: boolean, optional + whether layout should also be deleted + """ + if layout: + self.axes.clear() + + for artist in self.axes.lines + self.axes.collections: + artist.remove() + + return self + + def get_ipywidget(self): + return self.figure + + def show(self, *args, **kwargs): + return self.figure.show(*args, **kwargs) + + def draw_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, _axes=None, **kwargs): + marker_format = marker.get("symbol", "o") if marker else None + marker_color = marker.get("color") + + axes = _axes or self._get_subplot_axes(row=row, col=col) + + return axes.plot( + x, y, color=line.get("color"), linewidth=line.get("width", 1), + marker=marker_format, markersize=marker.get("size"), markerfacecolor=marker_color, markeredgecolor=marker_color, + label=name + ) + + def draw_multicolor_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, _axes=None, **kwargs): + # This is heavily based on + # https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html + + color = line.get("color") + + if not np.issubdtype(np.array(color).dtype, np.number): + return self.draw_multicolor_scatter(x, y, name=name, marker=line, text=text, row=row, col=col, _axes=_axes, **kwargs) + + points = np.array([x, y]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + lc_kwargs = {} + coloraxis = line.get("coloraxis") + if coloraxis is not None: + coloraxis = self._coloraxes.get(coloraxis) + lc_kwargs["cmap"] = coloraxis.get("colorscale") + if coloraxis.get("cmin") is not None: + lc_kwargs["norm"] = Normalize(coloraxis['cmin'], coloraxis['cmax']) + + lc = LineCollection(segments, **lc_kwargs) + + # Set the values used for colormapping + lc.set_array(line.get("color")) + lc.set_linewidth(line.get("width", 1)) + + axes = _axes or self._get_subplot_axes(row=row, col=col) + + axes.add_collection(lc) + + #self._colorbar = axes.add_collection(lc) + + def draw_multisize_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, _axes=None, **kwargs): + points = np.array([x, y]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + lc = LineCollection(segments) + + # Set the values used for colormapping + lc.set_linewidth(line.get("width", 1)) + + axes = _axes or self._get_subplot_axes(row=row, col=col) + + axes.add_collection(lc) + + def draw_area_line(self, x, y, line={}, name=None, dependent_axis=None, row=None, col=None, _axes=None, **kwargs): + + width = line.get('width') + if width is None: + width = 1 + spacing = width / 2 + + axes = _axes or self._get_subplot_axes(row=row, col=col) + + if dependent_axis in ("y", None): + axes.fill_between( + x, y + spacing, y - spacing, + color=line.get('color'), label=name + ) + elif dependent_axis == "x": + axes.fill_betweenx( + y, x + spacing, x - spacing, + color=line.get('color'), label=name + ) + else: + raise ValueError(f"dependent_axis must be one of 'x', 'y', or None, but was {dependent_axis}") + + def draw_scatter(self, x, y, name=None, marker={}, text=None, zorder=2, row=None, col=None, _axes=None, **kwargs): + axes = _axes or self._get_subplot_axes(row=row, col=col) + try: + return axes.scatter(x, y, c=marker.get("color"), s=marker.get("size", 1), cmap=marker.get("colorscale"), alpha=marker.get("opacity"), label=name, zorder=zorder, **kwargs) + except TypeError as e: + if str(e) == "alpha must be a float or None": + warn(f"Your matplotlib version doesn't support multiple opacity values, please upgrade to >=3.4 if you want to use opacity.") + return axes.scatter(x, y, c=marker.get("color"), s=marker.get("size", 1), cmap=marker.get("colorscale"), label=name, zorder=zorder, **kwargs) + else: + raise e + + def draw_multicolor_scatter(self, *args, **kwargs): + marker = {**kwargs.pop("marker",{})} + coloraxis = marker.get("coloraxis") + if coloraxis is not None: + coloraxis = self._coloraxes.get(coloraxis) + marker["colorscale"] = coloraxis.get("colorscale") + return super().draw_multicolor_scatter(*args, marker=marker, **kwargs) + + def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, _axes=None): + + extent = None + if x is not None and y is not None: + extent = [x[0], x[-1], y[0], y[-1]] + + axes = _axes or self._get_subplot_axes(row=row, col=col) + + coloraxis = self._coloraxes.get(coloraxis, {}) + colorscale = coloraxis.get("colorscale") + vmin = coloraxis.get("cmin") + vmax = coloraxis.get("cmax") + + axes.imshow( + values, + cmap=colorscale, + vmin=vmin, vmax=vmax, + label=name, extent=extent, + origin="lower" + ) + + def set_axis(self, axis, range=None, title="", tickvals=None, ticktext=None, showgrid=False, row=None, col=None, _axes=None, **kwargs): + axes = _axes or self._get_subplot_axes(row=row, col=col) + + if range is not None: + updater = getattr(axes, f'set_{axis}lim') + updater(*range) + + if title: + updater = getattr(axes, f'set_{axis}label') + updater(title) + + if tickvals is not None: + updater = getattr(axes, f'set_{axis}ticks') + updater(ticks=tickvals, labels=ticktext) + + axes.grid(visible=showgrid, axis=axis) + + def set_axes_equal(self, row=None, col=None, _axes=None): + axes = _axes or self._get_subplot_axes(row=row, col=col) + axes.axis("equal") \ No newline at end of file diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py new file mode 100644 index 0000000000..264aa8a2c1 --- /dev/null +++ b/src/sisl/viz/figure/plotly.py @@ -0,0 +1,563 @@ +import itertools +from typing import Optional, Sequence + +import numpy as np +import plotly.graph_objs as go +import plotly.io as pio + +from ..processors.coords import sphere +from .figure import Figure + +# Special plotly templates for sisl +pio.templates["sisl"] = go.layout.Template( + layout={ + "plot_bgcolor": "white", + "paper_bgcolor": "white", + **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( + ("xaxis", "yaxis"), + (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), + ("color", "black"), ("showgrid", False), ("gridcolor", "#ccc"), ("gridwidth", 1), + ("zeroline", False), ("zerolinecolor", "#ccc"), ("zerolinewidth", 1), + ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) + )}, + "hovermode": "closest", + "scene": { + **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( + ("xaxis", "yaxis", "zaxis"), + (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), + ("color", "black"), ("showgrid", + False), ("gridcolor", "#ccc"), ("gridwidth", 1), + ("zeroline", False), ("zerolinecolor", + "#ccc"), ("zerolinewidth", 1), + ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) + )}, + } + #"editrevision": True + #"title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} + }, +) + +pio.templates["sisl_dark"] = go.layout.Template( + layout={ + "plot_bgcolor": "black", + "paper_bgcolor": "black", + **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( + ("xaxis", "yaxis"), + (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), + ("color", "white"), ("showgrid", + False), ("gridcolor", "#ccc"), ("gridwidth", 1), + ("zeroline", False), ("zerolinecolor", "#ccc"), ("zerolinewidth", 1), + ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) + )}, + "font": {'color': 'white'}, + "hovermode": "closest", + "scene": { + **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( + ("xaxis", "yaxis", "zaxis"), + (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), + ("color", "white"), ("showgrid", + False), ("gridcolor", "#ccc"), ("gridwidth", 1), + ("zeroline", False), ("zerolinecolor", + "#ccc"), ("zerolinewidth", 1), + ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) + )}, + } + #"editrevision": True + #"title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} + }, +) + +# This will be the default one for sisl plots +# Maybe we should pass it explicitly to plots instead of making it the default +# so that it doesn't affect plots outside sisl. +pio.templates.default = "sisl" + +class PlotlyFigure(Figure): + """Generic canvas for the plotly framework. + + On initialization, a plotly.graph_objs.Figure object is created and stored + under `self.figure`. If an attribute is not found on the backend, it is looked for + in the figure. Therefore, you can apply all the methods that are appliable to a plotly + figure! + + On initialization, we also take the class attribute `_layout_defaults` (a dictionary) + and run `update_layout` with those parameters. + """ + _multi_axis = None + + _layout_defaults = {} + + def _init_figure(self, *args, **kwargs): + self.figure = go.Figure() + self.update_layout(**self._layout_defaults) + return self + + def _init_figure_subplots(self, rows, cols, **kwargs): + + figure = self._init_figure() + + figure.set_subplots(**{ + "rows": rows, "cols": cols, **kwargs, + }) + + return figure + + def _iter_subplots(self, plot_actions): + + it = zip(itertools.product(range(self._rows), range(self._cols)), plot_actions) + + # Start assigning each plot to a position of the layout + for i, ((row, col), section_actions) in enumerate(it): + + row_col_kwargs = {"row": row + 1, "col": col + 1} + active_axes = { + ax: f"{ax}axis" if row == 0 and col == 0 else f"{ax}axis{i + 1}" + for ax in "xyz" + } + + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} + elif action_name.startswith("set_ax"): + action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} + + sanitized_section_actions.append(action) + + yield sanitized_section_actions + + def _init_multiaxis(self, axis, n): + + axes = [f"{axis}{i + 1}" if i > 0 else axis for i in range(n) ] + layout_axes = [f"{axis}axis{i + 1}" if i > 0 else f"{axis}axis" for i in range(n) ] + if axis == "x": + sides = ["bottom", "top"] + elif axis == "y": + sides = ["left", "right"] + else: + raise ValueError(f"Multiple axis are only supported for 'x' or 'y'") + + layout_updates = {} + for ax, side in zip(layout_axes, itertools.cycle(sides)): + layout_updates[ax] = {'side': side, 'overlaying': axis} + layout_updates[f"{axis}axis"]['overlaying'] = None + self.update_layout(**layout_updates) + + return layout_axes + + def _iter_multiaxis(self, plot_actions): + + for i, section_actions in enumerate(plot_actions): + active_axes = {ax: v[i] for ax, v in self._multi_axes.items()} + active_axes_kwargs = {f"{ax}axis": v.replace("axis", "") for ax, v in active_axes.items()} + + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": {**action.get("kwargs", {}), **active_axes_kwargs}} + elif action_name.startswith("set_ax"): + action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} + + sanitized_section_actions.append(action) + + yield sanitized_section_actions + + def _init_figure_animated(self, frame_names: Optional[Sequence[str]] = None, frame_duration: int = 500, transition: int = 300, redraw: bool = False, **kwargs): + self._animation_settings = { + "frame_names": frame_names, + "frame_duration": frame_duration, + "transition": transition, + "redraw": redraw, + } + self._animate_frame_names = frame_names + self._animate_init_kwargs = kwargs + return self._init_figure(**kwargs) + + def _iter_animation(self, plot_actions): + + frame_duration = self._animation_settings["frame_duration"] + transition = self._animation_settings["transition"] + redraw = self._animation_settings["redraw"] + + frame_names = self._animation_settings["frame_names"] + if frame_names is None: + frame_names = [i for i in range(len(plot_actions))] + + frame_names = [str(name) for name in frame_names] + + frames = [] + for i, section_actions in enumerate(plot_actions): + + yield section_actions + + # Create a frame and append it + frames.append(go.Frame(name=frame_names[i],data=self.figure.data, layout=self.figure.layout)) + + # Reinit the figure + self._init_figure(**self._animate_init_kwargs) + + self.figure.update(data=frames[0].data, frames=frames) + + slider_steps = [ + {"args": [ + [frame["name"]], + {"frame": {"duration": int(frame_duration), "redraw": redraw}, + "mode": "immediate", + "transition": {"duration": transition}} + ], + "label": frame["name"], + "method": "animate"} for frame in self.figure.frames + ] + + slider = { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "font": {"size": 20}, + #"prefix": "Bands file:", + "visible": True, + "xanchor": "right" + }, + #"transition": {"duration": 300, "easing": "cubic-in-out"}, + "pad": {"b": 10, "t": 50}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": slider_steps + } + + # Buttons to play and pause the animation + updatemenus = [ + + {'type': 'buttons', + 'buttons': [ + { + 'label': '▶', + 'method': 'animate', + 'args': [None, {"frame": {"duration": int(frame_duration), "redraw": redraw}, + "fromcurrent": True, "transition": {"duration": 100}}], + }, + + { + 'label': '⏸', + 'method': 'animate', + 'args': [[None], {"frame": {"duration": 0}, "redraw": redraw, + 'mode': 'immediate', + "transition": {"duration": 0}}], + } + ]} + ] + + self.update_layout(sliders=[slider], updatemenus=updatemenus) + + def __getattr__(self, key): + if key != "figure": + return getattr(self.figure, key) + raise AttributeError(key) + + def show(self, *args, **kwargs): + return self.figure.show(*args, **kwargs) + + def clear(self, frames=True, layout=False): + """ Clears the plot canvas so that data can be reset + + Parameters + -------- + frames: boolean, optional + whether frames should also be deleted + layout: boolean, optional + whether layout should also be deleted + """ + self.figure.data = [] + + if frames: + self.figure.frames = [] + + if layout: + self.figure.layout = {} + + return self + + # -------------------------------- + # METHODS TO STANDARIZE BACKENDS + # -------------------------------- + def init_coloraxis(self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs): + if len(self._coloraxes) == 0: + kwargs['ax_name'] = "coloraxis" + else: + kwargs['ax_name'] = f'coloraxis{len(self._coloraxes) + 1}' + + super().init_coloraxis(name, cmin, cmax, cmid, colorscale, **kwargs) + + ax_name = kwargs['ax_name'] + self.update_layout(**{ax_name: {"colorscale": colorscale, "cmin": cmin, "cmax": cmax, "cmid": cmid}}) + + def _get_coloraxis_name(self, coloraxis: Optional[str]): + + if coloraxis in self._coloraxes: + return self._coloraxes[coloraxis]['ax_name'] + else: + return coloraxis + + def _handle_multicolor_scatter(self, marker, scatter_kwargs): + + if 'coloraxis' in marker: + marker = marker.copy() + coloraxis = marker['coloraxis'] + + if coloraxis is not None: + scatter_kwargs['hovertemplate'] = "x: %{x:.2f}
y: %{y:.2f}
" + coloraxis + ": %{marker.color:.2f}" + marker['coloraxis'] = self._get_coloraxis_name(coloraxis) + + return marker + + def draw_line(self, x, y, name=None, line={}, row=None, col=None, **kwargs): + """Draws a line in the current plot.""" + opacity = kwargs.get("opacity", line.get("opacity", 1)) + + # Define the mode of the scatter trace. If markers or text are passed, + # we enforce the mode to show them. + mode = kwargs.pop("mode", "lines") + if kwargs.get("marker") and "markers" not in mode: + mode += "+markers" + if kwargs.get("text") and "text" not in mode: + mode += "+text" + + # Finally, we add the trace. + self.add_trace({ + 'type': 'scatter', + 'x': x, + 'y': y, + 'mode': mode, + 'name': name, + 'line': {k: v for k, v in line.items() if k != "opacity"}, + 'opacity': opacity, + **kwargs, + }, row=row, col=col) + + def draw_multicolor_line(self, *args, **kwargs): + kwargs['marker_line_width'] = 0 + + super().draw_multicolor_line(*args, **kwargs) + + def draw_multisize_line(self, *args, **kwargs): + kwargs['marker_line_width'] = 0 + + super().draw_multisize_line(*args, **kwargs) + + def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + chunk_x = x + chunk_y = y + + width = line.get('width') + if width is None: + width = 1 + chunk_spacing = width / 2 + + if dependent_axis is None: + # We draw the area line using the perpendicular direction to the line, because we don't know which + # direction should we draw it in. + normal = np.array([- np.gradient(y), np.gradient(x)]).T + norms = normal / np.linalg.norm(normal, axis=1).reshape(-1, 1) + + x = [*(chunk_x + norms[:, 0] * chunk_spacing), *reversed(chunk_x - norms[:, 0] * chunk_spacing)] + y = [*(chunk_y + norms[:, 1] * chunk_spacing), *reversed(chunk_y - norms[:, 1] * chunk_spacing)] + elif dependent_axis == "y": + x = [*chunk_x, *reversed(chunk_x)] + y = [*(chunk_y + chunk_spacing), *reversed(chunk_y - chunk_spacing)] + elif dependent_axis == "x": + x = [*(chunk_x + chunk_spacing), *reversed(chunk_x - chunk_spacing)] + y = [*chunk_y, *reversed(chunk_y)] + else: + raise ValueError(f"Invalid dependent axis: {dependent_axis}") + + self.add_trace({ + "type": "scatter", + "mode": "lines", + "x": x, + "y": y, + "line": {"width": 0, "color": line.get('color')}, + "name": name, + "legendgroup": name, + "showlegend": kwargs.pop("showlegend", None), + "fill": "toself" + }, row=row, col=col) + + def draw_scatter(self, x, y, name=None, marker={}, **kwargs): + marker.pop("dash", None) + self.draw_line(x, y, name, marker=marker, mode="markers", **kwargs) + + def draw_multicolor_scatter(self, *args, **kwargs): + + kwargs['marker'] = self._handle_multicolor_scatter(kwargs['marker'], kwargs) + + super().draw_multicolor_scatter(*args, **kwargs) + + def draw_line_3D(self, x, y, z, **kwargs): + self.draw_line(x, y, type="scatter3d", z=z, **kwargs) + + def draw_multicolor_line_3D(self, x, y, z, **kwargs): + kwargs['line'] = self._handle_multicolor_scatter(kwargs['line'], kwargs) + + super().draw_multicolor_line_3D(x, y, z, **kwargs) + + def draw_scatter_3D(self, *args, **kwargs): + self.draw_line_3D(*args, mode="markers", **kwargs) + + def draw_multicolor_scatter_3D(self, *args, **kwargs): + + kwargs['marker'] = self._handle_multicolor_scatter(kwargs['marker'], kwargs) + + super().draw_multicolor_scatter_3D(*args, **kwargs) + + def draw_balls_3D(self, x, y, z, name=None, marker={}, **kwargs): + + style = {} + for k in ("size", "color", "opacity"): + val = marker.get(k) + + if isinstance(val, (str, int, float)): + val = itertools.repeat(val) + + style[k] = val + + iterator = enumerate(zip(np.array(x), np.array(y), np.array(z), style["size"], style["color"], style["opacity"])) + + showlegend = True + for i, (sp_x, sp_y, sp_z, sp_size, sp_color, sp_opacity) in iterator: + self.draw_ball_3D( + xyz=[sp_x, sp_y, sp_z], + size=sp_size, color=sp_color, opacity=sp_opacity, + name=f"{name}_{i}", + legendgroup=name, showlegend=showlegend + ) + showlegend = False + + return + + def draw_ball_3D(self, xyz, size, color="gray", name=None, vertices=15, row=None, col=None, **kwargs): + self.add_trace({ + 'type': 'mesh3d', + **{key: val for key, val in sphere(center=xyz, r=size, vertices=vertices).items()}, + 'alphahull': 0, + 'color': color, + 'showscale': False, + 'name': name, + 'meta': ['({:.2f}, {:.2f}, {:.2f})'.format(*xyz)], + 'hovertemplate': '%{meta[0]}', + **kwargs + }, row=None, col=None) + + def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_angle=20, arrowhead_scale=0.3, scale: float = 1, row=None, col=None, **kwargs): + """Draws 3D arrows in plotly using a combination of a scatter3D and a Cone trace.""" + # Make sure we are dealing with numpy arrays + xyz = np.array([x, y, z]).T + dxyz = np.array(dxyz) * scale + + final_xyz = xyz + dxyz + + line = kwargs.get("line", {}).copy() + color = line.get("color") + if color is None: + color = "red" + line['color'] = color + # 3D lines don't support opacity + line.pop("opacity", None) + + name = kwargs.get("name", "Arrows") + + arrows_coords = np.empty((xyz.shape[0]*3, 3), dtype=np.float64) + + arrows_coords[0::3] = xyz + arrows_coords[1::3] = final_xyz + arrows_coords[2::3] = np.nan + + conebase_xyz = xyz + (1 - arrowhead_scale) * dxyz + + rows_cols = {} + if row is not None: + rows_cols['rows'] = [row, row] + if col is not None: + rows_cols['cols'] = [col, col] + + + self.figure.add_traces([{ + "x": arrows_coords[:, 0], + "y": arrows_coords[:, 1], + "z": arrows_coords[:, 2], + "mode": "lines", + "type": "scatter3d", + "hoverinfo": "none", + "line": line, + "legendgroup": name, + "name": f"{name} lines", + "showlegend": False, + }, + { + "type": "cone", + "x": conebase_xyz[:, 0], + "y": conebase_xyz[:, 1], + "z": conebase_xyz[:, 2], + "u": arrowhead_scale * dxyz[:, 0], + "v": arrowhead_scale * dxyz[:, 1], + "w": arrowhead_scale * dxyz[:, 2], + "hovertemplate": "[%{u}, %{v}, %{w}]", + "sizemode": "absolute", + "sizeref": arrowhead_scale * np.linalg.norm(dxyz, axis=1).max() / 2, + "colorscale": [[0, color], [1, color]], + "showscale": False, + "legendgroup": name, + "name": name, + "showlegend": True, + }], **rows_cols) + + def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None): + + self.add_trace({ + 'type': 'heatmap', 'z': values, + 'x': x, 'y': y, + 'name': name, + 'zsmooth': zsmooth, + 'coloraxis': self._get_coloraxis_name(coloraxis), + # **kwargs + }, row=row, col=col) + + def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name=None, row=None, col=None, **kwargs): + + x, y, z = vertices.T + I, J, K = faces.T + + self.add_trace(dict( + type="mesh3d", + x=x, y=y, z=z, + i=I, j=J, k=K, + color=color, + opacity=opacity, + name=name, + showlegend=True, + **kwargs + ), row=row, col=col) + + def set_axis(self, axis, _active_axes={}, **kwargs): + if axis in _active_axes: + ax_name = _active_axes[axis] + else: + ax_name = f"{axis}axis" + + updates = {} + if ax_name.endswith("axis"): + updates = {f"scene_{ax_name}": kwargs} + if axis != "z": + updates.update({ax_name: kwargs}) + + self.update_layout(**updates) + + def set_axes_equal(self, _active_axes={}): + x_axis = _active_axes.get("x", "xaxis") + y_axis = _active_axes.get("y", "yaxis").replace("axis", "") + + self.update_layout({x_axis: {"scaleanchor": y_axis, "scaleratio": 1}}) + self.update_layout(scene_aspectmode="data") \ No newline at end of file diff --git a/src/sisl/viz/figure/py3dmol.py b/src/sisl/viz/figure/py3dmol.py new file mode 100644 index 0000000000..d63a815e41 --- /dev/null +++ b/src/sisl/viz/figure/py3dmol.py @@ -0,0 +1,148 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +import collections.abc +import itertools + +import numpy as np +import py3Dmol + +from .figure import Figure + + +class Py3DmolFigure(Figure): + """Generic canvas for the py3Dmol framework""" + + def _init_figure(self, *args, **kwargs): + self.figure = py3Dmol.view() + + def draw_line(self, x, y, name="", line={}, marker={}, text=None, row=None, col=None, **kwargs): + z = np.full_like(x, 0) + # x = self._2D_scale[0] * x + # y = self._2D_scale[1] * y + return self.draw_line_3D(x, y, z, name=name, line=line, marker=marker, text=text, row=row, col=col, **kwargs) + + def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs): + z = np.full_like(x, 0) + # x = self._2D_scale[0] * x + # y = self._2D_scale[1] * y + return self.draw_scatter_3D(x, y, z, name=name, marker=marker, text=text, row=row, col=col, **kwargs) + + def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, **kwargs): + """Draws a line.""" + + xyz = np.array([x, y, z], dtype=float).T + + # To be compatible with other frameworks such as plotly and matplotlib, + # we allow x, y and z to contain None values that indicate discontinuities + # E.g.: x=[0, 1, None, 2, 3] means we should draw a line from 0 to 1 and another + # from 2 to 3. + # Here, we get the breakpoints (i.e. indices where there is a None). We add + # -1 and None at the sides to facilitate iterating. + breakpoint_indices = [-1, *np.where(np.isnan(xyz).any(axis=1))[0], None] + + # Now loop through all segments using the known breakpoints + for start_i, end_i in zip(breakpoint_indices, breakpoint_indices[1:]): + # Get the coordinates of the segment + segment_xyz = xyz[start_i+1: end_i] + + # If there is nothing to draw, go to next segment + if len(segment_xyz) == 0: + continue + + points = [{"x": x, "y": y, "z": z} for x, y, z in segment_xyz] + + # If there's only two points, py3dmol doesn't display the curve, + # probably because it can not smooth it. + if len(points) == 2: + points.append(points[-1]) + + self.figure.addCurve(dict( + points=points, + radius=line.get("width", 0.1), + color=line.get("color"), + opacity=line.get('opacity', 1.) or 1., + smooth=1 + )) + + return self + + def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, collection=None, frame=None, **kwargs): + style = { + "color": marker.get("color", "gray"), + "opacity": marker.get("opacity", 1.), + "size": marker.get("size", 1.), + } + + for k, v in style.items(): + if (not isinstance(v, (collections.abc.Sequence, np.ndarray))) or isinstance(v, str): + style[k] = itertools.repeat(v) + + for i, (x_i, y_i, z_i, color, opacity, size) in enumerate(zip(x, y, z, style["color"], style["opacity"], style["size"])): + self.figure.addSphere(dict( + center={"x": float(x_i), "y": float(y_i), "z": float(z_i)}, radius=size, color=color, opacity=opacity, + quality=5., # This does not work, but sphere quality is really bad by default + )) + + draw_scatter_3D = draw_balls_3D + + def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, scale: float = 1, row=None, col=None, line={},**kwargs): + """Draws multiple arrows using the generic draw_line method. + + Parameters + ----------- + xy: np.ndarray of shape (n_arrows, 2) + the positions where the atoms start. + dxy: np.ndarray of shape (n_arrows, 2) + the arrow vector. + arrow_head_scale: float, optional + how big is the arrow head in comparison to the arrow vector. + arrowhead_angle: angle + the angle that the arrow head forms with the direction of the arrow (in degrees). + scale: float, optional + multiplying factor to display the arrows. It does not affect the underlying data, + therefore if the data is somehow displayed it should be without the scale factor. + row: int, optional + If the figure contains subplots, the row where to draw. + col: int, optional + If the figure contains subplots, the column where to draw. + """ + # Make sure we are dealing with numpy arrays + xyz = np.array([x, y, z]).T + dxyz = np.array(dxyz) * scale + + for (x, y, z), (dx, dy, dz) in zip(xyz, dxyz): + + self.figure.addArrow(dict( + start={"x": x, "y": y, "z": z}, + end={"x": x + dx, "y": y + dy, "z": z + dz}, + radius=line.get("width", 0.1), + color=line.get("color"), + opacity=line.get("opacity", 1.), + radiusRatio=2, + mid=(1 - arrowhead_scale), + )) + + def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name="Mesh", wireframe=False, row=None, col=None, **kwargs): + + def vec_to_dict(a, labels="xyz"): + return dict(zip(labels,a)) + + self.figure.addCustom(dict( + vertexArr=[vec_to_dict(v) for v in vertices.astype(float)], + faceArr=[int(x) for f in faces for x in f], + color=color, + opacity=float(opacity or 1.), + wireframe=wireframe + )) + + def set_axis(self, *args, **kwargs): + """There are no axes titles and these kind of things in py3dmol. + At least for now, we might implement it later.""" + + def set_axes_equal(self, *args, **kwargs): + """Axes are always "equal" in py3dmol, so we do nothing here""" + + def show(self, *args, **kwargs): + self.figure.zoomTo() + return self.figure.show() \ No newline at end of file diff --git a/src/sisl/viz/input_fields/__init__.py b/src/sisl/viz/input_fields/__init__.py deleted file mode 100644 index b4b91bab02..0000000000 --- a/src/sisl/viz/input_fields/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -"""This submodule implements all input fields. - -We have the basic input fields, that need a GUI implementation, -and the rest of input fields, which are just extensions of the -basic input fields. -""" -from .aiida_node import AiidaNodeInput -from .atoms import AtomSelect, SpeciesSelect -from .axes import GeomAxisSelect -from .basic import * -from .energy import ErangeInput -from .file import FilePathInput -from .orbital import OrbitalQueries, OrbitalsNameSelect -from .programatic import FunctionInput, ProgramaticInput -from .queries import QueriesInput -from .sisl_obj import ( - BandStructureInput, - DistributionInput, - GeometryInput, - PlotableInput, - SileInput, - SislObjectInput, -) -from .spin import SpinSelect diff --git a/src/sisl/viz/input_fields/aiida_node.py b/src/sisl/viz/input_fields/aiida_node.py deleted file mode 100644 index ceeac3a47d..0000000000 --- a/src/sisl/viz/input_fields/aiida_node.py +++ /dev/null @@ -1,22 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .._input_field import InputField - -try: - from aiida import orm - AIIDA_AVAILABLE = True -except ModuleNotFoundError: - AIIDA_AVAILABLE = False - - -class AiidaNodeInput(InputField): - - dtype = orm.Node if AIIDA_AVAILABLE else None - - def parse(self, val): - - if AIIDA_AVAILABLE and val is not None and not isinstance(val, self.dtype): - val = orm.load_node(val) - - return val diff --git a/src/sisl/viz/input_fields/atoms.py b/src/sisl/viz/input_fields/atoms.py deleted file mode 100644 index 599d1a054c..0000000000 --- a/src/sisl/viz/input_fields/atoms.py +++ /dev/null @@ -1,172 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from sisl._help import isiterable - -from .basic import ( - CreatableDictInput, - DictInput, - FloatInput, - IntegerInput, - OptionsInput, - RangeInput, - RangeSliderInput, - TextInput, -) -from .category import CategoryInput - - -class AtomCategoryInput(CategoryInput): - pass - - -class AtomIndexCatInput(AtomCategoryInput, DictInput): - - def __init__(self, *args, fields=(), **kwargs): - fields = [ - OptionsInput(key="in", name="Indices", - params={ - "placeholder": "Select indices...", - "options": [], - "isMulti": True, - "isClearable": True, - "isSearchable": True, - } - ), - - *fields - ] - - super().__init__(*args, fields=fields, **kwargs) - - def parse(self, val): - if isinstance(val, int): - return val - elif isiterable(val): - return val - else: - return super().parse(val) - - def update_options(self, geom): - self.get_param("in").modify("inputField.params.options", - [{"label": f"{at} ({geom.atoms[at].symbol})", "value": at} - for at in geom]) - - -class AtomFracCoordsCatInput(AtomCategoryInput, RangeSliderInput): - - _default = { - "default": [0, 1], - "params": {"min": 0, "max": 1, "step": 0.01} - } - - -class AtomCoordsCatInput(AtomCategoryInput, RangeInput): - pass - - -class AtomZCatInput(AtomCategoryInput, IntegerInput): - - _default = { - "params": {"min": 0} - } - - -class AtomNeighboursCatInput(AtomCategoryInput, DictInput): - - def __init__(self, *args, fields=(), **kwargs): - fields = [ - RangeInput(key="range", name=""), - - FloatInput(key="R", name="R"), - - TextInput(key="neigh_tag", name="Neighbour tag", default=None), - ] - - super().__init__(*args, fields=fields, **kwargs) - - def parse(self, val): - - if isinstance(val, dict): - val = {**val} - if "range" in val: - val["min"], val["max"] = val.pop("range") - if "neigh_tag" in val: - val["neighbour"] = {"tag": val.pop("neigh_tag")} - - return val - - -class AtomTagCatInput(AtomCategoryInput, TextInput): - pass - - -class AtomSeqCatInput(AtomCategoryInput, TextInput): - pass - - -class AtomSelect(CreatableDictInput): - - _default = {} - - def __init__(self, *args, fields=(), **kwargs): - fields = [ - AtomIndexCatInput(key="index", name="Indices"), - - *[AtomFracCoordsCatInput(key=f"f{ax}", name=f"Fractional {ax.upper()}", default=[0, 1]) - for ax in "xyz"], - - *[AtomCoordsCatInput(key=ax, name=f"{ax.upper()} coordinate") - for ax in "xyz"], - - AtomZCatInput(key="Z", name="Atomic number"), - - AtomNeighboursCatInput(key="neighbours", name="Neighbours"), - - AtomTagCatInput(key="tag", name="Atom tag"), - - AtomSeqCatInput(key="seq", name="Index sequence") - ] - - super().__init__(*args, fields=fields, **kwargs) - - def update_options(self, geom): - - self.get_param("index").update_options(geom) - - return self - - def parse(self, val): - if isinstance(val, dict): - val = super().parse(val) - - return val - - -class SpeciesSelect(OptionsInput): - - _default = { - "default": None, - "params": { - "placeholder": "Select species...", - "options": [], - "isMulti": True, - "isClearable": True, - "isSearchable": True, - } - } - - def update_options(self, geom): - - self.modify("inputField.params.options", - [{"label": unique_at.symbol, "value": unique_at.symbol} - for unique_at in geom.atoms.atom]) - - return self - - -class AtomicQuery(DictInput): - - _fields = { - "atoms": {"field": AtomSelect, "name": "Atoms"} - } diff --git a/src/sisl/viz/input_fields/axes.py b/src/sisl/viz/input_fields/axes.py deleted file mode 100644 index 5d8c43895d..0000000000 --- a/src/sisl/viz/input_fields/axes.py +++ /dev/null @@ -1,51 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import re - -import numpy as np - -from .basic import OptionsInput - - -class GeomAxisSelect(OptionsInput): - - _default = { - "params": { - "placeholder": "Choose axis...", - "options": [ - {'label': ax, 'value': ax} for ax in ["x", "y", "z", "-x", "-y", "-z", "a", "b", "c", "-a", "-b", "-c"] - ], - "isMulti": True, - "isClearable": False, - "isSearchable": True, - } - } - - def _sanitize_axis(self, ax): - if isinstance(ax, str): - if re.match("[+-]?[012]", ax): - ax = ax.replace("0", "a").replace("1", "b").replace("2", "c") - ax = ax.lower().replace("+", "") - elif isinstance(ax, int): - ax = 'abc'[ax] - elif isinstance(ax, (list, tuple)): - ax = np.array(ax) - - # Now perform some checks - invalid = True - if isinstance(ax, str): - invalid = not re.match("-?[xyzabc]", ax) - elif isinstance(ax, np.ndarray): - invalid = ax.shape != (3,) - - if invalid: - raise ValueError(f"Incorrect axis passed. Axes must be one of [+-]('x', 'y', 'z', 'a', 'b', 'c', '0', '1', '2', 0, 1, 2)" + - " or a numpy array/list/tuple of shape (3, )") - - return ax - - def parse(self, val): - if isinstance(val, str): - val = re.findall("[+-]?[xyzabc012]", val) - return [self._sanitize_axis(ax) for ax in val] diff --git a/src/sisl/viz/input_fields/basic/__init__.py b/src/sisl/viz/input_fields/basic/__init__.py deleted file mode 100644 index 7b834fd8da..0000000000 --- a/src/sisl/viz/input_fields/basic/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" -This submodule contains all the basic input fields. - -If a GUI wants to use sisl.viz, it should **build an implementation for -the input fields here**. These are the building blocks that all input -fields use. - -The rest of input fields are just extensions of the ones implemented here. -Extensions only tweak details of the internal functionality (e.g. parsing) -but the graphical interface of the input needs no modification. -""" -from .array import Array1DInput, Array2DInput -from .bool import BoolInput -from .color import ColorInput -from .dict import CreatableDictInput, DictInput -from .list import ListInput -from .number import FloatInput, IntegerInput -from .options import CreatableOptionsInput, OptionsInput -from .range import RangeInput, RangeSliderInput -from .text import TextInput diff --git a/src/sisl/viz/input_fields/basic/array.py b/src/sisl/viz/input_fields/basic/array.py deleted file mode 100644 index 3cd8dc4f8e..0000000000 --- a/src/sisl/viz/input_fields/basic/array.py +++ /dev/null @@ -1,43 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from ..._input_field import InputField - - -class ArrayNDInput(InputField): - - dtype = "array-like" - - _type = 'array' - - _default = {} - - def __init__(self, *args, **kwargs): - - # If the shape of the array was not provided, register it - # This is important because otherwise when the setting is set to None - # there is no way of knowing how to display the input field. - # For variable shapes, a different input (ListInput) should be used - try: - kwargs["params"]["shape"] - except Exception: - shape = np.array(kwargs["default"]).shape - - if kwargs.get("params", False): - kwargs["params"]["shape"] = shape - else: - kwargs["params"] = {"shape": shape} - - super().__init__(*args, **kwargs) - - -class Array1DInput(ArrayNDInput): - - _type = 'vector' - - -class Array2DInput(ArrayNDInput): - - _type = "matrix" diff --git a/src/sisl/viz/input_fields/basic/bool.py b/src/sisl/viz/input_fields/basic/bool.py deleted file mode 100644 index d973c068c3..0000000000 --- a/src/sisl/viz/input_fields/basic/bool.py +++ /dev/null @@ -1,37 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ..._input_field import InputField - - -class BoolInput(InputField): - """Simple input that controls a boolean variable. - - GUI indications - ---------------- - It can be implemented as a switch or a checkbox, for example. - """ - _false_strings = ("f", "false") - _true_strings = ("t", "true") - - dtype = bool - - _type = 'bool' - - _default = {} - - def parse(self, val): - if val is None: - pass - elif isinstance(val, str): - val = val.lower() - if val in self._true_strings: - val = True - elif val in self._false_strings: - val = False - else: - raise ValueError(f"String '{val}' is not understood by {self.__class__.__name__}") - elif not isinstance(val, bool): - self._raise_type_error(val) - - return val diff --git a/src/sisl/viz/input_fields/basic/color.py b/src/sisl/viz/input_fields/basic/color.py deleted file mode 100644 index 476ccb7da9..0000000000 --- a/src/sisl/viz/input_fields/basic/color.py +++ /dev/null @@ -1,21 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ..._input_field import InputField - - -class ColorInput(InputField): - """Input field to pick a color. - - GUI indications - ---------------- - The best implementation for this input is probably a color picker - that let's the user choose any color they like. - - The value returned by the input field should be a string representing - a color in hex, rgb, rgba or any other named color supported in html. - """ - - dtype = str - - _type = 'color' diff --git a/src/sisl/viz/input_fields/basic/dict.py b/src/sisl/viz/input_fields/basic/dict.py deleted file mode 100644 index f9422e7111..0000000000 --- a/src/sisl/viz/input_fields/basic/dict.py +++ /dev/null @@ -1,136 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ..._input_field import InputField -from ...configurable import Configurable - - -class DictInput(InputField): - """Input field for a dictionary. - - GUI indications - --------------- - This input field is just a container for key-value pairs - of other inputs. Despite its simplicity, it's not trivial - to implement. One must have all the other input fields implemented - in a very modular way for `DictInput` to come up naturally. Otherwise - it can get complicated. - - - `param.inputField["params"]["fields"]` contains a list of all the input - fields that are contained in the dictionary. Each input field can be of - any type. - """ - - dtype = dict - - _type = 'dict' - - _fields = [] - - _default = {} - - @property - def fields(self): - return self.inputField["fields"] - - def __init__(self, *args, fields=[], help="", **kwargs): - - fields = self._sanitize_fields(fields) - - input_field_attrs = { - **kwargs.pop("input_field_attrs", {}), - "fields": fields, - } - - def get_fields_help(): - return "\n\t".join([f"'{param.key}': {param.help}" for param in fields]) - - help += "\n\n Structure of the dict: {\n\t" + get_fields_help() + "\n}" - - super().__init__(*args, **kwargs, help=help, input_field_attrs=input_field_attrs) - - def _sanitize_fields(self, fields): - """Parses the fields, converting strings to the known input fields (under self._fields).""" - sanitized_fields = [] - for i, field in enumerate(fields): - if isinstance(field, str): - if field not in self._fields: - raise KeyError( - f"{self.__class__.__name__} has no pre-built field for '{field}'") - - built_field = self._fields[field]['field']( - key=field, **{key: val for key, val in self._fields[field].items() if key != 'field'} - ) - - sanitized_fields.append(built_field) - else: - sanitized_fields.append(field) - - return sanitized_fields - - def get_param(self, key, **kwargs): - """Gets a parameter from the fields of this dictionary.""" - return Configurable.get_param( - self, key, params_extractor=lambda obj: obj.inputField["fields"], **kwargs - ) - - def modify_param(self, key, *args, **kwargs): - """Modifies a parameter from the fields of this dictionary.""" - return Configurable.modify_param(self, key, *args, **kwargs) - - def complete_dict(self, query, **kwargs): - """Completes a partially build dictionary with the missing fields. - - Parameters - ----------- - query: dict - the query to be completed. - **kwargs: - other keys that need to be added to the query IN CASE THEY DON'T ALREADY EXIST - """ - return { - **{param.key: param.default for param in self.fields}, - **kwargs, - **query - } - - def parse(self, val): - if val is None: - val = {} - if not isinstance(val, dict): - self._raise_type_error(val) - - val = {**val} - for field in self.fields: - if field.key in val: - val[field.key] = field.parse(val[field.key]) - - return val - - def __getitem__(self, key): - for field in self.inputField['fields']: - if field.key == key: - return field - - return super().__getitem__(key) - - def __contains__(self, key): - - for field in self.inputField['fields']: - if field.key == key: - return True - - return False - - -class CreatableDictInput(DictInput): - """Input field for a dictionary for which entries can be created and removed. - - GUI indications - --------------- - This input is a bit trickier than `DictInput`. It should be possible to remove - and add the fields as you wish. - """ - - _type = "creatable dict" diff --git a/src/sisl/viz/input_fields/basic/list.py b/src/sisl/viz/input_fields/basic/list.py deleted file mode 100644 index 761d0fafad..0000000000 --- a/src/sisl/viz/input_fields/basic/list.py +++ /dev/null @@ -1,46 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from sisl._help import isiterable - -from ..._input_field import InputField -from .text import TextInput - - -class ListInput(InputField): - """A versatile list input. - - GUI indications - --------------- - This input field is to be used to create and mutate lists. Therefore, - it should implement some way of performing the following actions: - - Create a new element - - Delete an existing element - - Reorganize the existing list. - - The input field of each element in the list is indicated in `param.inputField["params"]` - under the `"itemInput"` key. The `"sortable"` key contains a boolean specifying whether - the user should have the ability of sorting the list. - """ - - dtype = "array-like" - - _type = 'list' - - _default = { - "params": {"itemInput": TextInput("-", "-"), "sortable": True} - } - - def get_item_input(self): - return self.inputField["params"]["itemInput"] - - def modify_item_input(self, *args, **kwargs): - return self.get_item_input().modify(*args, **kwargs) - - def parse(self, val): - if val is None: - return val - elif not isiterable(val): - self._raise_type_error(val) - - return [self.get_item_input().parse(v) for v in val] diff --git a/src/sisl/viz/input_fields/basic/number.py b/src/sisl/viz/input_fields/basic/number.py deleted file mode 100644 index 1b67607acb..0000000000 --- a/src/sisl/viz/input_fields/basic/number.py +++ /dev/null @@ -1,88 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from sisl._help import isiterable - -from ..._input_field import InputField - - -class NumericInput(InputField): - """Simple input for a number. - - GUI indications - ---------------- - If you have a `param` that uses a `NumericInput`, you will find a dictionary - at `param.inputField["params"]` that has some specifications that your input - field should fulfill. These are: {"min", "max", "step"}. - - E.g. if `param.inputField["params"]` is `{"min": 0, "max": 1, "step": 0.1}`, - your input field needs to make sure that the value is always contained between - 0 and 1 and can be increased/decreased in steps of 0.1. - """ - - _type = 'number' - - _default = { - "default": 0, - "params": { - "min": 0, - } - } - - -class IntegerInput(NumericInput): - """Simple input for an integer. - - GUI indications - ---------------- - No implementation needed for this input field, if your `NumericInput` - implementation supports "min", "max" and "step" correctly, you already - have an `IntegerInput`. - """ - - dtype = int - - _default = { - **NumericInput._default, - "params": { - **NumericInput._default["params"], - "step": 1 - } - } - - def parse(self, val): - if val is None: - return val - if isiterable(val): - return np.array(val, dtype=int) - return int(val) - - -class FloatInput(NumericInput): - """Simple input for an integer. - - GUI indications - ---------------- - No implementation needed for this input field, if your `NumericInput` - implementation supports "min", "max" and "step" correctly, you already - have a `FloatInput`. - """ - - dtype = float - - _default = { - **NumericInput._default, - "params": { - **NumericInput._default["params"], - "step": 0.1 - } - } - - def parse(self, val): - if val is None: - return val - if isiterable(val): - return np.array(val, dtype=float) - return float(val) diff --git a/src/sisl/viz/input_fields/basic/options.py b/src/sisl/viz/input_fields/basic/options.py deleted file mode 100644 index af7cc1712f..0000000000 --- a/src/sisl/viz/input_fields/basic/options.py +++ /dev/null @@ -1,86 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ..._input_field import InputField - - -class OptionsInput(InputField): - """Input to select between different options. - - GUI indications - --------------- - The interface of this input field is left to the choice of the GUI - designer. Some possibilities are: - - Checkboxes or radiobuttons. - - A dropdown, better if there are many options. - - Whatever interface one chooses to implement, it should comply with - the following properties described at `param.inputField["params"]`: - placeholder: str - Not meaningful in some implementations. The text shown if - there's no option chosen. This is optional to implement, it just - makes the input field more explicit. - options: list of dicts like {"label": "value_label", "value": value} - Each dictionary represents an available option. `"value"` contains - the value that this option represents, while "label" may be a more - human readable description of the value. The label is what should - be shown to the user. - isMulti: boolean - Whether multiple options can be selected. - isClearable: boolean - Whether the input field can have an empty value (all its options - can be deselected). - """ - - _type = 'options' - - _default = { - "params": { - "placeholder": "Choose an option...", - "options": [ - ], - "isMulti": False, - "isClearable": True, - "isSearchable": True, - } - } - - def __init__(self, *args, **kwargs): - - # Build the help string - params = kwargs.get("params") - if "dtype" not in kwargs and params is not None: - - multiple_choice = getattr(params, "isMulti", False) - if multiple_choice: - self.dtype = "array-like" - - options = getattr(params, "options", None) - if options is not None: - self.valid_vals = [option["value"] for option in options] - - super().__init__(*args, **kwargs) - - def get_options(self, raw=False): - return [opt if raw else opt["value"] for opt in self['inputField.params.options']] - - def _set_options(self, val): - self.modify("inputField.params.options", val) - - options = property(fget=get_options, fset=_set_options) - - -class CreatableOptionsInput(OptionsInput): - """Input to select between different options and potentially create new ones. - - GUI indications - --------------- - This field is very similar to `OptionsInput`. The implementation should follow - the details described for `OptionsInput`. Additionally, it should **allow the - creation of new options**. - - This input will be used when there's no specific set of options, but we want to - cover some of the most common ones. - """ - - _type = "creatable options" diff --git a/src/sisl/viz/input_fields/basic/range.py b/src/sisl/viz/input_fields/basic/range.py deleted file mode 100644 index effac72680..0000000000 --- a/src/sisl/viz/input_fields/basic/range.py +++ /dev/null @@ -1,81 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from ..._input_field import InputField - - -class RangeInput(InputField): - """Simple range input composed of two values, min and max. - - GUI indications - ---------------- - This input field is the interface to an array of length 2 that specifies - some range. E.g. a valid value could be `[0, 1]`, which would mean "from - 0 to 1". Some simple implementation can be just two numeric inputs. - - It should make sure that if `param.inputField["params"]` contains min and - max, the values of the range never go beyond those limits. - """ - - dtype = "array-like of shape (2,)" - - _type = 'range' - - _default = { - "params": { - 'step': 0.1 - } - } - - -class RangeSliderInput(InputField): - """Slider that controls a range. - - GUI indications - ---------------- - A slider that lets you select a range. - - It is used over `RangeInput` when the bounds of the range (min and max) - are very well defined. The reason to prefer a slider is that visually is - much better. However, if it's not possible to implement, this field can - use the same interface as `RangeInput` without problem. - - It should make sure that if `param.inputField["params"]` contains min and - max, the values of the range never go beyond those limits. Also, - `param.inputField["params"]["marks"]` can contain a list of dictionaries - with the values and labels of the ticks that should appear in the slider. - I.e. `[{"value": 0, "label": "mark1"}]` indicates that there should be a - tick at 0 with label "mark1". - """ - - dtype = "array-like of shape (2,)" - - _type = 'rangeslider' - - _default = { - "width": "s100%", - "params": { - "min": -10, - "max": 10, - "step": 0.1, - } - } - - def update_marks(self, marks=None): - """Updates the marks of the rangeslider. - - Parameters - ---------- - marks: dict, optional - a dict like {value: label, ...} for each mark that we want. - - If no marks are passed, the method will try to update the marks acoording to the current - min and max values. - """ - if marks is None: - marks = [{"value": int(val), "label": str(val)} for val in np.arange( - self.inputField["params"]["min"], self.inputField["params"]["max"], 1, dtype=int)] - - self.modify("inputField.params.marks", marks) diff --git a/src/sisl/viz/input_fields/basic/tests/test_parsing.py b/src/sisl/viz/input_fields/basic/tests/test_parsing.py deleted file mode 100644 index 0d17f6a692..0000000000 --- a/src/sisl/viz/input_fields/basic/tests/test_parsing.py +++ /dev/null @@ -1,126 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ast import parse - -import numpy as np -import pytest - -from sisl.viz.input_fields import ( - BoolInput, - DictInput, - FloatInput, - IntegerInput, - ListInput, - TextInput, -) - - -def test_text_input_parse(): - input_field = TextInput(key="test", name="Test") - - assert input_field.parse("Some test input") == "Some test input" - - -def test_integer_input_parse(): - input_field = IntegerInput(key="test", name="Test") - - assert input_field.parse(3) == 3 - assert input_field.parse("3") == 3 - assert input_field.parse(3.2) == 3 - - with pytest.raises(ValueError): - input_field.parse("Some non-integer") - - assert input_field.parse(None) is None - - # Test that it can also accept arrays of integers and - # it parses them to numpy arrays. - for val in [3, 3.2]: - parsed_array = input_field.parse([val]) - assert isinstance(parsed_array, np.ndarray) - assert np.all(parsed_array == [3]) - - -def test_float_input_parse(): - input_field = FloatInput(key="test", name="Test") - - assert input_field.parse(3.2) == 3.2 - - with pytest.raises(ValueError): - input_field.parse("Some non-float") - - assert input_field.parse(None) is None - - # Test that it can also accept arrays of floats and - # it parses them to numpy arrays. - parsed_array = input_field.parse([3.2]) - assert isinstance(parsed_array, np.ndarray) - assert np.all(parsed_array == [3.2]) - - -def test_bool_input_parse(): - input_field = BoolInput(key="test", name="Test") - - assert input_field.parse(True) == True - assert input_field.parse(False) == False - - with pytest.raises(ValueError): - input_field.parse("Some non-boolean string") - - assert input_field.parse("true") == True - assert input_field.parse("True") == True - assert input_field.parse("t") == True - assert input_field.parse("T") == True - - assert input_field.parse("false") == False - assert input_field.parse("False") == False - assert input_field.parse("f") == False - assert input_field.parse("F") == False - - with pytest.raises(TypeError): - input_field.parse([]) - - assert input_field.parse(None) is None - - -def test_dict_input_parse(): - input_field = DictInput( - key="test", name="Test", - fields=[ - TextInput(key="a", name="A"), - IntegerInput(key="b", name="B"), - ] - ) - - assert input_field.parse({"a": "S", "b": 3}) == {"a": "S", "b": 3} - assert input_field.parse({"a": "S", "b": 3.2}) == {"a": "S", "b": 3} - - with pytest.raises(ValueError): - input_field.parse({"a": "S", "b": "Some non-integer"}) - - assert input_field.parse(None) == {} - assert input_field.parse({}) == {} - - with pytest.raises(TypeError): - input_field.parse(3) - - -def test_list_input_parse(): - - input_field = ListInput(key="test", name="Test", - params={"itemInput": IntegerInput(key="_", name="_")} - ) - - assert input_field.parse([]) == [] - assert input_field.parse([3]) == [3] - assert input_field.parse([3.2]) == [3] - assert input_field.parse((3.2,)) == [3] - assert input_field.parse(np.array([3.2])) == [3] - - with pytest.raises(TypeError): - input_field.parse(3) - with pytest.raises(ValueError): - input_field.parse(["Some non-integer"]) - - assert input_field.parse(None) is None diff --git a/src/sisl/viz/input_fields/basic/text.py b/src/sisl/viz/input_fields/basic/text.py deleted file mode 100644 index 585dfd8097..0000000000 --- a/src/sisl/viz/input_fields/basic/text.py +++ /dev/null @@ -1,26 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from ..._input_field import InputField - - -class TextInput(InputField): - """Simple input for text. - - GUI indications - ---------------- - The implementation of this input should be a simple text field. - - Optionally, you may use `param.inputField["params"]["placeholder"]` - as the placeholder. - """ - - dtype = str - - _type = "textinput" - - _default = { - "params": { - "placeholder": "Write your value here...", - } - } diff --git a/src/sisl/viz/input_fields/category.py b/src/sisl/viz/input_fields/category.py deleted file mode 100644 index 42757a1418..0000000000 --- a/src/sisl/viz/input_fields/category.py +++ /dev/null @@ -1,8 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .._input_field import InputField - - -class CategoryInput(InputField): - pass diff --git a/src/sisl/viz/input_fields/energy.py b/src/sisl/viz/input_fields/energy.py deleted file mode 100644 index 7ed3c4f4aa..0000000000 --- a/src/sisl/viz/input_fields/energy.py +++ /dev/null @@ -1,13 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .basic import RangeInput - - -class ErangeInput(RangeInput): - - def __init__(self, key, name="Energy range", default=None, - params={"step": 1}, help="The energy range that is displayed", **kwargs): - - super().__init__(key=key, name=name, default=default, - params=params, help=help, **kwargs) diff --git a/src/sisl/viz/input_fields/file.py b/src/sisl/viz/input_fields/file.py deleted file mode 100644 index a7c25d1f3a..0000000000 --- a/src/sisl/viz/input_fields/file.py +++ /dev/null @@ -1,34 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from pathlib import Path - -from sisl import BaseSile - -from .basic import TextInput - -if not hasattr(BaseSile, "to_json"): - # Little patch so that Siles can be sent to the GUI - def sile_to_json(self): - return str(self.file) - - BaseSile.to_json = sile_to_json - - -class FilePathInput(TextInput): - - _default = { - "params": { - "placeholder": "Write your path here...", - } - } - - def parse(self, val): - - if isinstance(val, BaseSile): - val = val.file - - if isinstance(val, str): - val = Path(val) - - return val diff --git a/src/sisl/viz/input_fields/orbital.py b/src/sisl/viz/input_fields/orbital.py deleted file mode 100644 index 2d63ec1378..0000000000 --- a/src/sisl/viz/input_fields/orbital.py +++ /dev/null @@ -1,363 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from collections import defaultdict - -import numpy as np - -from .atoms import AtomSelect, SpeciesSelect -from .basic import OptionsInput -from .queries import QueriesInput -from .spin import SpinSelect - - -class OrbitalsNameSelect(OptionsInput): - - _default = { - "default": None, - "params": { - "placeholder": "Select orbitals...", - "options": [], - "isMulti": True, - "isClearable": True, - "isSearchable": True, - } - } - - def update_options(self, geom): - - orbs = set([orb.name() - for unique_at in geom.atoms.atom for orb in unique_at]) - - self.modify("inputField.params.options", - [{"label": orb, "value": orb} - for orb in orbs]) - - return self - - -class OrbitalQueries(QueriesInput): - """ - This class implements an input field that allows you to select orbitals by atom, species, etc... - """ - - _fields = { - "species": {"field": SpeciesSelect, "name": "Species"}, - "atoms": {"field": AtomSelect, "name": "Atoms"}, - "orbitals": {"field": OrbitalsNameSelect, "name": "Orbitals"}, - "spin": {"field": SpinSelect, "name": "Spin"}, - } - - _keys_to_cols = { - "atoms": "atom", - "orbitals": "orbital_name", - } - - def _build_orb_filtering_df(self, geom): - import pandas as pd - - orb_props = defaultdict(list) - del_key = set() - #Loop over all orbitals of the basis - for at, iorb in geom.iter_orbitals(): - - atom = geom.atoms[at] - orb = atom[iorb] - - orb_props["atom"].append(at) - orb_props["Z"].append(atom.Z) - orb_props["species"].append(atom.symbol) - orb_props["orbital_name"].append(orb.name()) - - for key in ("n", "l", "m", "zeta"): - val = getattr(orb, key, None) - if val is None: - del_key.add(key) - orb_props[key].append(val) - - for key in del_key: - del orb_props[key] - - self.orb_filtering_df = pd.DataFrame(orb_props) - - def update_options(self, geometry, spin=""): - """ - Updates the options of the orbital queries. - - Parameters - ----------- - geometry: sisl.Geometry - the geometry that contains the orbitals that can be selected. - spin: sisl.Spin, str or int - It is used to indicate the kind of spin so that the spin selector - (in case there is one) can display the appropiate options. - - See also - --------- - sisl.viz.input_fields.dropdown.SpinSelect - sisl.physics.Spin - """ - self.geometry = geometry - - for key in ("species", "atoms", "orbitals"): - try: - self.get_query_param(key).update_options(geometry) - except KeyError: - pass - - try: - self.get_query_param('spin').update_options(spin) - except KeyError: - pass - - self._build_orb_filtering_df(geometry) - - def get_options(self, key, **kwargs): - """ - Gets the options for a given key or combination of keys. - - Parameters - ------------ - key: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"} - the parameter that you want the options for. - - Note that you can combine them with a "+" to get all the possible combinations. - You can get the same effect also by passing a list. - See examples. - **kwargs: - keyword arguments that add additional conditions to the query. The values of this - keyword arguments can be lists, in which case it indicates that you want a value - that is in the list. See examples. - - Returns - ---------- - np.ndarray of shape (n_options, [n_keys]) - all the possible options. - - If only one key was provided, it is a one dimensional array. - - Examples - ----------- - - >>> plot = H.plot.pdos() - >>> plot.get_param("requests").get_options("l", species="Au") - >>> plot.get_param("requests").get_options("n+l", atoms=[0,1]) - """ - # Get the tadatframe - df = self.orb_filtering_df - - # Filter the dataframe according to the constraints imposed by the kwargs, - # if there are any. - if kwargs: - if "atoms" in kwargs: - kwargs["atoms"] = self.geometry._sanitize_atoms(kwargs["atoms"]) - def _repr(v): - if isinstance(v, np.ndarray): - v = list(v.ravel()) - if isinstance(v, dict): - raise Exception(str(v)) - return repr(v) - query = ' & '.join([f'{self._keys_to_cols.get(k, k)}=={_repr(v)}' for k, v in kwargs.items( - ) if self._keys_to_cols.get(k, k) in df]) - if query: - df = df.query(query) - - # If + is in key, it is a composite key. In that case we are going to - # split it into all the keys that are present and get the options for all - # of them. At the end we are going to return a list of tuples that will be all - # the possible combinations of the keys. - keys = [self._keys_to_cols.get(k, k) for k in key.split("+")] - - # Spin values are not stored in the orbital filtering dataframe. If the options - # for spin are requested, we need to pop the key out and get the current options - # for spin from the input field - spin_in_keys = "spin" in keys - if spin_in_keys: - spin_key_i = keys.index("spin") - keys.remove("spin") - - spin_param = self.get_param("spin") - spin_options = spin_param.options - - if spin_param.spin.is_polarized and len(spin_options) > 1: - spin_options = (0, 1) - - # We might have some constraints on what the spin value can be - if "spin" in kwargs: - spin_options = set(spin_options).intersection(kwargs["spin"]) - - # Now get the unique options from the dataframe - if keys: - options = df.drop_duplicates(subset=keys)[ - keys].values.astype(object) - else: - # It might be the only key was "spin", then we are going to fake it - # to get an options array that can be treated in the same way. - options = np.array([[]], dtype=object) - - # If "spin" was one of the keys, we are going to incorporate the spin options, taking into - # account the position (column index) where they are expected to be returned. - if spin_in_keys and len(spin_options) > 0: - options = np.concatenate( - [np.insert(options, spin_key_i, spin, axis=1) for spin in spin_options]) - - # Squeeze the options array, just in case there is only one key - # There's a special case: if there is only one option for that key, - # squeeze converts it to a number, so we need to make sure there is at least 1d - if options.shape[1] == 1: - options = options.squeeze() - options = np.atleast_1d(options) - - return options - - def get_orbitals(self, query): - - if "atoms" in query: - query["atoms"] = self.geometry._sanitize_atoms(query["atoms"]) - - filtered_df = self.filter_df( - self.orb_filtering_df, query, self._keys_to_cols) - - return filtered_df.index - - def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignore_constraints=False, **kwargs): - """ - Splits a query into multiple queries based on one of its parameters. - - Parameters - -------- - query: dict - the query that we want to split - on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"}, or list of str - the parameter to split along. - Note that you can combine parameters with a "+" to split along multiple parameters - at the same time. You can get the same effect also by passing a list. - only: array-like, optional - if desired, the only values that should be plotted out of - all of the values that come from the splitting. - exclude: array-like, optional - values of the splitting that should not be plotted. - query_gen: function, optional - the request generator. It is a function that takes all the parameters for each - request that this method has come up with and gets a chance to do some modifications. - - This may be useful, for example, to give each request a color, or a custom name. - ignore_constraints: boolean or array-like, optional - determines whether constraints (imposed by the query that you want to split) - on the parameters that we want to split along should be taken into consideration. - - If `False`: all constraints considered. - If `True`: no constraints considered. - If array-like: parameters contained in the list ignore their constraints. - **kwargs: - keyword arguments that go directly to each new request. - - This is useful to add extra filters. For example: - - `self._split_query(request, on="orbitals", spin=[0])` - will split the request on the different orbitals but will take - only the contributions from spin up. - """ - if exclude is None: - exclude = [] - - # Divide the splitting request into all the parameters - if isinstance(on, str): - on = on.split("+") - - # Get the current values of the parameters that we want to split the request on - # because these will be our constraints. If a parameter is set to None or not - # provided, we have no constraints for that parameter. - constraints = {} - if ignore_constraints is not True: - - if ignore_constraints is False: - ignore_constraints = () - - for key in filter(lambda key: key not in ignore_constraints, on): - val = query.get(key, None) - if val is not None: - constraints[key] = val - - # Knowing what are our constraints (which may be none), get the available options - values = self.get_options("+".join(on), **constraints) - - # We are going to make sure that, even if there was only one parameter to split on, - # the values are two dimensional. In this way, we can take the same actions for the - # case when there is only one parameter and the case when there are multiple. - if values.ndim == 1: - values = values.reshape(-1, 1) - - # If no function to modify queries was provided we are just going to generate a - # dummy one that just returns the query as it gets it - if query_gen is None: - def query_gen(**kwargs): - return kwargs - - # We ensure that on is a list even if there is only one parameter, for the same - # reason we ensured values was 2 dimensional - if isinstance(on, str): - on = on.split("+") - - # Define the name that we will give to the new queries, using templating - # If a splitting parameter is not used by the name, we are going to - # append it, in order to make names unique and self-explanatory. - base_name = kwargs.pop("name", query.get("name", "")) - first_added = True - for key in on: - kwargs.pop(key, None) - - if f"${key}" not in base_name: - base_name += f"{' | ' if first_added else ', '}{key}=${key}" - first_added = False - - # Now build all the queries - queries = [] - for i, value in enumerate(values): - if value not in exclude and (only is None or value in only): - - # Use the name template to generate the name for this query - name = base_name - for key, val in zip(on, value): - name = name.replace(f"${key}", str(val)) - - # And append the new query to the queries - queries.append( - query_gen(**{ - **query, - **{key: [val] for key, val in zip(on, value)}, - "name": name, **kwargs - }) - ) - - return queries - - def _generate_queries(self, on, only=None, exclude=None, query_gen=None, **kwargs): - """ - Automatically generates queries based on the current options. - - Parameters - -------- - on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"} or list of str - the parameter to split along. - Note that you can combine parameters with a "+" to split along multiple parameters - at the same time. You can get the same effect also by passing a list. - only: array-like, optional - if desired, the only values that should be plotted out of - all of the values that come from the splitting. - exclude: array-like, optional - values that should not be plotted - query_gen: function, optional - the request generator. It is a function that takes all the parameters for each - request that this method has come up with and gets a chance to do some modifications. - - This may be useful, for example, to give each request a color, or a custom name. - **kwargs: - keyword arguments that go directly to each request. - - This is useful to add extra filters. For example: - `plot._generate_requests(on="orbitals", species=["C"])` - will split the PDOS on the different orbitals but will take - only those that belong to carbon atoms. - """ - return self._split_query({}, on=on, only=only, exclude=exclude, query_gen=query_gen, **kwargs) diff --git a/src/sisl/viz/input_fields/programatic.py b/src/sisl/viz/input_fields/programatic.py deleted file mode 100644 index 9705309872..0000000000 --- a/src/sisl/viz/input_fields/programatic.py +++ /dev/null @@ -1,35 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .._input_field import InputField - - -class ProgramaticInput(InputField): - - _type="programatic" - - def __init__(self, *args, help="", **kwargs): - - #help = f"only meant to be provided programatically. {help}" - - super().__init__(*args, help=help, **kwargs) - - -class FunctionInput(ProgramaticInput): - """ - This input will be used for those settings that are expecting functions. - - Parameters - --------- - positional: array-like of str, optional - The names of the positional arguments that this function should expect. - keyword: array-like of str, optional - The names of the keyword arguments that this function should expect. - returns: array-like of type - The datatypes that the function is expected to return. - """ - - _type="function" - - def __init__(self, *args, positional=None, keyword=None, returns=None, **kwargs): - super().__init__(*args, **kwargs) diff --git a/src/sisl/viz/input_fields/queries.py b/src/sisl/viz/input_fields/queries.py deleted file mode 100644 index aee58e4287..0000000000 --- a/src/sisl/viz/input_fields/queries.py +++ /dev/null @@ -1,145 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from .basic import DictInput, ListInput - - -class QueriesInput(ListInput): - """ - Parameters - ---------- - queryForm: list of InputField - The list of input fields that conform a query. - """ - - dtype = "array-like of dict" - - _dict_input = DictInput - - _default = {} - - def __init__(self, *args, queryForm=[], help="", params={}, **kwargs): - - query_form = self._sanitize_queryform(queryForm) - - self._dict_param = self._dict_input(key="", name="", fields=query_form) - - params = { - "sortable": True, - "itemInput": self._dict_param, - **params, - } - - input_field_attrs = { - **kwargs.get("input_field_attrs", {}), - } - - help += f"\n\n Each item is a dict. {self._dict_param.help}" - - super().__init__(*args, **kwargs, help=help, params=params, input_field_attrs=input_field_attrs) - - def get_query_param(self, key, **kwargs): - """Gets the parameter info for a given key.""" - return self._dict_param.get_param(key, **kwargs) - - def get_param(self, *args, **kwargs): - """ - Just a clone of get_query_param. - - Because Configurable looks for this method when modifying parameters, but the other name is clearer. - """ - return self.get_query_param(*args, **kwargs) - - def modify_query_param(self, key, *args, **kwargs): - """ - Uses Configurable.modify_param to modify a parameter inside QueryForm - """ - return self._dict_param.modify_param(self, key, *args, **kwargs) - - def complete_query(self, query, **kwargs): - """ - Completes a partially build query with the default values - - Parameters - ----------- - query: dict - the query to be completed. - **kwargs: - other keys that need to be added to the query IN CASE THEY DON'T ALREADY EXIST - """ - return { - "active": True, - **self._dict_param.complete_dict(query, **kwargs), - } - - def filter_df(self, df, query, key_to_cols, raise_not_active=False): - """ - Filters a dataframe according to a query - - Parameters - ----------- - df: pd.DataFrame - the dataframe to filter. - query: dict - the query to be used as a filter. Can be incomplete, it will be completed using - `self.complete_query()` - keys_to_cols: array-like of tuples - An array of tuples that look like (key, col) - where key is the key of the parameter in the query and col the corresponding - column in the dataframe. - """ - query = self.complete_query(query) - - if raise_not_active: - if not query["active"]: - raise ValueError(f"Query {query} is not active and you are trying to use it") - - query_str = [] - for key, val in query.items(): - key = key_to_cols.get(key, key) - if key in df and val is not None: - if isinstance(val, (np.ndarray, tuple)): - val = np.ravel(val).tolist() - query_str.append(f'{key}=={repr(val)}') - - return df.query(" & ".join(query_str)) - - def _sanitize_queryform(self, queryform): - """ - Parses a query form to fields, converting strings - to the known input fields (under self._fields). As an example, - see OrbitalQueries. - """ - sanitized_form = [] - for i, field in enumerate(queryform): - if isinstance(field, str): - if field not in self._fields: - raise KeyError( - f"{self.__class__.__name__} has no pre-built field for '{field}'") - - built_field = self._fields[field]['field']( - key=field, **{key: val for key, val in self._fields[field].items() if key != 'field'} - ) - - sanitized_form.append(built_field) - else: - sanitized_form.append(field) - - return sanitized_form - - def parse(self, val): - if isinstance(val, dict): - val = [val] - - return super().parse(val) - - def __getitem__(self, key): - try: - return self._dict_param.get_param(key) - except KeyError: - return super().__getitem__(key) - - def __contains__(self, key): - return self._dict_param.__contains__(key) diff --git a/src/sisl/viz/input_fields/sisl_obj.py b/src/sisl/viz/input_fields/sisl_obj.py deleted file mode 100644 index 6d404e1105..0000000000 --- a/src/sisl/viz/input_fields/sisl_obj.py +++ /dev/null @@ -1,259 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" This input field is prepared to receive sisl objects that are plotables """ -from pathlib import Path - -import sisl -from sisl import BaseSile -from sisl.physics import distribution - -from .._input_field import InputField -from .basic import BoolInput, DictInput, FloatInput, IntegerInput, OptionsInput, TextInput -from .file import FilePathInput -from .queries import QueriesInput - -if not hasattr(BaseSile, "to_json"): - # Little patch so that Siles can be sent to the GUI - def sile_to_json(self): - return str(self.file) - - BaseSile.to_json = sile_to_json - - -forced_keys = { - sisl.Geometry: 'geometry', - sisl.Hamiltonian: 'H', - sisl.BandStructure: 'band_structure', - sisl.BrillouinZone: 'brillouin_zone', - sisl.Grid: 'grid', - sisl.EigenstateElectron: 'eigenstate', -} - - -class SislObjectInput(InputField): - - _type = "sisl_object" - - def __init__(self, key, *args, **kwargs): - - super().__init__(key, *args, **kwargs) - - if self.dtype is None: - raise ValueError(f'Please provide a dtype for {key}') - - valid_key = forced_keys.get(self.dtype, None) - - if valid_key is not None and not key.endswith(valid_key): - - raise ValueError( - f'Invalid key ("{key}") for an input that accepts {kwargs["dtype"]}, please use {valid_key}' - 'to help keeping consistency across sisl and therefore make the world a better place.' - f'If there are multiple settings that accept {kwargs["dtype"]}, please use *_{valid_key}' - ) - - -class GeometryInput(SislObjectInput): - - dtype = (sisl.Geometry, "sile (or path to file) that contains a geometry") - _dtype = (str, sisl.Geometry, *sisl.get_siles(attrs=['read_geometry'])) - - def parse(self, val): - - if isinstance(val, (str, Path)): - val = sisl.get_sile(val) - if isinstance(val, sisl.io.BaseSile): - val = val.read_geometry() - - return val - - -class HamiltonianInput(SislObjectInput): - pass - - -class BandStructureInput(QueriesInput, SislObjectInput): - - dtype = sisl.BandStructure - - def __init__(self, *args, **kwargs): - kwargs["help"] = """A band structure. it can either be provided as a sisl.BandStructure object or - as a list of points, which will be parsed into a band structure object. - """ - - # Let's define the queryform. Each query will be a point of the path. - kwargs["queryForm"] = [ - - FloatInput( - key="x", name="X", - default=0, - params={ - "step": 0.01 - } - ), - - FloatInput( - key="y", name="Y", - default=0, - params={ - "step": 0.01 - } - ), - - FloatInput( - key="z", name="Z", - default=0, - params={ - "step": 0.01 - } - ), - - IntegerInput( - key="divisions", name="Divisions", - default=50, - params={ - "min": 0, - "step": 10 - } - ), - - TextInput( - key="name", name="Name", - default=None, - params = { - "placeholder": "Name..." - }, - help = "Tick that should be displayed at this corner of the path." - ), - - BoolInput( - key="jump", name="Jump", - default=False, - help="""If True, this point just signals a discontinuity and the rest - of inputs for this point will be ignored. - """ - ), - ] - - super().__init__(*args, **kwargs) - - def parse(self, val): - if not isinstance(val, sisl.BandStructure) and val is not None: - # Then let's parse the list of points into a band structure object. - # Use only those points that are active. - val = [point for point in val if point.get("active", True)] - - points = [] - divisions = [] - names = [] - # Loop over all points and construct the inputs for BandStructure - for i_point, point in enumerate(val): - if point.get("jump") is True: - # This is a discontinuity - points.append(None) - if i_point > 0: - divisions.append(1) - else: - # This is an actual point in the band structure. - points.append( - [point.get("x", None) or 0, point.get("y", None) or 0, point.get("z", None) or 0] - ) - names.append(point.get("name", "")) - if i_point > 0: - divisions.append(int(point["divisions"])) - - print(points, divisions, names) - - val = sisl.BandStructure(None, points=points, divisions=divisions, names=names) - - return val - - -class BrillouinZoneInput(SislObjectInput): - pass - - -class GridInput(SislObjectInput): - pass - - -class EigenstateElectronInput(SislObjectInput): - pass - - -class PlotableInput(SislObjectInput): - - _type = "plotable" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class DistributionInput(DictInput, SislObjectInput): - - def __init__(self, *args, **kwargs): - # Let's define the queryform (although we only want one point for now we use QueriesInput for convenience) - kwargs["fields"] = [ - - OptionsInput( - key="method", name="Method", - default="gaussian", - params={ - "options": [{"label": dist, "value": dist} for dist in distribution.__all__ if dist != "get_distribution"], - "isMulti": False, - "isClearable": False, - } - ), - - FloatInput( - key="smearing", name="Smearing", - default=0.1, - params={ - "step": 0.01 - } - ), - - FloatInput( - key="x0", name="Center", - default=0.0, - params={ - "step": 0.01 - } - ), - ] - - super().__init__(*args, **kwargs) - - def parse(self, val): - if val and not callable(val): - if isinstance(val, str): - val = distribution.get_distribution(val) - else: - val = distribution.get_distribution(**self.complete_dict(val)) - - return val - - -class SileInput(FilePathInput, SislObjectInput): - - def __init__(self, *args, required_attrs=None, **kwargs): - - if required_attrs: - self._required_attrs = required_attrs - kwargs["dtype"] = None - - super().__init__(*args, **kwargs) - - def _get_dtype(self): - """ - This is a temporal fix because for some reason some sile classes can not be pickled - """ - if hasattr(self, "_required_attrs"): - return tuple(sisl.get_siles(attrs=self._required_attrs)) - else: - return self.__dict__["dtype"] - - def _set_dtype(self, val): - self.__dict__["dtype"] = val - - dtype = property(fget=_get_dtype, fset=_set_dtype, ) diff --git a/src/sisl/viz/input_fields/spin.py b/src/sisl/viz/input_fields/spin.py deleted file mode 100644 index bc7191e49a..0000000000 --- a/src/sisl/viz/input_fields/spin.py +++ /dev/null @@ -1,107 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -from sisl import Spin -from sisl._help import isiterable - -from .basic import OptionsInput - - -class SpinSelect(OptionsInput): - """ Input field that helps selecting and managing the desired spin. - - It has a method to update the options according to spin class. - - Parameters - ------------ - only_if_polarized: bool, optional - If set to `True`, the options can only be either [UP, DOWN] or []. - - That is, no extra options for non collinear and spin orbit calculations. - - Defaults to False. - """ - - _default = { - "default": None, - "params": { - "placeholder": "Select spin...", - "options": [], - "isMulti": True, - "isClearable": True, - "isSearchable": True, - }, - "style": { - "width": 200 - } - } - - _options = { - Spin.UNPOLARIZED: [], - Spin.POLARIZED: [{"label": "↑", "value": 0}, {"label": "↓", "value": 1}, - {"label": "Total", "value": "total"}, {"label": "Net z", "value": "z"}], - Spin.NONCOLINEAR: [{"label": val, "value": val} for val in ("total", "x", "y", "z")], - Spin.SPINORBIT: [{"label": val, "value": val} - for val in ("total", "x", "y", "z")] - } - - def __init__(self, *args, only_if_polarized=False, **kwargs): - - super().__init__(*args, **kwargs) - - self._only_if_polarized = only_if_polarized - - def update_options(self, spin, only_if_polarized=None): - """ - Updates the options of the spin selector. - - It does so according to the type of spin that the plot is handling. - - Parameters - ----------- - spin: sisl.Spin, str or int - It is used to indicate the kind of spin. - only_if_polarized: bool, optional - If set to `True`, the options can only be either [UP, DOWN] or []. - - That is, no extra options for non collinear and spin orbit calculations. - - If not provided the initialization value of `only_if_polarized` will be used. - - See also - --------- - sisl.physics.Spin - """ - if not isinstance(spin, Spin): - spin = Spin(spin) - - self.spin = spin - - # Use the default for this input field if only_if_polarized is not provided. - if only_if_polarized is None: - only_if_polarized = self._only_if_polarized - - # Determine what are the new options - if only_if_polarized: - if spin.is_polarized: - options = self._options[Spin.POLARIZED] - else: - options = self._options[Spin.UNPOLARIZED] - else: - options = self._options[spin.kind] - - # Update them - self.modify("inputField.params.options", options) - - return self - - def parse(self, val): - if val is None: - return val - - if not isiterable(val): - val = [val] - - return val diff --git a/src/sisl/viz/plot.py b/src/sisl/viz/plot.py index cbd57941e5..fbbecf8ca7 100644 --- a/src/sisl/viz/plot.py +++ b/src/sisl/viz/plot.py @@ -1,1976 +1,24 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" -This file contains the Plot class, which should be inherited by all plot classes -""" -import inspect -import itertools -import time -import uuid -from copy import deepcopy -from functools import partial -from pathlib import Path -from types import FunctionType, MethodType +from sisl.messages import deprecate +from sisl.nodes import Workflow -import numpy as np -import sisl -from sisl.messages import info, warn - -from ._presets import get_preset -from ._shortcuts import ShortCutable -from .backends._plot_backends import Backends -from .configurable import ( - Configurable, - ConfigurableMeta, - _populate_with_settings, - vizplotly_settings, -) -from .input_fields import ( - BoolInput, - IntegerInput, - ListInput, - OptionsInput, - ProgramaticInput, - SileInput, - TextInput, -) -from .plotutils import ( - call_method_if_present, - check_widgets, - dictOfLists2listOfDicts, - init_multiple_plots, - repeat_if_children, - running_in_notebook, - spoken_message, - trigger_notification, -) - -__all__ = ["Plot", "MultiplePlot", "Animation", "SubPlots"] - - -class PlotMeta(ConfigurableMeta): - - def __call__(cls, *args, **kwargs): - """ This method decides what to return when the plot class is instantiated. - - It is supposed to help the users by making the plot class very functional - without the need for the users to use extra methods. - - It will catch the first argument and initialize the corresponding plot - if the first argument is: - - A string, it will be assumed that it is a path to a file. - - A plotable object (has a _plot attribute) - - Note that both cases are registered in the _plotables.py file, and you - can register new siles/plotables by using the register functions. - """ - if args: - - # This is just so that the plotable framework knows from which plot class - # it is being called so that it can build the corresponding plot. - # Only relevant if the plot is built with obj.plot() - plot_method = kwargs.get("plot_method", cls.suffix()) - - # If a filename is recieved, we will try to find a plot for it - if isinstance(args[0], (str, Path)): - - filename = args[0] - sile = sisl.get_sile(filename) - - if hasattr(sile, "plot"): - plot = sile.plot(**{**kwargs, "method": plot_method}) - else: - raise NotImplementedError( - f'There is no plot implementation for {sile.__class__} yet.') - elif isinstance(args[0], Plot): - plot = args[0].update_settings(**kwargs) - else: - obj = args[0] - # Maybe the first argument is a plotable object (e.g. a geometry) - if hasattr(obj, "plot"): - plot = obj.plot(**{**kwargs, "method": plot_method}) - else: - return object.__new__(cls) - - return plot - - elif 'animate' in kwargs or 'varying' in kwargs or 'subplots' in kwargs: - - methods = {'animate': cls.animated, 'varying': cls.multiple, 'subplots': cls.subplots} - # Retrieve the keyword that was actually passed - # and choose the appropiate method - for keyword in ('animate', 'varying', 'subplots'): - variable_settings = kwargs.pop(keyword, None) - if variable_settings is not None: - method = methods[keyword] - break - - # Normalize all accepted input types to a dict - if isinstance(variable_settings, str): - variable_settings = [variable_settings] - if isinstance(variable_settings, (list, tuple, np.ndarray)): - variable_settings = {key: kwargs.pop(key) for key in variable_settings} - - # Just run the method that will get us the desired plot - plot = method(variable_settings, fixed=kwargs, **kwargs) - - return plot - - return super().__call__(cls, *args, **kwargs) - - -class Plot(ShortCutable, Configurable, metaclass=PlotMeta): - """ Parent class of all plot classes - - Implements general things needed by all plots such as settings and shortcut - management. - - Parameters - ---------- - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - - Attributes - ---------- - settings: dict - contains the values for each setting of the plot. - params: dict - for each setting, contains information about their input field. - - This information might include valid values, for example. - paramGroups: tuple of dicts - contains the different setting groups present in the plot. - - Each group is a dict like { "key": , "name": ,"icon": , "description": } - ... - - """ - _update_methods = { - "read_data": [], - "set_data": [], - "get_figure": [] - } - - _param_groups = ( - { - "key": "dataread", - "name": "Data reading settings", - "icon": "import_export", - "description": "In such a busy world, one may forget how the files are structured in their computer. Please take a moment to make sure your data is being read exactly in the way you expect." - }, - ) - - _parameters = ( - - SileInput( - key = "root_fdf", name = "Path to fdf file", - dtype=sisl.io.siesta.fdfSileSiesta, - group="dataread", - help="Path to the fdf file that is the 'parent' of the results.", - params={ - "placeholder": "Write the path here..." - } - ), - - TextInput( - key="results_path", name = "Path to your results", - group="dataread", - default="", - params={ - "placeholder": "Write the path here..." - }, - help = "Directory where the files with the simulations results are located.
This path has to be relative to the root fdf.", - ), - - ListInput(key="entry_points_order", name="Entry points order", - group="dataread", - default=[], - params={ - "itemInput": OptionsInput(key="-", name="-", params={"options": []}) - }, - help="""Order with which entry points will be attempted.""" - ), - - OptionsInput( - key="backend", name="Backend", - default=None, - params={}, - help="Directory where the files with the simulations results are located.
This path has to be relative to the root fdf.", - ), - - ) - - @property - def read_data_methods(self): - entry_points_names = [entry_point._method.__name__ for entry_point in self.entry_points] - - return ["_before_read", "_after_read", "_read_from_sources", *entry_points_names, *self._update_methods["read_data"]] - - @property - def set_data_methods(self): - return ["_set_data", *self._update_methods["set_data"]] - - @property - def get_figure_methods(self): - return ["_after_get_figure", *self._update_methods["get_figure"]] - - def _parse_update_funcs(self, func_names): - """ Decides which functions to run when the settings of the plot are updated - - This is called in self._run_updates as a final oportunity to decide what functions to run. - - In the case of plots, all we basically want to know is if we need to read the data again (execute - `read_data`), set new data for the plot without needing to read (execute `set_data`) or just update - some aesthetic aspects of the plot (execute `get_figure`). - - When Plot sees one of the following functions in the list of functions with updated parameters: - - "_before_read", "_after_read", any entry point function, functions in cls._update_methods["read_data"] - - it knows that `read_data` needs to be executed. (which afterwards triggers `set_data` and `get_figure`). - - Otherwise, if any of this functions is present: - - "_set_data", functions in cls._update_methods["set_data"] - - it executes `set_data` (and `get_figure` subsequentially) - - Finally, if it finds: - - "_after_get_figure", functions in cls._update_methods["get_figure"] - - it executes `get_figure`. - - WARNING: If it doesn't find any of these, it will return the unparsed list of functions, - and all functions will get executed. - - Parameters - ----------- - func_names: set of str - the unique functions names that are to be executed unless you modify them. - - Returns - ----------- - array-like of str - the final list of functions that will be executed. - - See also - ------------ - ``Configurable._run_updates`` - """ - if len(func_names.intersection(self.read_data_methods)) > 0: - return ["read_data"] - - if len(func_names.intersection(self.set_data_methods)) > 0: - return ["set_data"] - - if len(func_names.intersection(self.get_figure_methods)) > 0: - return ["get_figure"] - - return func_names - - @classmethod - def from_plotly(cls, plotly_fig): - """ Converts a plotly plot to a Plot object - - Parameters - ----------- - plotly_fig: plotly.graph_objs.Figure - the figure that we want to convert into a sisl plot - - Returns - ----------- - Plot - the converted plot that contains the plotly figure. - """ - plot = cls(only_init=True) - plot.figure = plotly_fig - - return plot - - @classmethod - def plot_name(cls): - """ The name of the plot. Used to be displayed in the GUI, for example """ - return getattr(cls, "_plot_type", cls.__name__) - - @classmethod - def suffix(cls): - """ Get the suffix that this class adds to plotting functions - - See sisl/viz/_plotables.py and particularly the `register_plotable` - function to understand this better. - """ - if cls is Plot: - return None - return getattr(cls, "_suffix", cls.__name__.lower().replace("plot", "")) - - @classmethod - def entry_points_help(cls): - """ Generates a helpful message about the entry points of the plot class """ - string = "" - - for entry_point in cls.entry_points: - - string += f"{entry_point._name.capitalize()}\n------------\n\n" - string += (entry_point.help or "").lstrip() - - string += "\nSettings used:\n\t- " - string += '\n\t- '.join(map(lambda ab: ab[1], entry_point._method._settings)) - string += "\n\n" - - return string - - @property - def _innotebook(self): - """ Boolean indicating whether this plot is being used in a notebook - - Used to understand how we should display the plot. - """ - return running_in_notebook() - - @property - def _widgets(self): - """ Dictionary that informs of which jupyter notebook widgets are available """ - return check_widgets() - - @classmethod - def multiple(cls, *args, fixed={}, template_plot=None, merge_method='together', **kwargs): - """ Creates a multiple plot out of a class - - This class method returns a multiple plot that contains several plots of the same class. - It is a general method that serves as a common denominator for building MultiplePlot, SubPlot - or Animation objects. Therefore, it is used by the `animated` and `subplots` methods - - If no arguments are passed, you will get the default multiple plot for the class, if there is any. - - Parameters - ----------- - *args: - Depending on what you pass the arguments will be interpreted differently: - - Two arguments: - First: str - Key of the setting that you want to be variable. - Second: array-like - Values that you want the setting to have in each individual plot. - - Ex: BandsPlot.multiple("bands_file", ["file1", "file2", "file3"] ) - will produce a multiple plot where each plot uses a different bands_file. - - - One argument and it is a dictionary: - First: dict - The keys of this dictionary will be the setting keys you want to be variable - and the values are of course the values for each plot for that setting. - - It works exactly as the previous case, but in this case we have multiple settings that vary. - - - One argument and it is a function: - First: function - With this function you can produce any settings you want without limitations. - - It will be used as the `_getInitKwargsList` method of the MultiplePlot object, so it needs to - accept self (the MultiplePlot object) and return a list of dictionaries which are the - settings for each plot. - - fixed: dict, optional - A dictionary containing values for settings that will be fixed for all plots. - For the settings that you don't specify here you will get the defaults. - template_plot: sisl Plot, optional - If provided this plot will be used as a template. - - It is important to know that it will not act only as the settings template, - but it will also PROVIDE DATA FOR THE OTHER PLOTS in case the data reading - settings are not being varied through the different plots. - - This is extremely important to provide when possible, because in some cases the data - that a plot gathers can be very large and therefore it may not be even feasable to store - the repeated data in terms of memory/time. - merge_method: {'together', 'subplots', 'animation'}, optional - the way in which the multiple plots are 'related' to each other (i.e. how they should be displayed). - In most cases, instead of using this argument, you should probably use the specific method (`animated` or - `subplots`). They set this argument accordingly but also do some other work to make your life easier. - **kwargs: - Will be passed directly to initialization, so it can contain the settings for the MultiplePlot object, for example. - - If args are not passed and the default multiple plot is being created, some keyword arguments may be used by the method - that generates the default multiple plot. One recurrent example of this is the keyword `wdir`. - - Returns - -------- - MultiplePlot, SubPlots or Animation - The plot that you asked for. - """ - #Try to retrieve the default animation if no arguments are provided - if len(args) == 0: - - return call_method_if_present(cls, "_default_animation", fixed=fixed, **kwargs) - - #Define how the getInitkwargsList method will look like - if callable(args[0]): - _getInitKwargsList = args[0] - else: - if len(args) == 2: - variable_settings = {args[0]: args[1]} - elif isinstance(args[0], dict): - variable_settings = args[0] - - def _getInitKwargsList(self): - - #Adding the fixed values to the list - vals = { - **{key: itertools.repeat(val) for key, val in fixed.items()}, - **variable_settings - } - - return dictOfLists2listOfDicts(vals) - - # Choose the specific class that we want to initialize - MultipleClass = {'together': MultiplePlot, 'subplots': SubPlots, 'animation': Animation}[merge_method] - - #Return the initialized multiple plot - return MultipleClass(_plugins={ - "_getInitKwargsList": _getInitKwargsList, - "_plot_classes": cls, - **kwargs.pop('_plugins', {}) - }, template_plot=template_plot, **kwargs) - - @classmethod - def subplots(cls, *args, fixed={}, template_plot=None, rows=None, cols=None, arrange="rows", **kwargs): - """ Creates subplots where each plot has different settings - - Mainly, it uses the `multiple` method to generate them. - - Parameters - ----------- - *args: - Depending on what you pass the arguments will be interpreted differently: - - Two arguments: - First: str - Key of the setting that you want to vary across subplots. - Second: array-like - Values that you want the setting to have in each subplot. - - Ex: BandsPlot.multiple("bands_file", ["file1", "file2", "file3"] ) - will produce a layout where each subplot uses a different bands_file. - - - One argument and it is a dictionary: - First: dict - The keys of this dictionary will be the setting keys you want to vary across subplots - and the values are of course the values for each plot for that setting. - - It works exactly as the previous case, but in this case we have multiple settings that vary. - - - One argument and it is a function: - First: function - With this function you can produce any settings you want without limitations. - - It will be used as the `_getInitKwargsList` method of the MultiplePlot object, so it needs to - accept self (the MultiplePlot object) and return a list of dictionaries which are the - settings for each plot. - - fixed: dict, optional - A dictionary containing values for settings that will be fixed for all subplots. - For the settings that you don't specify here you will get the defaults. - template_plot: sisl Plot, optional - If provided this plot will be used as a template. - - It is important to know that it will not act only as the settings template, - but it will also PROVIDE DATA FOR THE OTHER PLOTS in case the data reading - settings are not being varied through the different plots. - - This is extremely important to provide when possible, because in some cases the data - that a plot gathers can be very large and therefore it may not be even feasable to store - the repeated data in terms of memory/time. - rows: int, optional - The number of rows of the plot grid. If not provided, it will be inferred from `cols` - and the number of plots. If neither `cols` or `rows` are provided, the `arrange` parameter will decide - how the layout should look like. - cols: int, optional - The number of columns of the subplot grid. If not provided, it will be inferred from `rows` - and the number of plots. If neither `cols` or `rows` are provided, the `arrange` parameter will decide - how the layout should look like. - arrange: {'rows', 'col', 'square'}, optional - The way in which subplots should be aranged if the `rows` and/or `cols` - parameters are not provided. - **kwargs: - Will be passed directly to SubPlots initialization, so it can contain the settings for it, for example. - - If args are not passed and the default multiple plot is being created, some keyword arguments may be used by the method - that generates the default multiple plot. One recurrent example of this is the keyword `wdir`. - """ - return cls.multiple(*args, fixed=fixed, template_plot=None, merge_method='subplots', - rows=rows, cols=cols, arrange=arrange, **kwargs) - - @classmethod - def animated(cls, *args, fixed={}, frame_names=None, template_plot=None, **kwargs): - """ Creates an animation out of a class - - This class method returns an animation with frames belonging to a given plot class. - - For example, if you run `BandsPlot.animated()` you will get an animation made of bands plots. - - If no arguments are passed, you will get the default animation for that plot, if there is any. - - Parameters - ----------- - *args: - Depending on what you pass the arguments will be interpreted differently: - - Two arguments: - First: str - Key of the setting that you want to animate. - Second: array-like - Values that you want the setting to have at each animation frame. - - Ex: BandsPlot.animated("bands_file", ["file1", "file2", "file3"] ) - will produce an animation where each frame uses a different bands_file. - - - One argument and it is a dictionary: - First: dict - The keys of this dictionary will be the setting keys you want to animate - and the values are of course the values for each frame for that setting. - - It works exactly as the previous case, but in this case we have multiple settings to animate. - - - One argument and it is a function: - First: function - With this function you can produce any settings you want without limitations. - - It will be used as the `_getInitKwargsList` method of the animation, so it needs to - accept self (the animation object) and return a list of dictionaries which are the - settings for each frame. - - the function will recieve the parameter and can act on it in any way you like. - It doesn't need to return the parameter, just modify it. - In this function, you can call predefined methods of the parameter, for example. - - Ex: obj.modify_param("length", lambda param: param.incrementByOne() ) - - given that you know that this type of parameter has this method. - fixed: dict, optional - A dictionary containing values for settings that will be fixed along the animation. - For the settings that you don't specify here you will get the defaults. - frame_names: list of str or function, optional - If it is a list of strings, each string will be used as the name for the corresponding frame. - - If it is a function, it should accept `self` (the animation object) and return a list of strings - with the frame names. Note that you can access the plot instance responsible for each frame under - `self.children`. The function will be run each time the figure is generated, so in this way your - frame names will be dynamic. - - FRAME NAMES SHOULD BE UNIQUE, OTHERWISE THE ANIMATION WILL HAVE A WEIRD BEHAVIOR. - - If this is not provided, frame names will be generated automatically. - template_plot: sisl Plot, optional - If provided this plot will be used as a template. - - It is important to know that it will not act only as the settings template, - but it also will PROVIDE DATA FOR THE OTHER PLOTS in case the data reading - settings are not animated. - - This is extremely important to provide when possible, because in some cases the data - that a plot gathers can be very large and therefore it may not be even feasable to store - the repeated data in terms of memory/time. - **kwargs: - Will be passed directly to animation initialization, so it can contain the settings for the animation, for example. - - If args are not passed and the default animation is being created. Some keyword arguments may be used by the method - that generates the default animation. One recurrent example of this is the keyword `wdir`. - - Returns - -------- - Animation - The Animation that you asked for - """ - # And just let the general multiple plot creator do the work - return cls.multiple(*args, fixed=fixed, template_plot=template_plot, merge_method='animation', - frame_names=frame_names, **kwargs) - - def __init_subclass__(cls): - """ Whenever a plot class is defined, this method is called - - We will use this opportunity to: - - Register entry points. - - Generate a more helpful __init__ method that exposes all the settings. - - We could use this to register plotables (see commented code). - However, there is one major problem: how to specify defaults. - This is a problem because sometimes an input field is inherited from one plot - to another, therefore you can not say: "this is the default plotable input". - - Probably, defaults should be centralized, but I don't know where just yet. - """ - super().__init_subclass__() - - # Register the entry points of this class. - cls.entry_points = [] - for key, val in inspect.getmembers(cls, lambda x: isinstance(x, EntryPoint)): - cls.entry_points.append(val) - # After registering an entry point, we will just set the method - setattr(cls, key, _populate_with_settings(val._method, [param["key"] for param in cls._get_class_params()[0]])) - - entry_points_order = cls.get_class_param("entry_points_order") - entry_points_order.modify_item_input( - "inputField.params.options", - [{"label": entry._name, "value": entry._name} for entry in cls.entry_points] - ) - entry_points_order.modify( - "default", - [entry._name for entry in sorted(cls.entry_points, key=lambda entry: entry._sort_key)] - ) - - cls.backends = Backends(cls) - - @vizplotly_settings('before', init=True) - def __init__(self, *args, H = None, attrs_for_plot={}, only_init=False, presets=None, layout={}, _debug=False, **kwargs): - # Give an ID to the plot - self.id = str(uuid.uuid4()) - - # Inform whether the plot is in debug mode or not: - self._debug = _debug - - # Initialize shortcut management - ShortCutable.__init__(self) - - # Give the user the possibility to do things before initialization (IDK why) - call_method_if_present(self, "_before_init") - - #Set the isChildPlot attribute to let the plot know if it is part of a bigger picture (e.g. Animation) - self.isChildPlot = kwargs.get("isChildPlot", False) - - #Initialize the variable to store when has been the last data read (0 means never basically) - self.last_dataread = 0 - self._files_to_follow = [] - - # Check if the user has provided a hamiltonian (which can contain a geometry) - # This is not meant to be used by the GUI (in principle), just programatically - self.PROVIDED_H = False - self.PROVIDED_GEOM = False - if H is not None: - self.PROVIDED_H = True - self.H = H - self.setup_hamiltonian() - - if presets is not None: - if isinstance(presets, str): - presets = [presets] - - # on_figure_change is triggered after get_figure. - self.on_figure_change = None - - # Set all the attributes that have been passed - # It is important that this is here so that it can overwrite any of - # the already written attributes - for key, val in attrs_for_plot.items(): - setattr(self, key, val) - - #If plugins have been provided, then add them. - #Plugins are an easy way of extending a plot. They can be methods, variables... - #They are added to the object instance, not the whole class. - if kwargs.get("_plugins"): - for name, plugin in kwargs.get("_plugins").items(): - if isinstance(plugin, FunctionType): - plugin = MethodType(plugin, self) - setattr(self, name, plugin) - - # Add the general plot shortcuts - self._general_plot_shortcuts() - - #Give the user the possibility to overwrite default settings - call_method_if_present(self, "_after_init") - - # If we were supposed to only initialize the plot, stop here - if only_init: - return - - #Try to generate the figure (if the settings required are still not there, it won't be generated) - try: - if MultiplePlot in type.mro(self.__class__): - #If its a multiple plot try to inititialize all its child plots - if self.PLOTS_PROVIDED: - self.get_figure() - else: - self.init_all_plots() - else: - self.read_data() - - except Exception as e: - if self._debug: - raise e - info(f"The plot has been initialized correctly, but the current settings were not enough to generate the figure.\nError: {e}") - - def __str__(self): - """ Information to print about the plot """ - string = ( - f'Plot class: {self.__class__.__name__} Plot type: {self.plot_name()}\n\n' - 'Settings:\n{}'.format("\n".join(["\t- {}: {}".format(key, value) for key, value in self.settings.items()])) - ) - - return string +class Plot(Workflow): + """Base class for all plots""" def __getattr__(self, key): - """ This method is executed only after python has found that there is no such attribute in the instance - - So let's try to find it elsewhere. There are two options: - - The attribute is in the backend object (self._backend) - - The attribute is currently being shared with other plots (only possible if it's a childplot) - """ - if key in ["_backend", "_get_shared_attr"]: - pass - elif hasattr(self, "_backend") and hasattr(self._backend, key): - return getattr(self._backend, key) - else: - #If it is a childPlot, maybe the attribute is in the shared storage to save memory and time - try: - return self._get_shared_attr(key) - except (KeyError, AttributeError): - pass - - raise AttributeError(f"The attribute '{key}' was not found either in the plot, its backend, or in shared attributes.") - - def __setattr__(self, key, val): - """ - If is a childplot and it has the attribute `_SHOULD_SHARE_WITH_SIBLINGS` set to True, we will submit the attribute to the shared store. - This happens in animations/multiple plots. There's a "leading plot" that reads the data and then shares it with the rest - so that they don't need to read it again, in a collective effort to save memory and time. - - Otherwise we set the attribute to the plot itself. - """ - if key != '_SHOULD_SHARE_WITH_SIBLINGS' and getattr(self, '_SHOULD_SHARE_WITH_SIBLINGS', False): - self.share_attr(key, val) - else: - object.__setattr__(self, key, val) - - def __getitem__(self, key): - """ Getting an item from plot returns the trace(s) that correspond to the requested indices """ - if isinstance(key, (int, slice)): - return self.data[key] - - def _general_plot_shortcuts(self): - """ In this method we set the shortcuts that are general to all plots - - This is called in `__init__` - """ - self._listening_shortcut() - - self.add_shortcut("ctrl+z", "Undo settings", self.undo_settings, _description="Takes the settings of the plot one step back") - - @repeat_if_children - @vizplotly_settings('before') - def read_data(self, update_fig=True, **kwargs): - """ This method is responsible for organizing the data-reading step - - If everything is done succesfully, it calls the next step (`set_data`) - """ - # Restart the files_to_follow variable so that we can start to fill it with the new files - # Apart from the explicit call in this method, setFiles and setup_hamiltonian also add files to follow - self._files_to_follow = [] - - call_method_if_present(self, "_before_read") - - # We try to read from the different entry points available - self._read_from_sources() - - # We don't update the last dataread here in case there has been a succesful data read because we want to - # wait for the after_read() method to be succesful - if self.source is None: - self.last_dataread = 0 - - call_method_if_present(self, "_after_read") - - if self.source is not None: - self.last_dataread = time.time() - - if update_fig: - self.set_data(update_fig = update_fig) - - return self - - def _read_from_sources(self, entry_points_order): - """ Tries to read the data from the different available entry points in the plot class - - If it fails to read from all entry points, it raises an exception. - """ - # It is possible that the class does not implement any entry points, - # because it doesn't need to read any data. Then the plotting process - # will basically start at set_data. - if not self.entry_points: - return - - errors = [] - # Try to read data using all the different entry points - # This is just a first implementation. One of the reasons entry points - # have been implemented is that we can do smarter things than this. - for entry_point_name in entry_points_order: - for entry_point in self.entry_points: - if entry_point._name == entry_point_name: - break - else: - warn(f"Entry point {entry_point_name} not found in {self.__class__.__name__}") - continue - - try: - returns = getattr(self, entry_point._method_attr)() - self.source = entry_point - return returns - except Exception as e: - errors.append("\t- {}: {}.{}".format(entry_point._name, type(e).__name__, e)) + if key != "nodes": + return getattr(self.nodes.output.get(), key) else: - self.source = None - raise ValueError("Could not read or generate data for {} from any of the possible sources.\nHere are the errors for each source:\n{}" - .format(self.__class__.__name__, "\n".join(errors))) - - def follow(self, *files, to_abs=True, unfollow=False): - """ Makes sure that the object knows which files to follow in order to trigger updates - - Parameters - ---------- - *files: str - a string that represents the path to the file that needs to be followed. - - You can pass as many as you want as separate arguments. Note that if you have a list of - files you can pass them separately by doing `follow(*my_list_of_files)`, you don't need to - (and you shouldn't) build a loop :) - to_abs: boolean, optional - whether the paths should be converted to absolute paths to make file following procedures - more robust. It is better to leave it as True unless you have a good reason to change it. - unfollow: boolean, optional - whether the previous files should be unfollowed. If set to False, we are just adding more files. - """ - new_files = [Path(file_path).resolve() if to_abs else Path(file_path) for file_path in files or []] - - self._files_to_follow = new_files if unfollow else [*self._files_to_follow, *new_files] - - def get_sile(self, path, results_path, root_fdf, *args, follow=True, follow_kwargs={}, file_contents=None, **kwargs): - """ A wrapper around get_sile so that the reading of the file is registered - - It has to main functions: - - Automatically following files that are read, so that you don't neet to go always like: - - ``` - self.follow(file) - sisl.get_sile(file) - ``` - - Infering files from a root file. For example, using the root_fdf. - - Parameters - ---------- - path: str - the path to the file that you want to read. - It can also be the setting key that you want to read. - *args: - passed to sisl.get_sile - follow: boolean, optional - whether the path should be followed. - follow_kwargs: dict, optional - dictionary of keywords that are passed directly to the follow method. - **kwargs: - passed to sisl.get_sile - """ - # If path is a setting name, retrieve it - if path in self.settings: - setting_key = path - path = self.get_setting(path) - - # However, if it wasn't provided, try to infer it. - # For example, if it is a siesta sile, we will try to infer it - # from the fdf file - if not path: - - sile_type = self.get_param(setting_key).dtype - # We need to check here if it is a SIESTA related sile! - - fdf_sile = sisl.get_sile(root_fdf) - - for rule in sisl.get_sile_rules(cls=sile_type): - filename = fdf_sile.get('SystemLabel', default='siesta') + f'.{rule.suffix}' - try: - path = fdf_sile.dir_file(filename, results_path) - return self.get_sile(path, *args, follow=True, follow_kwargs={}, file_contents=None, **kwargs) - except Exception: - pass - else: - raise FileNotFoundError(f"Tried to infer {setting_key} from the 'root_fdf', " - f"but didn't find any {sile_type.__name__} in {Path(fdf_sile._directory) / results_path }") - - if follow: - self.follow(path, **follow_kwargs) - - return sisl.get_sile(path, *args, **kwargs) - - def updates_available(self): - """ This function checks whether the read files have changed - - For it to work properly, one should specify the files that have been read by - their reading methods (usually, the entry points). This is done by using the - `follow()` method or by reading files with `self.get_sile()` instead of `sisl.get_sile()`. - """ - def modified(filepath): - - try: - return filepath.stat().st_mtime > self.last_dataread - except FileNotFoundError: - return False # This probably should implement better logic - - files_modified = np.array([modified(file_path) for file_path in self._files_to_follow]) - - return files_modified.any() - - def listen(self, forever=True, show=True, as_animation=False, return_animation=True, return_figWidget=False, - clear_previous=True, notify=False, speak=False, notify_title=None, notify_message=None, speak_message=None, fig_widget=None): - """ Listens for updates in the followed files (see the `updates_available` method) - - Parameters - --------- - forever: boolean, optional - whether to keep listening after the first plot update - show: boolean, optional - whether to show the plot at the beggining and update the layout when the plot is updated. - as_animation: boolean, optional - will add a new frame each time the plot is updated. - - The resulting animation is returned unless return_animation is set to False. This is done because - the Plot object iself is not converted to an animation. Instead, a new animation is created and if you - don't save it in a variable it will be lost, you will have no way to access it later. - - If you are seeing two figures at the beggining, it is because you are not storing the animation figure. - Set the return_animation parameter to False if you understand that you are going to "lose" the animation, - you will only be able to see a display of it while it is there. - return_animation: boolean, optional - if as_animation is `True`, whether the animation should be returned. - Important: see as_animation for an explanation on why this is the case - return_figWidget: boolean, optional - it returns the figure widget that is in display in a jupyter notebook in case the plot has - succeeded to display it. Note that, even if you are in a jupyter notebook, you won't get a figure - widget if you don't have the plotly notebook extension enabled. Check `._widgets` to see - if you are missing witget support. - - if return_animation is True, both the animation and the figure widget will be returned in a tuple. - Although right now, this does not make much sense because figure widgets don't support frames. You will get None. - clear_previous: boolean, optional - in case show is True, whether the previous version of the plot should be hidden or kept in display. - notify: boolean, optional - trigger a notification everytime the plot updates. - speak: boolean, optional - trigger a spoken message everytime the plot updates. - notify_title: str, optional - the title of the notification. - notify_message: str, optional - the message of the notification. - speak_message: str, optional - the spoken message. Feel free to get creative here! - """ - import asyncio - - from IPython.display import clear_output - - # This is a weird limitation, because multiple listeners could definitely - # be implemented, but I don't have time now, and I need to ensure that no listeners are left untracked - # If you need it, ask me! (Pol) - self.stop_listening() - - pt = self - - if as_animation: - pt = Animation( - plots = [self.clone()] - ) - - if show and fig_widget is None: - fig = pt.show(return_figWidget=True) - fig_widget = fig - - if notify: - trigger_notification("SISL", "Notifications will appear here") - if speak: - spoken_message("I will speak when there is an update.") - - async def listen(): - while True: - if self.updates_available(): - try: - - self.read_data(update_fig=True) - - if as_animation: - new_plot = self.clone() - pt.add_children(new_plot) - pt.get_figure() - - if clear_previous and fig_widget is None: - clear_output() - - if show and fig_widget is None: - pt.show() - else: - pt._backend._update_ipywidget(fig_widget) - - if not forever: - self._listening_task.cancel() - - if notify: - title = notify_title or "SISL PLOT UPDATE" - message = notify_message or f"{getattr(self, 'struct', '')} {self.__class__.__name__} updated" - trigger_notification(title, message) - if speak: - spoken_message(speak_message if speak_message is not None else f"Your {self.__class__.__name__} is updated. Check it out") - - except Exception as e: - pass - - await asyncio.sleep(1) - - loop = asyncio.get_event_loop() - self._listening_task = loop.create_task(listen()) - - self.add_shortcut("ctrl+alt+l", "Stop listening", self.stop_listening, fig_widget=fig_widget, _description="Tell the plot to stop listening for updates") - - if as_animation and return_animation: - if return_figWidget: - return pt, fig_widget - else: - return pt - elif return_figWidget: - return fig_widget - - def _listening_shortcut(self, fig_widget=None): - """ Adds the shortcut to start listening for updates - - This is done here and not in `_general_plot_settings` because - we need to be able to toggle this shortcut each time the plot - starts/stops listening. - - Maybe at some point we can have a rule to automatically disable - shortcuts based on state in `ShortCutable`. - """ - self.add_shortcut( - "ctrl+alt+l", "Listen for updates", - self.listen, fig_widget=fig_widget, - _description="Make the plot listen for changes in the files that it reads" - ) - - def stop_listening(self, fig_widget=None): - """ Makes the plot stop listening for updates - - Using this method only makes sense if you have previously made the plot listen - either through `Plot.listen()` or `Plot.show(listen=True)` - - Parameters - ----------- - fig_widget: plotly FigureWidget, optional - the figure widget where the plot is currently being displayed. - - This is just used to reset the listening shortcut. - - YOU WILL MOST LIKELY NOT USE THIS because `Plot` already knows - where is it being displayed in normal situations. - """ - task = getattr(self, "_listening_task", None) - - if task is not None: - task.cancel() - self._listening_task = None - self._listening_shortcut(fig_widget=fig_widget) - - return self - - @vizplotly_settings('before') - def setup_hamiltonian(self, **kwargs): - """ Sets up the hamiltonian for calculations with sisl """ - NEW_FDF = True - if len(self.settings_history) > 1: - NEW_FDF = self.settings_history.was_updated("root_fdf") - - if not hasattr(self, "geometry") or NEW_FDF: - try: - fdf_sile = self.get_sile("root_fdf") - self.geometry = fdf_sile.read_geometry(output = True) - except Exception: - pass - - if not self.PROVIDED_H and (not hasattr(self, "H") or NEW_FDF): - #Read the hamiltonian - fdf_sile = self.get_sile("root_fdf") - self.H = fdf_sile.read_hamiltonian() - else: - if isinstance(self.H, (str, Path)): - self.H = self.get_sile(self.H) - - if isinstance(self.H, sisl.BaseSile): - self.H = self.H.read_hamiltonian(geometry=getattr(self, "geometry", None)) - - if not hasattr(self, "geometry"): - self.geometry = self.H.geometry - - return self - - @repeat_if_children - @vizplotly_settings('before') - def set_data(self, update_fig = True, **kwargs): - """ Method to process the data that has been read beforehand by read_data() and prepare the figure - - If everything is succesful, it calls the next step in plotting (`get_figure`) - """ - - self._for_backend = self._set_data() - - if update_fig: - self.get_figure() - - return self - - def get_figure(self, backend, clear_fig=True, **kwargs): - """ - Generates a figure out of the already processed data. - - Parameters - --------- - clear_fig: boolean, optional - whether the figure should be cleared before drawing. - - Returns - --------- - self.figure: plotly.graph_objs.Figure - the plotly figure. - """ - # Initialize the backend - if backend is None: - # It is possible to not use any plotting backend. - # In that case, we just simply process the data, but we do not plot anything - return - self.backends.setup(self, backend) - - if clear_fig: - # Clear all the traces from the figure before drawing the new ones - self.clear() - - self.draw(getattr(self, "_for_backend", None)) - - call_method_if_present(self, '_after_get_figure') - - call_method_if_present(self, 'on_figure_change') - - return self - - #------------------------------------------- - # PLOT DISPLAY METHODS - #------------------------------------------- - - def show(self, *args, listen=False, return_figWidget=False, **kwargs): - """ Displays the plot - - Parameters - ------ - listen: bool, optional - after showing, keeps listening for file changes to update the plot. - This is nice for monitoring. - return_figureWidget: bool, optional - if the plot is displayed in a jupyter notebook, whether you want to - get the figure widget as a return so that you can act on it. - """ - if self._backend is None: - return warn("There is no plotting backend selected, so the plot can't be displayed.") - - if listen: - self.listen(show=True, **kwargs) - - if self._innotebook and (len(args) == 0 or 'config' in kwargs): - try: - return self._ipython_display_(listen=listen, return_figWidget=return_figWidget, **kwargs) - except Exception as e: - warn(e) - - return self._backend.show(*args, **kwargs) - - def _ipython_display_(self, return_figWidget=False, **kwargs): - """ Handles all things needed to display the plot in a jupyter notebook - - Plotly already knows how to show a plot in the jupyter notebook, however - here we try to extend it to support shortcuts if the appropiate widget is - there (ipyevents, https://github.com/mwcraig/ipyevents). - - Parameters - ------ - return_figureWidget: bool, optional - if the plot is displayed in a jupyter notebook, whether you want to - get the figure widget as a return so that you can act on it. - """ - from IPython.display import display - - def _try_backend(): - if self._backend is None: - return display(repr(self)) - - kwargs.pop("listen", None) - display_method = getattr(self._backend, "_ipython_display_", None) - if display_method is not None: - return display_method(**kwargs) - self._backend.show(**kwargs) - - if not isinstance(self, Animation): - - try: - widget = self._backend.get_ipywidget() - except Exception: - return _try_backend() - - if False and self._widgets["events"]: - # For now we want provide keyboard shortcut support - # If ipyevents is available, show with shortcut support - self._ipython_display_with_shortcuts(widget, **kwargs) - else: - # Else, show without shortcut support - display(widget) - - self._listening_shortcut(fig_widget=widget) - - if return_figWidget: - return widget - - else: - _try_backend() - - def _ipython_display_with_shortcuts(self, fig_widget, **kwargs): - """ - If the appropiate widget is there (ipyevents, https://github.com/mwcraig/ipyevents), - we extend plotly's FigureWidget to support keypress events so that we can trigger - shortcuts from the notebook. - - Parameters - ------ - fig_widget: plotly.graph_objs.FigureWidget - The figure widget that we need to extend. - """ - from ipyevents import Event - from IPython.display import display - from ipywidgets import HTML, Output - - h = HTML("") # This is to display help such as available shortcuts - messages = HTML("") # This is to inform about current status - styles = HTML("") - d = Event(source=fig_widget, watched_events=['keydown', 'keyup']) - - def handle_dom_events(event, keys_down=[], last_timestamp=[0], keys_up=[]): - # We will keep track of keydowns because then we will be able to support multiple keys shortcuts - time_threshold = 500 #To remove key up events - - try: - # Clear the list - timestamp = event.get("timeStamp") - duplicates = len(keys_down) != len(set(keys_down)) - time_diff = timestamp - last_timestamp[0] - if time_diff > 2000 or duplicates: - keys_down *= 0 #Clear the list - if time_diff > time_threshold: - keys_up *= 0 - - last_timestamp[0] = timestamp - - # This means that the key has been held down for a long time - if event.get("repeat", False): - return - - key = event.get("key", "").lower() - key_code = event.get("code", "") - - ev_type = event.get("type", None) - - if ev_type == "keydown": - if key == "control": - key = "ctrl" - # If it's a key down event, record it - keys_down.append(key) - elif ev_type == "keyup" and key in keys_down: - if key == "control": - key = "ctrl" - if len(keys_down) == 1: - keys_up.append(key) - # If it's a key up event, anounce that the key is not down anymore - keys_down.remove(key) - - only_down = "+".join(keys_down) - shortcut_key = f'{" ".join(keys_up)} {only_down}'.strip() - - if shortcut_key: - messages.value = f'{shortcut_key}' - - # Get the help message - if shortcut_key == "shift+?": - h.value = self.shortcuts_summary("html") if not h.value else "" - - shortcut = self.shortcut(shortcut_key) - if shortcut is not None: - keys_down *= 0 - keys_up *= 0 - - messages.value = f'Executing "{shortcut["name"]}" because you pressed "{shortcut_key}"...' - self.call_shortcut(shortcut_key) - messages.value = "" - - self._update_ipywidget(fig_widget) - - except Exception as e: - messages.value = f'{e}' - - d.on_dom_event(partial(handle_dom_events)) - - display(fig_widget, messages, h, styles, Output()) - - #------------------------------------------- - # PLOT MANIPULATION METHODS - #------------------------------------------- - - def merge(self, others, to="multiple", extend_multiples=True, **kwargs): - """ Merges this plot's instance with the list of plots provided - - Parameters - ------- - others: array-like of Plot() or Plot() - the plots that we want to merge with this plot instance. - to: {"multiple", "subplots", "animation"}, optional - the merge method. Each option results in a different way of putting all the plots - together: - - "multiple": All plots are shown in the same canvas at the same time. Useful for direct - comparison. - - "subplots": The layout is divided in different subplots. - - "animation": Each plot is converted into the frame of an animation. - extend_multiples: boolean, optional - if True, if `MultiplePlot`s are passed, they are splitted into their children, so that the result - is the merge of its children with the rest. - If False, a `MultiplePlot` is treated as a solid unit. - kwargs: - extra arguments that are directly passed to `MultiplePlot`, `Subplots` - or `Animation` initialization. - - Returns - ------- - MultiplePlot, Subplots or Animation - depending on the value of the `to` parameter. - """ - #Make sure we deal with a list (user can provide a single plot) - if not isinstance(others, (list, tuple, np.ndarray)): - others = [others] - - children = [self, *others] - if extend_multiples: - children = [[pt] if not isinstance(pt, MultiplePlot) else pt.children for pt in children] - # Flatten the list - children = [pt for plots in children for pt in plots] - - PlotClass = { - "multiple": MultiplePlot, - "subplots": SubPlots, - "animation": Animation - }[to] - - return PlotClass(plots=children, **kwargs) - - def copy(self): - """ Returns a copy of the plot - - If you want a plot with the exact plot configuration but newly initialized, - use `clone()` instead. - """ - return deepcopy(self) - - def clone(self, *args, **kwargs): - """ Gets you and exact clone of this plot - - You can pass extra args that will overwrite the previous parameters though, if you don't want it to be that exact. - - IMPORTANT: IT WILL INITIALIZE A NEW PLOT, THEREFORE IT WILL READ NEW DATA. - IF YOU JUST WANT A COPY, USE THE `copy()` method. - """ - return deepcopy(self) - - return self.__class__(*args, **self.settings, **kwargs) - - #------------------------------------------- - # LISTENING TO EVENTS - #------------------------------------------- - - def dispatch_event(self, event, *args, **kwargs): - """ Not functional yet """ - warn((event, args, kwargs)) - # Of course this needs to be done - raise NotImplementedError - - #------------------------------------------- - # DATA TRANSFER/STORAGE METHODS - #------------------------------------------- - def __getstate__(self): - """Returns the object to be pickled""" - # We just simply remove any sile from the settings history (as they are not pickleable) - # and replace it with a posix path. Note that this does not fix the problem if there are - # nested siles. - for key, item in self.settings_history._vals.items(): - self.settings_history._vals[key] = [val.file if isinstance(val, sisl.io.BaseSile) else val for val in item] - return self.__dict__ - - def save(self, path): - """ Saves the plot so that it can be loaded in the future - - Parameters - --------- - path: str - The path to the file where you want to save the plot - - Returns - --------- - self - """ - import dill - - if isinstance(path, str): - path = Path(path) - - with open(path, 'wb') as handle: - dill.dump(self, handle, protocol=dill.HIGHEST_PROTOCOL) - - return True - - -class EntryPoint: - - def __init__(self, name, sort_key, setting_key, method, instance=None): - self._name = name - self._sort_key = sort_key - self._method_attr = method.__name__ - self._setting_key = setting_key - self._method = method - self.help = method.__doc__ - - -def entry_point(name, sort_key=0): - """ Helps registering entry points for plots - - See the usage section to get a fast intuitive way of how to use it. - - Basically, you need to provide some parameters (which are described - in the parameters section), and this function will return a decorator that - you can use in the functions of your plot class that do the reading part. - - A function that is meant to read data but it's not marked as an entry_point - will be invisible to Plot. - - NOTE: A plot class can have no entry points. This is perfectly fine if the - class does not need to read data for some reason. In this case, we will go straight - into the data setting methods (i.e. set_data). - - Examples - ----------- - - >>> class MyPlot(Plot): - >>> @entry_point('siesta_output') - >>> def _lets_read_from_siesta_output(self): - >>> ...do some work here - >>> - >>> @entry_point('ask_mum'): - >>> def _we_are_quite_lost_so_we_better_ask_mum(self): - >>> self.call_mum() - - Parameters - ----------- - name: str - the name of the entry point that the decorated function implements. - sort_key: any - the entry points order will be sorted according to this key. - """ - return partial(EntryPoint, name, sort_key, ()) - -#------------------------------------------------ -# CLASSES TO SUPPORT COMPOSITE PLOTS -#------------------------------------------------ - - -class MultiplePlot(Plot): - """ General handler of a group of plots that need to be rendered together - - Parameters - ---------- - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - """ - - _trigger_kw = "varying" - - def __init__(self, *args, plots=None, template_plot=None, **kwargs): - self.shared = {} - - # Take the plots if they have already been created and are provided by the user - self.PLOTS_PROVIDED = plots is not None - if self.PLOTS_PROVIDED: - self.set_children(plots) - - self.has_template_plot = False - if isinstance(template_plot, Plot): - self.template_plot = template_plot - self.has_template_plot = True - - super().__init__(*args, **kwargs) - - def __getitem__(self, i): - """ Gets a given child plot """ - return self.children[i] - - @staticmethod - def _kw_from_cls(cls): - return cls._trigger_kw - - @staticmethod - def _cls_from_kw(key): - for cls in MultiplePlot.__subclasses__(): - if cls._trigger_kw == key: - return cls - else: - return None - - @property - def _attrs_for_children(self): - """ Returns all the attributes that its children should have """ - return { - 'isChildPlot': True, - '_get_shared_attr': lambda key: self.shared_attr(key), - 'share_attr': lambda key, val: self.set_shared_attr(key, val) - } - - def init_all_plots(self, update_fig=True, try_sharing=True): - """ Initializes all child plots - - Parameters - ----------- - update_fig: boolean, optional - whether we should build the figure if this step is succesful. - try_sharing: boolean, optional - If `True`, we will check if all plots have exactly the same settings to read - data and, in that case, they will share attributes to avoid memory waste. - - This is specially essential for plots that read big amounts of data (e.g. `GridPlot`), - but also for those that take significant time to read it. - """ - if not self.PLOTS_PROVIDED: - - # If there is a template plot, take its settings as the starting point - template_settings={} - if self.has_template_plot: - self._plot_classes = self.template_plot.__class__ - template_settings = self.template_plot.settings - - SINGLE_CLASS = isinstance(self._plot_classes, type) - - # Initialize all the plots - # In case there is only one class only initialize them, avoid reading data - # In this case, it is extremely important to initialize them all in serial mode, because - # with multiprocessing they won't know that the current instance is their parent - # (objects get copied in multiprocessing) and they won't be able to share data - plots = init_multiple_plots( - self._plot_classes, - kwargsList = [ - {**template_settings, **kwargs, "attrs_for_plot": self._attrs_for_children, "only_init": SINGLE_CLASS and try_sharing} - for kwargs in self._getInitKwargsList() - ], - serial=SINGLE_CLASS and try_sharing - ) - - if SINGLE_CLASS and try_sharing: - - if not self.has_template_plot: - # Our leading plot will be the first one - leading_plot = plots[0] - else: - leading_plot = self.template_plot - - # Now, we get the settings of the first plot - read_data_settings = { - key: leading_plot.get_setting(key) for key, funcs in leading_plot._run_on_update.items() - if set(funcs).intersection(leading_plot.read_data_methods) - } - - for i, plot in enumerate(plots): - if not plot.has_these_settings(read_data_settings): - # If there is a plot that needs to read different data, we will just - # make each of them read their own data. (this could be optimized by grouping plots) - self.init_all_plots(try_sharing=False) - break - else: - # In case there is no plot that has different settings, we will - # happily set the data, avoiding the read data step. Plots will take - # their missing attributes from the shared store or from the plot - # template - self.set_children(plots) - - if not self.has_template_plot: - leading_plot._SHOULD_SHARE_WITH_SIBLINGS = True - leading_plot.read_data(update_fig=False) - leading_plot._SHOULD_SHARE_WITH_SIBLINGS = False - self.set_data() - - else: - # If we haven't tried sharing data, the plots are already prepared (with read data of their own) - self.set_children(plots) - - call_method_if_present(self, "_after_children_updated") - - if update_fig: - self.get_figure() - - return self - - def update_children_settings(self, children_sel=None, **kwargs): - """ Updates the settings of all child plots - - Parameters - ----------- - children_sel: array-like of int, optional - The indices of the child plots that you want to update. - **kwargs - Keyword arguments specifying the settings that you want to update - and the values you want them to have - """ - return self.update_settings(on_children=True, on_parent_plot=False, children_sel=children_sel, **kwargs) - - def _update_settings(self, on_children=False, on_parent_plot=True, children_sel=None, **kwargs): - """ This method takes into account that on plots that contain children, one may want to update only the parent settings or all the child's settings. - - Parameters - ----------- - on_children: boolean, optional - whether the settings should be updated on child plots - on_parent_plot: boolean, optional - whether the settings should be updated on the parent plot. - children_sel: array-like of int, optional - The indices of the child plots that you want to update. - """ - if on_parent_plot: - super()._update_settings(**kwargs) - - if on_children: - - repeat_if_children(Configurable._update_settings)(self, children_sel=children_sel, **kwargs) - - call_method_if_present(self, "_after_children_updated") - - return self - - def set_children(self, plots, keep=False): - """ Sets the children of a multiple plot - - Parameters - -------- - plots: array-like of sisl.viz.plotly.Plot or plotly Figure - the plots that should be set as children for the animation. - keep: boolean, optional - whether the existing children should be kept. - - If `True`, `plots` is added after them. - """ - for plot in plots: - for key, val in self._attrs_for_children.items(): - setattr(plot, key, val) - - self.children = plots if not keep else [*self.children, *plots] - - return self - - def add_children(self, *plots): - """ Append children to the existing ones - - Parameters - ----------- - *plots: Plot - all the plots that you want to add as child plots of this one. - """ - self.set_children(plots, keep=True) - - def insert_childplot(self, index, plot): - """ Inserts a plot in a given position of the children list - - Parameters - ---------- - index: int - The position where the plot should be inserted - plot: sisl Plot or plotly Figure - The plot to insert in the list - """ - self.children.insert(index, plot) - - def shared_attr(self, key): - """ Gets an attribute that is located in the shared storage of the MultiplePlot - - This method will be given to all children so that they can retreive the shared - attributes. This is done in `set_children`. - - Parameters - ------------ - key: str - the name of the attribute that you want to retrieve - - Returns - ----------- - any - the value that you asked for - """ - # If from the beggining there is a template plot, the shared - # storage is actually that plot. - if self.has_template_plot: - return getattr(self.template_plot, key) - - return self.shared[key] - - def set_shared_attr(self, key, val): - """ Sets the value of a shared attribute - - Parameters - ------------ - key: str - the key of the attribute that is to be set. - val: any - the new value for the attribute - """ - self.shared[key] = val - - return self - - def get_figure(self, backend, **kwargs): - self._for_backend = getattr(self, "_for_backend", {}) - self._for_backend["children"] = self.children - return super().get_figure(backend, **kwargs) - - -class Animation(MultiplePlot): - """ Version of MultiplePlot that renders each plot in a different animation frame - - Parameters - ---------- - frame_duration: int, optional - Time (in ms) that each frame will be displayed. This is only - meaningful in the plotly backend - interpolated_frames: int, optional - The number of frames that should be interpolated between two plots. - This is only meaningful in the blender backend. - redraw: bool, optional - Whether each frame of the animation should be redrawn - If False, the animation will try to interpolate between one frame and - the other Set this to False if you are sure that the - frames contain the same number of traces, otherwise new traces will - not appear. - ani_method: optional - It determines how the animation is rendered. - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - """ - - _trigger_kw = "animate" - - _isAnimation = True - - _param_groups = ( - - { - "key": "animation", - "name": "Animation specific settings", - "icon": "videocam", - "description": "The fact that you have not studied cinematography is not a good excuse for creating ugly animations. Customize your animation with these settings" - }, - - ) - - _parameters = ( - - IntegerInput( - key = "frame_duration", name = "Frame duration", - default = 500, - group = "animation", - params = { - "step": 100 - }, - help = "Time (in ms) that each frame will be displayed.
This is only meaningful in the plotly backend" - ), - - IntegerInput( - key="interpolated_frames", name="Frames between images", - default=5, - group="animation", - help = "The number of frames that should be interpolated between two plots. This is only meaningful in the blender backend." - ), - - BoolInput( - key='redraw', name='Redraw each frame', - default=True, - group='animation', - help="""Whether each frame of the animation should be redrawn
- If False, the animation will try to interpolate between one frame and the other
- Set this to False if you are sure that the frames contain the same number of traces, otherwise new traces will not appear.""" - ), - - OptionsInput( - key='ani_method', name="Animation method", - default=None, - group='animation', - params={ - "placeholder": "Choose the animation method...", - "options": [ - {"label": "Update", "value": "update"}, - {"label": "Animate", "value": "animate"}, - ], - "isClearable": True, - "isSearchable": True, - "isMulti": False - }, - help="""It determines how the animation is rendered. """ - ) - - ) - - def __init__(self, *args, frame_names=None, _plugins={}, **kwargs): - if frame_names is not None: - _plugins["_get_frame_names"] = frame_names if callable(frame_names) else lambda self, i: frame_names[i] - elif "_get_frame_names" not in _plugins: - _plugins["_get_frame_names"] = lambda self, i: f"Frame {i}" - - super().__init__(*args, **kwargs, _plugins=_plugins) - - def get_figure(self, backend, interpolated_frames, **kwargs): - self._for_backend = getattr(self, "_for_backend", {}) - - # Get the names for each frame - frame_names = [] - for i, plot in enumerate(self.children): - frame_name = self._get_frame_names(i) - frame_names.append(frame_name) - self._for_backend["frame_names"] = frame_names - self._for_backend["interpolated_frames"] = interpolated_frames - - return super().get_figure(backend, **kwargs) - - -class SubPlots(MultiplePlot): - """ Version of MultiplePlot that renders each plot in a separate subplot - - Parameters - ----------- - arrange: optional - The way in which subplots should be aranged if the `rows` and/or - `cols` parameters are not provided. - rows: int, optional - The number of rows of the plot grid. If not provided, it will be - inferred from `cols` and the number of plots. If neither - `cols` or `rows` are provided, the `arrange` parameter will decide - how the layout should look like. - cols: int, optional - The number of columns of the subplot grid. If not provided, it will - be inferred from `rows` and the number of plots. If - neither `cols` or `rows` are provided, the `arrange` parameter will - decide how the layout should look like. - make_subplots_kwargs: dict, optional - Extra keyword arguments that will be passed to make_subplots. - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - """ - - _trigger_kw = "subplots" - - _is_subplots = True - - _parameters = ( - - OptionsInput(key='arrange', name='Automatic arrangement method', - default='rows', - params={ - 'options': [ - {'value': option, 'label': option} for option in ('rows', 'cols', 'square') - ], - 'placeholder': 'Choose a subplot arrangement method...', - 'isMulti': False, - 'isSearchable': True, - 'isClearable': True, - }, - help="""The way in which subplots should be aranged if the `rows` and/or `cols` - parameters are not provided.""" - ), - - IntegerInput(key='rows', name='Rows', - default=None, - help="""The number of rows of the plot grid. If not provided, it will be inferred from `cols` - and the number of plots. If neither `cols` or `rows` are provided, the `arrange` parameter will decide - how the layout should look like.""" - ), - - IntegerInput(key='cols', name='Columns', - default=None, - help="""The number of columns of the subplot grid. If not provided, it will be inferred from `rows` - and the number of plots. If neither `cols` or `rows` are provided, the `arrange` parameter will decide - how the layout should look like.""" - ), - - ProgramaticInput(key='make_subplots_kwargs', name='make_subplot additional arguments', - dtype=dict, - default={}, - help="""Extra keyword arguments that will be passed to make_subplots.""" - ) - ) - - def get_figure(self, backend, rows, cols, arrange, make_subplots_kwargs, **kwargs): - """ Builds the subplots layout from the child plots' data """ - nplots = len(self.children) - if rows is None and cols is None: - if arrange == 'rows': - rows = nplots - cols = 1 - elif arrange == 'cols': - cols = nplots - rows = 1 - elif arrange == 'square': - cols = nplots ** 0.5 - rows = nplots ** 0.5 - # we will correct so it *fits*, always have more columns - rows, cols = int(rows), int(cols) - cols = nplots // rows + min(1, nplots % rows) - elif rows is None: - # ensure it is large enough by adding 1 if they don't add up - rows = nplots // cols + min(1, nplots % cols) - elif cols is None: - # ensure it is large enough by adding 1 if they don't add up - cols = nplots // rows + min(1, nplots % rows) - - rows, cols = int(rows), int(cols) - - if cols * rows < nplots: - warn(f"requested {nplots} on a {rows}x{cols} grid layout. {nplots - cols*rows} plots will be missing.") - - self._for_backend = { - "rows": rows, - "cols": cols, - "make_subplots_kwargs": make_subplots_kwargs, - } - - super().get_figure(backend, **kwargs) + return super().__getattr__(key) + + def merge(self, *others, **kwargs): + from .plots.merged import merge_plots + return merge_plots(self, *others, **kwargs) + + def update_settings(self, *args, **kwargs): + deprecate("f{self.__class__.__name__}.update_settings is deprecated. Please use update_inputs.", "0.15") + return self.update_inputs(*args, **kwargs) + + @classmethod + def plot_class_key(cls) -> str: + return cls.__name__.replace("Plot", "").lower() diff --git a/src/sisl/viz/plots/__init__.py b/src/sisl/viz/plots/__init__.py index 9cab6ae8ca..6c74428662 100644 --- a/src/sisl/viz/plots/__init__.py +++ b/src/sisl/viz/plots/__init__.py @@ -1,9 +1,9 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .bands import BandsPlot -from .bond_length import BondLengthMap -from .fatbands import FatbandsPlot -from .geometry import GeometryPlot -from .grid import GridPlot, WavefunctionPlot -from .pdos import PdosPlot +""" +Module containing all sisl-provided plots, both in a functional form and as Workflows. +""" + +from .bands import BandsPlot, FatbandsPlot, bands_plot, fatbands_plot +from .geometry import GeometryPlot, SitesPlot, geometry_plot, sites_plot +from .grid import GridPlot, WavefunctionPlot, grid_plot, wavefunction_plot +from .merged import merge_plots +from .pdos import PdosPlot, pdos_plot diff --git a/src/sisl/viz/plots/bands.py b/src/sisl/viz/plots/bands.py index 8918f092a3..144d4dea4b 100644 --- a/src/sisl/viz/plots/bands.py +++ b/src/sisl/viz/plots/bands.py @@ -1,1023 +1,244 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import itertools -from collections import defaultdict -from functools import partial +from typing import Dict, Literal, Optional, Sequence, Tuple import numpy as np -try: - import xarray as xr -except ModuleNotFoundError: - pass - -import sisl -from sisl.physics.brillouinzone import BrillouinZone -from sisl.physics.spin import Spin - -from ..input_fields import ( - AiidaNodeInput, - BandStructureInput, - BoolInput, - ColorInput, - ErangeInput, - FloatInput, - FunctionInput, - QueriesInput, - RangeSliderInput, - SileInput, - SpinSelect, - TextInput, -) -from ..plot import Plot, entry_point -from ..plotutils import find_files - -try: - import pathos - _do_parallel_calc = True -except Exception: - _do_parallel_calc = False - - -class BandsPlot(Plot): - """ - Plot representation of the bands. +from sisl.viz.types import OrbitalQueries, StyleSpec + +from ..data.bands import BandsData +from ..figure import Figure, get_figure +from ..plot import Plot +from ..plotters.plot_actions import combined +from ..plotters.xarray import draw_xarray_xy +from ..plotutils import random_color +from ..processors.bands import calculate_gap, draw_gaps, filter_bands, style_bands +from ..processors.data import accept_data +from ..processors.logic import matches +from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data +from ..processors.xarray import scale_variable +from .orbital_groups_plot import OrbitalGroupsPlot + + +def bands_plot(bands_data: BandsData, + Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", + bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, + bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, + spindown_style: StyleSpec = {"color": "blue", "width": 1}, + colorscale: Optional[str] = None, + gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, + custom_gaps: Sequence[Dict] = [], + line_mode: Literal["line", "scatter", "area_line"] = "line", + backend: str = "plotly" +) -> Figure: + """Plots band structure energies, with plentiful of customization options. Parameters - ------------- - bands_file: bandsSileSiesta, optional - This parameter explicitly sets a .bands file. Otherwise, the bands - file is attempted to read from the fdf file - band_structure: BandStructure, optional - A band structure. it can either be provided as a sisl.BandStructure - object or as a list of points, which will be parsed into a - band structure object. Each item is a dict. Structure - of the dict: { 'x': 'y': 'z': - 'divisions': 'name': Tick that should be displayed at this - corner of the path. } - wfsx_file: wfsxSileSiesta, optional - The WFSX file to get the eigenstates. In standard SIESTA - nomenclature, this should probably be the *.bands.WFSX file, as it is - the one that contains the eigenstates for the band - structure. - aiida_bands: optional - An aiida BandsData node. - add_band_data: optional - This function receives each band and should return a dictionary with - additional arguments that are passed to the band drawing - routine. It also receives the plot as the second argument. - See the docs of `sisl.viz.backends.templates.Backend.draw_line` to - understand what are the supported arguments to be - returned. Notice that the arguments that the backend is able to - process can be very framework dependant. - Erange: array-like of shape (2,), optional - Energy range where the bands are displayed. - E0: float, optional - The energy to which all energies will be referenced (including - Erange). - bands_range: array-like of shape (2,), optional - The bands that should be displayed. Only relevant if Erange is None. - spin: optional - Determines how the different spin configurations should be displayed. - In spin polarized calculations, it allows you to choose between spin - 0 and 1. In non-colinear spin calculations, it allows you - to ask for a given spin texture, by specifying the - direction. - spin_texture_colorscale: str, optional - The plotly colorscale to use for the spin texture (if displayed) - gap: bool, optional - Whether the gap should be displayed in the plot - direct_gaps_only: bool, optional - Whether to show only gaps that are direct, according to the gap - tolerance - gap_tol: float, optional - The difference in k that must exist to consider to gaps - different. If two gaps' positions differ in less than - this, only one gap will be drawn. Useful in cases - where there are degenerated bands with exactly the same values. - gap_color: str, optional - Color to display the gap - custom_gaps: array-like of dict, optional - List of all the gaps that you want to display. Each item is a dict. - Structure of the dict: { 'from': K value where to start - measuring the gap. It can be either the label of - the k-point or the numeric value in the plot. 'to': K value - where to end measuring the gap. It can be either - the label of the k-point or the numeric value in the plot. - 'color': The color with which the gap should be displayed - 'spin': The spin components where the gap should be calculated. } - bands_width: float, optional - Width of the lines that represent the bands - bands_color: str, optional - Choose the color to display the bands. This will be used for the - spin up bands if the calculation is spin polarized - spindown_color: str, optional - Choose the color for the spin down bands.Only used if the - calculation is spin polarized. - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. + ---------- + bands_data: + The object containing the data to plot. + Erange: + The energy range to plot. + If None, the range is determined by ``bands_range``. + E0: + The energy reference. + E_axis: + Axis to plot the energies. + bands_range: + The bands to plot. Only used if ``Erange`` is None. + If None, the 15 bands above and below the Fermi level are plotted. + spin: + Which spin channel to display. Only meaningful for spin-polarized calculations. + If None and the calculation is spin polarized, both are plotted. + bands_style: + Styling attributes for bands. + spindown_style: + Styling attributes for the spin down bands (if present). Any missing attribute + will be taken from ``bands_style``. + colorscale: + Colorscale to use for the bands in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + gap: + Whether to display the gap. + gap_tol: + Tolerance in k for determining whether two gaps are the same. + gap_color: + Color of the gap. + gap_marker: + Marker styles for the gap (as plotly marker's styles). + direct_gaps_only: + Whether to only display direct gaps. + custom_gaps: + List of custom gaps to display. See the showcase notebooks for examples. + line_mode: + The method used to draw the band lines. + backend: + The backend to use to generate the figure. """ - _plot_type = "Bands" - - _parameters = ( - - SileInput(key = "bands_file", name = "Path to bands file", - dtype=sisl.io.siesta.bandsSileSiesta, - group="dataread", - params = { - "placeholder": "Write the path to your bands file here...", - }, - help = """This parameter explicitly sets a .bands file. Otherwise, the bands file is attempted to read from the fdf file """ - ), - - BandStructureInput(key="band_structure", name="Band structure"), - - SileInput(key='wfsx_file', name='Path to WFSX file', - dtype=sisl.io.siesta.wfsxSileSiesta, - default=None, - help="""The WFSX file to get the eigenstates. - In standard SIESTA nomenclature, this should probably be the *.bands.WFSX file, as it is the one - that contains the eigenstates for the band structure. - """ - ), - - AiidaNodeInput(key="aiida_bands", name="Aiida BandsData node", - default=None, - help="""An aiida BandsData node.""" - ), - - FunctionInput(key="add_band_data", name="Add band data function", - default=lambda band, plot: {}, - positional=["band", "plot"], - returns=["band_data"], - help="""This function receives each band and should return a dictionary with additional arguments - that are passed to the band drawing routine. It also receives the plot as the second argument. - See the docs of `sisl.viz.backends.templates.Backend.draw_line` to understand what are the supported arguments - to be returned. Notice that the arguments that the backend is able to process can be very framework dependant. - """ - ), - - ErangeInput(key="Erange", - help = "Energy range where the bands are displayed." - ), - - FloatInput(key="E0", name="Reference energy", - default=0, - help="""The energy to which all energies will be referenced (including Erange).""" - ), - - RangeSliderInput(key = "bands_range", name = "Bands range", - default = None, - params = { - 'step': 1, - }, - help = "The bands that should be displayed. Only relevant if Erange is None." - ), - - SpinSelect(key="spin", name="Spin", - default=None, - help="""Determines how the different spin configurations should be displayed. - In spin polarized calculations, it allows you to choose between spin 0 and 1. - In non-colinear spin calculations, it allows you to ask for a given spin texture, - by specifying the direction.""" - ), - - TextInput(key="spin_texture_colorscale", name="Spin texture colorscale", - default=None, - help="The plotly colorscale to use for the spin texture (if displayed)" - ), - - BoolInput(key="gap", name="Show gap", - default=False, - params={ - 'onLabel': 'Yes', - 'offLabel': 'No' - }, - help="Whether the gap should be displayed in the plot" - ), - - BoolInput(key="direct_gaps_only", name="Only direct gaps", - default=False, - params={ - 'onLabel': 'Yes', - 'offLabel': 'No' - }, - help="Whether to show only gaps that are direct, according to the gap tolerance" - ), - - FloatInput(key="gap_tol", name="Gap tolerance", - default=0.01, - params={ - 'step': 0.001 - }, - help="""The difference in k that must exist to consider to gaps different.
- If two gaps' positions differ in less than this, only one gap will be drawn.
- Useful in cases where there are degenerated bands with exactly the same values.""" - ), - - ColorInput(key="gap_color", name="Gap color", - default=None, - help="Color to display the gap" - ), - - QueriesInput(key="custom_gaps", name="Custom gaps", - default=[], - help="""List of all the gaps that you want to display.""", - queryForm=[ - - TextInput( - key="from", name="From", - help="""K value where to start measuring the gap. - It can be either the label of the k-point or the numeric value in the plot.""", - default="0", - ), - - TextInput( - key="to", name="To", - help="""K value where to end measuring the gap. - It can be either the label of the k-point or the numeric value in the plot.""", - default="0", - ), - - ColorInput( - key="color", name="Line color", - help="The color with which the gap should be displayed", - default=None, - ), - - SpinSelect( - key="spin", name="Spin", - help="The spin components where the gap should be calculated.", - default=None, - only_if_polarized=True, - ), - - ] - ), - - FloatInput(key="bands_width", name="Band lines width", - default=1, - help="Width of the lines that represent the bands" - ), - - ColorInput(key = "bands_color", name = "No spin/spin up line color", - default = "black", - help = "Choose the color to display the bands.
This will be used for the spin up bands if the calculation is spin polarized" - ), - - ColorInput(key = "spindown_color", name = "Spin down line color", - default = "blue", - help = "Choose the color for the spin down bands.
Only used if the calculation is spin polarized." - ), + bands_data = accept_data(bands_data, cls=BandsData, check=True) - ) + # Filter the bands + filtered_bands = filter_bands(bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin) - _update_methods = { - "read_data": [], - "set_data": ["_draw_gaps"], - "get_figure": [] - } - - @classmethod - def _default_animation(cls, wdir=None, frame_names=None, **kwargs): - """ - Defines the default animation, which is to look for all .bands files in wdir. - """ - bands_files = find_files(wdir, "*.bands", sort = True) - - def _get_frame_names(self): - - return [childPlot.get_setting("bands_file").name for childPlot in self.child_plots] - - return cls.animated("bands_file", bands_files, frame_names = _get_frame_names, wdir = wdir, **kwargs) - - @property - def bands(self): - return self.bands_data["E"] - - @property - def spin_moments(self): - return self.bands_data["spin_moments"] - - def _after_init(self): - self.spin = sisl.Spin("") - - self.add_shortcut("g", "Toggle gap", self.toggle_gap) - - @entry_point('bands file', 0) - def _read_siesta_output(self, bands_file, band_structure): - """ - Reads the bands information from a SIESTA bands file. - """ - if band_structure: - raise ValueError("A path was provided, therefore we can not use the .bands file even if there is one") - - self.bands_data = self.get_sile(bands_file or "bands_file").read_data(as_dataarray=True) - - # Define the spin class of the results we have retrieved - if len(self.bands_data.spin.values) == 2: - self.spin = sisl.Spin("p") - - @entry_point('aiida bands', 1) - def _read_aiida_bands(self, aiida_bands): - """ - Creates the bands plot reading from an aiida BandsData node. - """ - plot_data = aiida_bands._get_bandplot_data(cartesian=True) - bands = plot_data["y"] - - # Expand the bands array to have an extra dimension for spin - if bands.ndim == 2: - bands = np.expand_dims(bands, 0) - - # Get the info about where to put the labels - tick_info = defaultdict(list) - for tick, label in plot_data["labels"]: - tick_info["ticks"].append(tick) - tick_info["ticklabels"].append(label) - - # Construct the dataarray - self.bands_data = xr.DataArray( - bands, - coords={ - "spin": np.arange(0, bands.shape[0]), - "k": plot_data["x"], - "band": np.arange(0, bands.shape[2]), - }, - dims=("spin", "k", "band"), - attrs={**tick_info} - ) - - def _get_eigenstate_wrapper(self, k_vals, extra_vars=(), spin_moments=True): - """Helper function to build the function to call on each eigenstate. - - Parameters - ---------- - k_vals: array_like of shape (nk,) - The (linear) values of the k points. This will be used for plotting - the bands. - extra_vars: array-like of dict, optional - This argument determines the extra quantities that should be included - in the final dataset of the bands. Energy and spin moments (if available) - are already included, so no need to pass them here. - Each item of the array defines a new quantity and should contain a dictionary - with the following keys: - - 'name', str: The name of the quantity. - - 'getter', callable: A function that gets 3 arguments: eigenstate, plot and - spin index, and returns the values of the quantity in a numpy array. This - function will be called for each eigenstate object separately. That is, once - for each (k-point, spin) combination. - - 'coords', tuple of str: The names of the dimensions of the returned array. - The number of coordinates should match the number of dimensions. - of - - 'coords_values', dict: If this variable introduces a new coordinate, you should - pass the values for that coordinate here. If the coordinates were already defined - by another variable, they will already have values. If you are unsure that the - coordinates are new, just pass the values for them, they will get overwritten. - spin_moments: bool, optional - Whether to add, if the spin is not diagonal, spin moments. - - Returns - -------- - function: - The function that should be called for each eigenstate and will return a tuple of size - n_vars with the values for each variable. - tuple of dicts: - A tuple containing the dictionaries that define all variables. Exactly the same as - the passed `extra_vars`, but with the added Energy and spin moment (if available) variables. - dict: - Dictionary containing the values for each coordinate involved in the dataset. - """ - # In case it is a non_colinear or spin-orbit calculation we will get the spin moments - if spin_moments and not self.spin.is_diagonal: - def _spin_moment_getter(eigenstate, plot, spin): - return eigenstate.spin_moment().real - - extra_vars = ({ - "coords": ("axis", "band"), "coords_values": dict(axis=["x", "y", "z"]), - "name": "spin_moments", "getter": _spin_moment_getter}, - *extra_vars) - - # Define the available spin indices. Notice that at the end the spin dimension - # is removed from the dataset unless the calculation is spin polarized. So having - # spin_indices = [0] is just for convenience. - spin_indices = [0] - if self.spin.is_polarized: - spin_indices = [0, 1] - - # Add a variable to get the eigenvalues. - all_vars = ({ - "coords": ("band",), "coords_values": {"spin": spin_indices, "k": k_vals}, - "name": "E", "getter": lambda eigenstate, self, spin: eigenstate.eig}, - *extra_vars - ) - - # Now build the function that will be called for each eigenstate and will - # return the values for each variable. - def bands_wrapper(eigenstate, spin_index): - return tuple(var["getter"](eigenstate, self, spin_index) for var in all_vars) - - # Finally get the values for all coordinates involved. - coords_values = {} - for var in all_vars: - coords_values.update(var.get("coords_values", {})) - - return bands_wrapper, all_vars, coords_values - - @entry_point('wfsx file', 2) - def _read_from_wfsx(self, root_fdf, wfsx_file, extra_vars=(), need_H=False): - """Plots bands from the eigenvalues contained in a WFSX file. - - It also needs to get a geometry. - """ - if need_H: - self.setup_hamiltonian() - if self.H is None: - raise ValueError("Hamiltonian was not setup, and it is needed for the calculations") - parent = self.H - self.geometry = parent.geometry - else: - # Get the fdf sile - fdf = self.get_sile(root_fdf or "root_fdf") - # Read the geometry from the fdf sile - self.geometry = fdf.read_geometry(output=True) - parent = self.geometry - - # Get the wfsx file - wfsx_sile = self.get_sile(wfsx_file or "wfsx_file", parent=parent) - - # Now read all the information of the k points from the WFSX file - k, weights, nwfs = wfsx_sile.read_info() - # Get the number of wavefunctions in the file while performing a quick check - nwf = np.unique(nwfs) - if len(nwf) > 1: - raise ValueError(f"File {wfsx_sile.file} contains different number of wavefunctions in some k points") - nwf = nwf[0] - # From the k values read in the file, build a brillouin zone object. - # We will use it just to get the linear k values for plotting. - bz = BrillouinZone(self.geometry, k=k, weight=weights) - - # Read the sizes of the file, which contain the number of spin channels - # and the number of orbitals and the number of k points. - nspin, nou, nk, _ = wfsx_sile.read_sizes() - - # Find out the spin class of the calculation. - self.spin = Spin({ - 1: Spin.UNPOLARIZED, 2: Spin.POLARIZED, - 4: Spin.NONCOLINEAR, 8: Spin.SPINORBIT - }[nspin]) - # Now find out how many spin channels we need. Note that if there is only - # one spin channel there will be no "spin" dimension on the final dataset. - nspin = 2 if self.spin.is_polarized else 1 - - # Determine whether spin moments will be calculated. - spin_moments = False - if not self.spin.is_diagonal: - # We need to set the parent - self.setup_hamiltonian() - if self.H is not None: - # We could read a hamiltonian, set it as the parent of the wfsx sile - wfsx_sile = sisl.get_sile(wfsx_sile.file, parent=self.H) - spin_moments = True - - # Get the wrapper function that we should call on each eigenstate. - # This also returns the coordinates and names to build the final dataset. - bands_wrapper, all_vars, coords_values = self._get_eigenstate_wrapper( - sisl.physics.linspace_bz(bz), extra_vars=extra_vars, - spin_moments=spin_moments - ) - # Make sure all coordinates have values so that we can assume the shape - # of arrays below. - coords_values['band'] = np.arange(0, nwf) - coords_values['orb'] = np.arange(0, nou) - - self.ticks = None - - # Initialize all the arrays. For each quantity we will initialize - # an array of the needed shape. - arrays = {} - for var in all_vars: - # These are all the extra dimensions of the quantity. Note that a - # quantity does not need to have extra dimensions. - extra_shape = [len(coords_values[coord]) for coord in var['coords']] - # First two dimensions will always be the spin channel and the k index. - # Then add potential extra dimensions. - shape = (nspin, len(bz), *extra_shape) - # Initialize the array. - arrays[var['name']] = np.empty(shape, dtype=var.get('dtype', np.float64)) - - # Loop through eigenstates in the WFSX file and add their contribution to the bands - ik = -1 - for eigenstate in wfsx_sile.yield_eigenstate(): - spin = eigenstate.info.get("spin", 0) - # Every time we encounter spin 0, we are in a new k point. - if spin == 0: - ik +=1 - if ik == 0: - # If this is the first eigenstate we read, get the wavefunction - # indices. We will assume that ALL EIGENSTATES have the same indices. - # Note that we already checked previously that they all have the same - # number of wfs, so this is a fair assumption. - coords_values['band'] = eigenstate.info['index'] - - # Get all the values for this eigenstate. - returns = bands_wrapper(eigenstate, spin_index=spin) - # And store them in the respective arrays. - for var, vals in zip(all_vars, returns): - arrays[var['name']][spin, ik] = vals - - # Now that we have all the values, just build the dataset. - self.bands_data = xr.Dataset( - data_vars={ - var['name']: (("spin", "k", *var['coords']), arrays[var['name']]) - for var in all_vars - } - ).assign_coords(coords_values) - - self.bands_data.attrs = {"ticks": None, "ticklabels": None, "parent": bz} - - @entry_point('band structure', 3) - def _read_from_H(self, band_structure, extra_vars=()): - """ - Uses a sisl's `BandStructure` object to calculate the bands. - """ - if band_structure is None: - raise ValueError("No band structure (k points path) was provided") - - if not isinstance(getattr(band_structure, "parent", None), sisl.Hamiltonian): - self.setup_hamiltonian() - band_structure.set_parent(self.H) - else: - self.H = band_structure.parent - - # Define the spin class of this calculation. - self.spin = self.H.spin - - self.ticks = band_structure.lineartick() - - # Get the wrapper function that we should call on each eigenstate. - # This also returns the coordinates and names to build the final dataset. - bands_wrapper, all_vars, coords_values= self._get_eigenstate_wrapper( - band_structure.lineark(), extra_vars=extra_vars - ) - - # Get a dataset with all values for all spin indices - spin_datasets = [] - coords = [var['coords'] for var in all_vars] - name = [var['name'] for var in all_vars] - for spin_index in coords_values['spin']: - - # Non collinear routines don't accept the keyword argument "spin" - spin_kwarg = {"spin": spin_index} - if not self.spin.is_diagonal: - spin_kwarg = {} - - with band_structure.apply(pool=_do_parallel_calc, zip=True) as parallel: - spin_bands = parallel.dataarray.eigenstate( - wrap=partial(bands_wrapper, spin_index=spin_index), - **spin_kwarg, - coords=coords, name=name, - ) - - spin_datasets.append(spin_bands) - - # Merge everything into a single dataset with a spin dimension - self.bands_data = xr.concat(spin_datasets, "spin").assign_coords(coords_values) - - # If the band structure contains discontinuities, we will copy the dataset - # adding the discontinuities. - if len(band_structure._jump_idx) > 0: - - old_coords = self.bands_data.coords - coords = { - name: band_structure.insert_jump(old_coords[name]) if name == "k" else old_coords[name].values - for name in old_coords - } - - def _add_jump(array): - if "k" in array.coords: - array = array.transpose("k", ...) - return (array.dims, band_structure.insert_jump(array)) - else: - return array - - self.bands_data = xr.Dataset( - {name: _add_jump(self.bands_data[name]) for name in self.bands_data}, - coords=coords - ) - - # Inform of where to place the ticks - self.bands_data.attrs = {"ticks": self.ticks[0], "ticklabels": self.ticks[1], **spin_datasets[0].attrs} - - def _after_read(self): - if isinstance(self.bands_data, xr.DataArray): - attrs = self.bands_data.attrs - self.bands_data = xr.Dataset({"E": self.bands_data}) - self.bands_data.attrs = attrs - - # If the calculation is not spin polarized it makes no sense to - # retain a spin index - if "spin" in self.bands_data and not self.spin.is_polarized: - self.bands_data = self.bands_data.sel(spin=self.bands_data.spin[0], drop=True) - - # Inform the spin input of what spin class are we handling - self.get_param("spin").update_options(self.spin) - self.get_param("custom_gaps").get_param("spin").update_options(self.spin) - - # Make sure that the bands_range control knows which bands are available - i_bands = self.bands.band.values - - if len(i_bands) > 30: - i_bands = i_bands[np.linspace(0, len(i_bands)-1, 20, dtype=int)] - - self.modify_param('bands_range', 'inputField.params', { - **self.get_param('bands_range')["inputField"]["params"], - "min": min(i_bands), - "max": max(i_bands), - "allowCross": False, - "marks": {int(i): str(i) for i in i_bands}, - }) - - def _set_data(self, Erange, E0, bands_range, spin, spin_texture_colorscale, bands_width, bands_color, spindown_color, - gap, gap_tol, gap_color, direct_gaps_only, custom_gaps): - # Calculate all the gaps of this band structure - self._calculate_gaps(E0) - - # Shift all the bands to the reference - filtered_bands = self.bands - E0 - continous_bands = filtered_bands.dropna("k", how="all") - - # Get the bands that matter for the plot - if Erange is None: - - if bands_range is None: - # If neither E range or bands_range was provided, we will just plot the 15 bands below and above the fermi level - CB = int(continous_bands.where(continous_bands <= 0).argmax('band').max()) - bands_range = [int(max(continous_bands["band"].min(), CB - 15)), int(min(continous_bands["band"].max() + 1, CB + 16))] - - i_bands = np.arange(*bands_range) - filtered_bands = filtered_bands.where(filtered_bands.band.isin(i_bands), drop=True) - continous_bands = filtered_bands.dropna("k", how="all") - self.update_settings( - run_updates=False, - Erange=np.array([float(f'{val:.3f}') for val in [float(continous_bands.min() - 0.01), float(continous_bands.max() + 0.01)]]), - bands_range=bands_range, no_log=True) - else: - Erange = np.array(Erange) - filtered_bands = filtered_bands.where((filtered_bands <= Erange[1]) & (filtered_bands >= Erange[0])).dropna("band", "all") - continous_bands = filtered_bands.dropna("k", how="all") - self.update_settings(run_updates=False, bands_range=[int(continous_bands['band'].min()), int(continous_bands['band'].max())], no_log=True) - - # Give the filtered bands the same attributes as the full bands - filtered_bands.attrs = self.bands_data.attrs - - # Let's treat the spin if the user requested it - self.spin_texture = False - if spin is not None and len(spin) > 0: - if isinstance(spin[0], int): - # Only use the spin setting if there is a spin index - if "spin" in filtered_bands.coords: - filtered_bands = filtered_bands.sel(spin=spin) - elif isinstance(spin[0], str): - if "spin_moments" not in self.bands_data: - raise ValueError(f"You requested spin texture ({spin[0]}), but spin moments have not been calculated. The spin class is {self.spin.kind}") - self.spin_texture = True - - if self.spin_texture: - spin_moments = self.spin_moments.sel(band=filtered_bands.band, axis=spin[0]) - else: - spin_moments = [] - - return { - "draw_bands": { - "filtered_bands": filtered_bands, - "line": {"color": bands_color, "width": bands_width}, - "spindown_line": {"color": spindown_color}, - "spin": self.spin, - "spin_texture": {"show": self.spin_texture, "values": spin_moments, "colorscale": spin_texture_colorscale}, - }, - "gaps": self._get_gaps(gap, gap_tol, gap_color, direct_gaps_only, custom_gaps) - } + # Add the styles + styled_bands = style_bands(filtered_bands, bands_style=bands_style, spindown_style=spindown_style) - def get_figure(self, backend, add_band_data, **kwargs): - self._for_backend["draw_bands"]["add_band_data"] = add_band_data - return super().get_figure(backend, **kwargs) - - def _calculate_gaps(self, E0): - """ - Calculates the gap (or gaps) assuming 0 is the fermi level. - - It creates the attributes `gap` and `gap_info` - """ - # Calculate the band gap to store it - shifted_bands = self.bands - E0 - above_fermi = self.bands.where(shifted_bands > 0) - below_fermi = self.bands.where(shifted_bands < 0) - CBbot = above_fermi.min() - VBtop = below_fermi.max() - - CB = above_fermi.where(above_fermi==CBbot, drop=True).squeeze() - VB = below_fermi.where(below_fermi==VBtop, drop=True).squeeze() - - self.gap = float(CBbot - VBtop) - - self.gap_info = { - 'k': (VB["k"].values, CB['k'].values), - 'bands': (VB["band"].values, CB["band"].values), - 'spin': (VB["spin"].values, CB["spin"].values) if self.spin.is_polarized else (0, 0), - 'Es': [float(VBtop), float(CBbot)] - } + # Determine what goes on each axis + x = matches(E_axis, "x", ret_true="E", ret_false="k") + y = matches(E_axis, "y", ret_true="E", ret_false="k") + + # Get the actions to plot lines + bands_plottings = draw_xarray_xy(data=styled_bands, x=x, y=y, set_axrange=True, what=line_mode, colorscale=colorscale, dependent_axis=E_axis) - def _get_gaps(self, gap, gap_tol, gap_color, direct_gaps_only, custom_gaps): - """ - Draws the calculated gaps and the custom gaps in the plot - """ - gaps_to_draw = [] + # Gap calculation + gap_info = calculate_gap(filtered_bands) + # Plot it if the user has asked for it. + gaps_plottings = draw_gaps(bands_data, gap, gap_info, gap_tol, gap_color, gap_marker, direct_gaps_only, custom_gaps, E_axis=E_axis) - # Draw gaps - if gap: + all_plottings = combined(bands_plottings, gaps_plottings, composite_method=None) - gapKs = [np.atleast_1d(k) for k in self.gap_info['k']] + return get_figure(backend=backend, plot_actions=all_plottings) - # Remove "equivalent" gaps - def clear_equivalent(ks): - if len(ks) == 1: - return ks +def _default_random_color(x): + return x.get("color") or random_color() - uniq = [ks[0]] - for k in ks[1:]: - if abs(min(np.array(uniq) - k)) > gap_tol: - uniq.append(k) - return uniq - all_gapKs = itertools.product(*[clear_equivalent(ks) for ks in gapKs]) +def _group_traces(actions): - for gap_ks in all_gapKs: + seen_groups = [] - if direct_gaps_only and abs(gap_ks[1] - gap_ks[0]) > gap_tol: - continue + new_actions = [] + for action in actions: + if action["method"].startswith("draw_"): + group = action["kwargs"].get("name") + action = action.copy() + action['kwargs']['legendgroup'] = group - ks, Es = self._get_gap_coords(*gap_ks, color=gap_color) - name = "Gap" + if group in seen_groups: + action["kwargs"]["showlegend"] = False + else: + seen_groups.append(group) + + new_actions.append(action) + + return new_actions + + +# I keep the fatbands plot here so that one can see how similar they are. +# I am yet to find a nice solution for extending workflows. +def fatbands_plot(bands_data: BandsData, + Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", + bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, + bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, + spindown_style: StyleSpec = {"color": "blue", "width": 1}, + gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, + custom_gaps: Sequence[Dict] = [], + bands_mode: Literal["line", "scatter", "area_line"] = "line", + # Fatbands inputs + groups: OrbitalQueries = [], + fatbands_var: str = "norm2", + fatbands_mode: Literal["line", "scatter", "area_line"] = "area_line", + fatbands_scale: float = 1., + backend: str = "plotly" +) -> Figure: + """Plots band structure energies showing the contribution of orbitals to each state. - gaps_to_draw.append({"ks": ks, "Es": Es, "color": gap_color, "name": name}) + Parameters + ---------- + bands_data: + The object containing the data to plot. + Erange: + The energy range to plot. + If None, the range is determined by ``bands_range``. + E0: + The energy reference. + E_axis: + Axis to plot the energies. + bands_range: + The bands to plot. Only used if ``Erange`` is None. + If None, the 15 bands above and below the Fermi level are plotted. + spin: + Which spin channel to display. Only meaningful for spin-polarized calculations. + If None and the calculation is spin polarized, both are plotted. + bands_style: + Styling attributes for bands. + spindown_style: + Styling attributes for the spin down bands (if present). Any missing attribute + will be taken from ``bands_style``. + gap: + Whether to display the gap. + gap_tol: + Tolerance in k for determining whether two gaps are the same. + gap_color: + Color of the gap. + gap_marker: + Marker styles for the gap (as plotly marker's styles). + direct_gaps_only: + Whether to only display direct gaps. + custom_gaps: + List of custom gaps to display. See the showcase notebooks for examples. + bands_mode: + The method used to draw the band lines. + groups: + Orbital groups to plots. See showcase notebook for examples. + fatbands_var: + The variable to use from bands_data to determine the width of the fatbands. + This variable must have as coordinates (k, band, orb, [spin]). + fatbands_mode: + The method used to draw the fatbands. + fatbands_scale: + Factor that scales the size of all fatbands. + backend: + The backend to use to generate the figure. + """ + bands_data = accept_data(bands_data, cls=BandsData, check=True) - # Draw the custom gaps. These are gaps that do not necessarily represent - # the maximum and the minimum of the VB and CB. - for custom_gap in custom_gaps: + # Filter the bands + filtered_bands = filter_bands(bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin) - requested_spin = custom_gap.get("spin", None) - if requested_spin is None: - requested_spin = [0, 1] + # Add the styles + styled_bands = style_bands(filtered_bands, bands_style=bands_style, spindown_style=spindown_style) - avail_spins = self.bands_data.get("spin", [0]) + # Process fatbands + orbital_manager = get_orbital_queries_manager( + bands_data, + key_gens={ + "color": _default_random_color, + } + ) + fatbands_data = reduce_orbital_data( + filtered_bands, groups=groups, orb_dim="orb", spin_dim="spin", sanitize_group=orbital_manager, + group_vars=('color', 'dash'), groups_dim="group", drop_empty=True, + spin_reduce=np.sum, + ) + scaled_fatbands_data = scale_variable(fatbands_data, var=fatbands_var, scale=fatbands_scale, default_value=1, allow_not_present=True) + + # Determine what goes on each axis + x = matches(E_axis, "x", ret_true="E", ret_false="k") + y = matches(E_axis, "y", ret_true="E", ret_false="k") + + sanitized_fatbands_mode = matches(groups, [], ret_true="none", ret_false=fatbands_mode) + + # Get the actions to plot lines + fatbands_plottings = draw_xarray_xy( + data=scaled_fatbands_data, x=x, y=y, color="color", width=fatbands_var, what=sanitized_fatbands_mode, dependent_axis=E_axis, + name="group" + ) + grouped_fatbands_plottings = _group_traces(fatbands_plottings) + bands_plottings = draw_xarray_xy(data=styled_bands, x=x, y=y, set_axrange=True, what=bands_mode, dependent_axis=E_axis) - for spin in avail_spins: - if spin in requested_spin: - from_k = custom_gap["from"] - to_k = custom_gap["to"] - color = custom_gap.get("color", None) - name = f"Gap ({from_k}-{to_k})" - ks, Es = self._get_gap_coords(from_k, to_k, color=color, gap_spin=spin) + # Gap calculation + gap_info = calculate_gap(filtered_bands) + # Plot it if the user has asked for it. + gaps_plottings = draw_gaps(bands_data, gap, gap_info, gap_tol, gap_color, gap_marker, direct_gaps_only, custom_gaps, E_axis=E_axis) - gaps_to_draw.append({"ks": ks, "Es": Es, "color": color, "name": name}) + all_plottings = combined(grouped_fatbands_plottings, bands_plottings, gaps_plottings, composite_method=None) - return gaps_to_draw + return get_figure(backend=backend, plot_actions=all_plottings) - def _sanitize_k(self, k): - """Returns the float value of a k point in the plot. +class BandsPlot(Plot): - Parameters - ------------ - k: float or str - The k point that you want to sanitize. - If it can be parsed into a float, the result of `float(k)` will be returned. - If it is a string and it is a label of a k point, the corresponding k value for that - label will be returned + function = staticmethod(bands_plot) - Returns - ------------ - float - The sanitized k value. - """ - san_k = None +class FatbandsPlot(OrbitalGroupsPlot): - try: - san_k = float(k) - except ValueError: - if k in self.bands_data.attrs["ticklabels"]: - i_tick = self.bands_data.attrs["ticklabels"].index(k) - san_k = self.bands_data.attrs["ticks"][i_tick] - else: - pass - # raise ValueError(f"We can not interpret {k} as a k-location in the current bands plot") - # This should be logged instead of raising the error - - return san_k - - def _get_gap_coords(self, from_k, to_k=None, gap_spin=0, **kwargs): - """ - Calculates the coordinates of a gap given some k values. - Parameters - ----------- - from_k: float or str - The k value where you want the gap to start (bottom limit). - If "to_k" is not provided, it will be interpreted also as the top limit. - If a k-value is a float, it will be directly interpreted - as the position in the graph's k axis. - If a k-value is a string, it will be attempted to be parsed - into a float. If not possible, it will be interpreted as a label - (e.g. "Gamma"). - to_k: float or str, optional - same as "from_k" but in this case represents the top limit. - If not provided, "from_k" will be used. - gap_spin: int, optional - the spin component where you want to draw the gap. - **kwargs: - keyword arguments that are passed directly to the new trace. - - Returns - ----------- - tuple - A tuple containing (k_values, E_values) - """ - if to_k is None: - to_k = from_k - - ks = [None, None] - # Parse the names of the kpoints into their numeric values - # if a string was provided. - for i, val in enumerate((from_k, to_k)): - ks[i] = self._sanitize_k(val) - - VB, CB = self.gap_info["bands"] - spin_bands = self.bands.sel(spin=gap_spin) if "spin" in self.bands.coords else self.bands - Es = [spin_bands.dropna("k", "all").sel(k=k, band=band, method="nearest") for k, band in zip(ks, (VB, CB))] - # Get the real values of ks that have been obtained - # because we might not have exactly the ks requested - ks = [np.ravel(E.k)[0] for E in Es] - Es = [np.ravel(E)[0] for E in Es] - - return ks, Es - - def toggle_gap(self): - """ - If the gap was being displayed, hide it. Else, show it. - """ - return self.update_settings(gap= not self.settings["gap"]) - - def plot_Ediff(self, band1, band2): - """ - Plots the energy difference between two bands. - - Parameters - ---------- - band1, band2: int - the indices of the two bands you want to get the difference for. - - Returns - --------- - Plot - a new plot with the plotted information. - """ - import plotly.express as px - - two_bands = self.bands.sel(band=[band1, band2]).squeeze().values - - diff = two_bands[:, 1] - two_bands[:, 0] - - fig = px.line(x=self.bands.k.values, y=diff) - - fig.update_layout({"title": f"Energy difference between bands {band1} and {band2}", "yaxis_range": [np.min(diff), np.max(diff)]}) - - return fig - - def _plot_Kdiff(self, band1, band2, E=None, offsetE=False): - """ - ONLY WORKING FOR A PAIR OF BANDS THAT ARE ALWAYS INCREASING OR ALWAYS DECREASING - AND ARE ISOLATED (sorry) - - Plots the k difference between two bands. - - Parameters - ----------- - band1, band2: int - the indices of the two bands you want to get the difference for. - E: array-like, optional - the energy values for which we want the K difference between the two bands - offsetE: boolean - whether the energy should be referenced to the minimum of the first band - - Returns - --------- - Plot - a new plot with the plotted information. - """ - import plotly.express as px - b1, b2 = self.bands.sel(band=[band1, band2]).squeeze().values.T - ks = self.bands.k.values - - if E is None: - #Interpolate the values of K for band2 that correspond to band1's energies. - b2Ks_for_b1Es = np.interp(b1, b2, ks) - - E = b1 - diff = ks - b2Ks_for_b1Es - - else: - if offsetE: - E += np.min(b1) - - diff = np.interp(E, b1, ks) - \ - np.interp(E, b2, ks) - - E -= np.min(b1) if offsetE else 0 - - fig = px.line(x=diff, y=E) - - plt = super().from_plotly(fig) - - plt.update_layout({"title": f"Delta K between bands {band1} and {band2}", 'xaxis_title': 'Delta k', 'yaxis_title': 'Energy [eV]'}) - - return plt - - def effective_mass(self, band, k, k_direction, band_spin=0, n_points=10): - """Calculates the effective mass from the curvature of a band in a given k point. - - It works by fitting the band to a second order polynomial. - - Notes - ----- - Only valid if there are no band-crossings in the fitted range. - The effective mass may be highly dependent on the `k_direction` parameter, as well as the - number of points fitted. - - Parameters - ----------- - band: int - The index of the band that we want to fit - k: float or str - The k value where we want to find the curvature of the band to calculate the effective mass. - band_spin: int, optional - The spin value for which we want the effective mass. - n_points: int - The number of points that we want to use for the polynomial fit. - k_direction: {"symmetric", "right", "left"}, optional - Indicates in which direction -starting from `k`- should the band be fitted. - "left" and "right" mean that the fit will only be done in one direction, while - "symmetric" indicates that points from both sides will be used. - - Return - ----------- - float - The efective mass, in atomic units. - """ - from sisl.unit.base import units - - # Get the band that we want to fit - bands = self.bands - if "spin" in bands.coords: - band_vals = bands.sel(band=band, spin=band_spin) - else: - band_vals = bands.sel(band=band) - - # Sanitize k to a float - k = self._sanitize_k(k) - # Find the index of the requested k - k_index = abs(self.bands.k -k).values.argmin() - - # Determine which slice of the band will we take depending on k_direction and n_points - if k_direction == "symmetric": - sel_slice = slice(k_index - n_points // 2, k_index + n_points // 2 + 1) - elif k_direction == "left": - sel_slice = slice(k_index - n_points + 1, k_index + 1) - elif k_direction == "right": - sel_slice = slice(k_index, k_index + n_points) - else: - raise ValueError(f"k_direction must be one of ['symmetric', 'left', 'right'], {k_direction} was passed") - - # Grab the slice of the band that we are going to fit - sel_band = band_vals[sel_slice] * units("eV", "Hartree") - sel_k = bands.k[sel_slice] - k - - # Fit the band to a second order polynomial - polyfit = np.polynomial.Polynomial.fit(sel_k, sel_band, 2) - - # Get the coefficient for the second order term - coeff_2 = polyfit.convert().coef[2] - - # Calculate the effective mass from the dispersion relation. - # Note that hbar = m_e = 1, since we are using atomic units. - eff_m = 1 / (2 * coeff_2) - - return eff_m + function = staticmethod(fatbands_plot) \ No newline at end of file diff --git a/src/sisl/viz/plots/bond_length.py b/src/sisl/viz/plots/bond_length.py deleted file mode 100644 index b542081d9e..0000000000 --- a/src/sisl/viz/plots/bond_length.py +++ /dev/null @@ -1,410 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from functools import partial - -import sisl -from sisl.utils.mathematics import fnorm - -from ..input_fields import BoolInput, FloatInput, IntegerInput, SileInput, TextInput -from ..plotutils import find_files -from .geometry import BoundGeometry, GeometryPlot - - -class BondLengthMap(GeometryPlot): - """ - Colorful representation of bond lengths. - - Parameters - ------------- - geom_from_output: bool, optional - In case the geometry is read from the fdf file, this will determine - whether the input or the output geometry is taken.This setting - will be ignored if geom_file is passed - strain_ref: str or Geometry, optional - The path to a geometry or a Geometry object used to calculate strain - from. This geometry will probably be the relaxed - one If provided, colors can indicate strain values. - Otherwise they are just bond length - strain: bool, optional - Determines whether strain values should be displayed instead of - lengths - bond_thresh: float, optional - Maximum distance between two atoms to draw a bond - colorscale: str, optional - This determines the colormap to be used for the bond lengths - display. You can see all valid colormaps here: - https://plot.ly/python/builtin-colorscales/ - Note that you can reverse a color map by adding _r - cmin: float, optional - Minimum color scale - cmax: float, optional - Maximum color scale - cmid: float, optional - Sets the middle point of the color scale. Only meaningful in - diverging colormaps If this is set 'cmin' and 'cmax' - are ignored. In strain representations this might be set to 0. - colorbar: bool, optional - Whether the color bar should be displayed or not. - points_per_bond: int, optional - Number of points that fill a bond. More points will make it look - more like a line but will slow plot rendering down. - geometry: Geometry, optional - A geometry object - geom_file: str, optional - A file name that can read a geometry - show_bonds: bool, optional - Show bonds between atoms. - bonds_style: dict, optional - Customize the style of the bonds by passing style specifications. - Currently, you can only pass one style specification. Styling bonds - individually is not supported yet, but it will be in the future. - Structure of the dict: { } - axes: optional - The axis along which you want to see the geometry. You - can provide as many axes as dimensions you want for your plot. - Note that the order is important and will result in setting the plot - axes diferently. For 2D and 1D representations, you can - pass an arbitrary direction as an axis (array of shape (3,)) - dataaxis_1d: array-like or function, optional - If you want a 1d representation, you can provide a data axis. - It determines the second coordinate of the atoms. - If it's a function, it will recieve the projected 1D coordinates and - needs to returns the coordinates for the other axis as - an array. If not provided, the other axis - will just be 0 for all points. - show_cell: optional - Specifies how the cell should be rendered. (False: not - rendered, 'axes': render axes only, 'box': render a bounding box) - nsc: array-like, optional - Make the geometry larger by tiling it along each lattice vector - atoms: dict, optional - The atoms that are going to be displayed in the plot. - This also has an impact on bonds (see the `bind_bonds_to_ats` and - `show_atoms` parameters). If set to None, all atoms are - displayed Structure of the dict: { 'index': Structure of - the dict: { 'in': } 'fx': 'fy': - 'fz': 'x': 'y': 'z': 'Z': - 'neighbours': Structure of the dict: { 'range': - 'R': 'neigh_tag': } 'tag': 'seq': } - atoms_style: array-like of dict, optional - Customize the style of the atoms by passing style specifications. - Each style specification can have an "atoms" key to select the atoms - for which that style should be used. If an atom fits into - more than one selector, the last specification is used. - Each item is a dict. Structure of the dict: { 'atoms': - Structure of the dict: { 'index': Structure of the dict: { - 'in': } 'fx': 'fy': 'fz': 'x': - 'y': 'z': 'Z': 'neighbours': Structure - of the dict: { 'range': 'R': 'neigh_tag': - } 'tag': 'seq': } 'color': 'size': - 'opacity': 'vertices': In a 3D representation, the number of - vertices that each atom sphere is composed of. } - arrows: array-like of dict, optional - Add arrows centered at the atoms to display some vector property. - You can add as many arrows as you want, each with different styles. - Each item is a dict. Structure of the dict: { 'atoms': - Structure of the dict: { 'index': Structure of the dict: { - 'in': } 'fx': 'fy': 'fz': 'x': - 'y': 'z': 'Z': 'neighbours': Structure - of the dict: { 'range': 'R': 'neigh_tag': - } 'tag': 'seq': } 'data': 'scale': - 'color': 'width': 'name': - 'arrowhead_scale': 'arrowhead_angle': } - atoms_scale: float, optional - A scaling factor for atom sizes. This is a very quick way to rescale. - atoms_colorscale: str, optional - The colorscale to use to map values to colors for the atoms. - Only used if atoms_color is provided and is an array of values. - bind_bonds_to_ats: bool, optional - whether only the bonds that belong to an atom that is present should - be displayed. If False, all bonds are displayed - regardless of the `atoms` parameter - show_atoms: bool, optional - If set to False, it will not display atoms. Basically - this is a shortcut for ``atoms = [], bind_bonds_to_ats=False``. - Therefore, it will override these two parameters. - points_per_bond: int, optional - Number of points that fill a bond in 2D in case each bond has a - different color or different size. More points will make it look - more like a line but will slow plot rendering down. - cell_style: dict, optional - The style of the unit cell lines Structure of the dict: { - 'color': 'width': 'opacity': } - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - """ - - _plot_type = "Bond length" - - _parameters = ( - - BoolInput( - key = "geom_from_output", name = "Geometry from output", - default = True, - group = "dataread", - params = { - "offLabel": "No", - "onLabel": "Yes", - }, - help = "In case the geometry is read from the fdf file, this will determine whether the input or the output geometry is taken.
This setting will be ignored if geom_file is passed" - ), - - SileInput( - key = "strain_ref", name = "Strain reference geometry", - hasattr=['read_geometry'], - dtype=(str, sisl.Geometry), - group = "dataread", - params = { - "placeholder": "Write the path to your strain reference file here..." - }, - help = """The path to a geometry or a Geometry object used to calculate strain from.
- This geometry will probably be the relaxed one
- If provided, colors can indicate strain values. Otherwise they are just bond length""" - ), - - BoolInput( - key = "strain", name = "Display strain", - default = True, - params = { - "offLabel": False, - "onLabel": True - }, - help = """Determines whether strain values should be displayed instead of lengths""" - ), - - FloatInput( - key = "bond_thresh", name = "Bond length threshold", - default = 1.7, - params = { - "step": 0.01 - }, - help = "Maximum distance between two atoms to draw a bond" - ), - - TextInput( - key="colorscale", name="Plotly colormap", - default="viridis", - params={ - "placeholder": "Write a valid plotly colormap here..." - }, - help="""This determines the colormap to be used for the bond lengths display.
- You can see all valid colormaps here: https://plot.ly/python/builtin-colorscales/
- Note that you can reverse a color map by adding _r""" - ), - - FloatInput( - key = "cmin", name = "Color scale low limit", - default = 0, - params = { - "step": 0.01 - }, - help="Minimum color scale" - ), - - FloatInput( - key = "cmax", name = "Color scale high limit", - default = 0, - params = { - "step": 0.01 - }, - help="Maximum color scale" - ), - - FloatInput( - key = "cmid", name = "Color scale mid point", - default = None, - params = { - "step": 0.01 - }, - help = """Sets the middle point of the color scale. Only meaningful in diverging colormaps
- If this is set 'cmin' and 'cmax' are ignored. In strain representations this might be set to 0. - """ - ), - - BoolInput( - key='colorbar', name='Show colorbar', - default=True, - help="""Whether the color bar should be displayed or not.""" - ), - - IntegerInput( - key="points_per_bond", name="Points per bond", - default=10, - help="Number of points that fill a bond.
More points will make it look more like a line but will slow plot rendering down." - ), - - ) - - _layout_defaults = { - 'xaxis_title': 'X [Ang]', - 'yaxis_title': "Y [Ang]", - 'yaxis_zeroline': False - } - - @classmethod - def _default_animation(self, wdir=None, frame_names=None, **kwargs): - """By default, we will animate all the *XV files that we find""" - geom_files = find_files(wdir, "*.XV", sort = True) - - return BondLengthMap.animated("geom_file", geom_files, wdir = wdir, **kwargs) - - @property - def on_relaxed_geom(self): - """ - Returns a bound geometry, which you can apply methods to so that the plot - updates automatically. - """ - return BoundGeometry(self.relaxed_geom, self) - - _read_geom = GeometryPlot.entry_points[0] - _read_file = GeometryPlot.entry_points[1] - - def _read_strain_ref(self, ref): - """Reads the strain reference, if there is any.""" - strain_ref = ref - - if isinstance(strain_ref, str): - self.relaxed_geom = self.get_sile(strain_ref).read_geometry() - elif isinstance(strain_ref, sisl.Geometry): - self.relaxed_geom = strain_ref - else: - self.relaxed_geom = None - - def _after_read(self, strain_ref, nsc): - self._read_strain_ref(strain_ref) - - is_strain_ref = self.relaxed_geom is not None - - self._tiled_geometry = self.geometry - for ax, reps in enumerate(nsc): - self._tiled_geometry = self._tiled_geometry.tile(reps, ax) - if is_strain_ref: - self.relaxed_geom = self.relaxed_geom.tile(reps, ax) - - self.geom_bonds = self.find_all_bonds(self._tiled_geometry) - - if is_strain_ref: - self.relaxed_bonds = self.find_all_bonds(self.relaxed_geom) - - self.get_param("atoms").update_options(self.geometry) - - def _wrap_bond3D(self, bond, bonds_styles, show_strain=False): - """ - Receives a bond and sets its color to the bond length for the 3D case - """ - if show_strain: - color = self._bond_strain(self.relaxed_geom, self._tiled_geometry, bond) - name = f'Strain: {color:.3f}' - else: - color = self._bond_length(self._tiled_geometry, bond) - name = f'{color:.3f} Ang' - - self.colors.append(color) - - return { - **self._default_wrap_bond3D(bond, bonds_styles=bonds_styles), - "color": color, - "name": name - } - - def _wrap_bond2D(self, bond, xys, bonds_styles, show_strain=False): - """ - Receives a bond and sets its color to the bond length for the 2D case - """ - if show_strain: - color = self._bond_strain(self.relaxed_geom, self._tiled_geometry, bond) - name = f'Strain: {color:.3f}' - else: - color = self._bond_length(self._tiled_geometry, bond) - name = f'{color:.3f} Ang' - - self.colors.append(color) - - return { - **self._default_wrap_bond2D(bond, xys, bonds_styles=bonds_styles), - "color": color, "name": name - } - - @staticmethod - def _bond_length(geom, bond): - """ - Returns the length of a bond between two atoms. - - Parameters - ------------ - geom: Geometry - the structure where the atoms are - bond: array-like of two int - the indices of the atoms that form the bond - """ - return fnorm(geom[bond[1]] - geom[bond[0]]) - - @staticmethod - def _bond_strain(relaxed_geom, geom, bond): - """ - Calculates the strain of a bond using a reference geometry. - - Parameters - ------------ - relaxed_geom: Geometry - the structure to take as a reference - geom: Geometry - the structure to take as the "current" one - bond: array-like of two int - the indices of the atoms that form the bond - """ - relaxed_bl = BondLengthMap._bond_length(relaxed_geom, bond) - bond_length = BondLengthMap._bond_length(geom, bond) - - return (bond_length - relaxed_bl) / relaxed_bl - - def _set_data(self, strain, axes, atoms, show_atoms, bind_bonds_to_ats, points_per_bond, cmin, cmax, colorscale, colorbar, - kwargs3d={}, kwargs2d={}, kwargs1d={}): - - # Set the bonds to the relaxed ones if there is a strain reference - show_strain = strain and hasattr(self, "relaxed_bonds") - if show_strain: - self.bonds = self.relaxed_bonds - - self.geometry.set_nsc(self.relaxed_geom.lattice.nsc) - else: - self.bonds = self.geom_bonds - - # We will initialize the colors list so that it is filled by - # the methods that generate them and we can at the end set the limits - # of the color scale - self.colors = [] - - # Let GeometryPlot set the data - for_backend = super()._set_data( - kwargs3d={ - "wrap_bond": partial(self._wrap_bond3D, show_strain=show_strain), - **kwargs3d - }, - kwargs2d={ - "wrap_bond": partial(self._wrap_bond2D, show_strain=show_strain), - "points_per_bond": points_per_bond, - **kwargs2d - }, - kwargs1d=kwargs1d - ) - - if self.colors: - for_backend["bonds_coloraxis"] = { - "cmin": cmin or min(self.colors), - "cmax": cmax or max(self.colors), - "colorscale": colorscale, - 'showscale': colorbar, - 'colorbar_title': 'Strain' if show_strain else 'Bond length [Ang]' - } - - return for_backend diff --git a/src/sisl/viz/plots/experimental/__init__.py b/src/sisl/viz/plots/experimental/__init__.py deleted file mode 100644 index df4b41bb82..0000000000 --- a/src/sisl/viz/plots/experimental/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" -This module contains plots that are not robust and probably not ready for production -but still might be helpful for people. - -Whoever uses them might also give helpful feedback to move them into the main plots folder! -""" - -# Somehow we need a way to don't break users codes when moving plots from experimental -# to production, while still makin it clear that they should expect bugs by using them -# in experimental mode -from .ldos import LDOSmap diff --git a/src/sisl/viz/plots/experimental/ldos.py b/src/sisl/viz/plots/experimental/ldos.py deleted file mode 100644 index 6dcd0b4320..0000000000 --- a/src/sisl/viz/plots/experimental/ldos.py +++ /dev/null @@ -1,412 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import os -import shutil - -import numpy as np - -import sisl - -from ...input_fields import ( - BoolInput, - ColorInput, - FloatInput, - IntegerInput, - OptionsInput, - ProgramaticInput, - QueriesInput, - RangeSliderInput, - TextInput, -) -from ...plot import Plot, entry_point -from ...plotutils import run_multiple - - -class LDOSmap(Plot): - """ - Generates a heat map with the STS spectra along a path. - - Parameters - ------------ - %%configurable_settings%% - - """ - - _plot_type = "LDOS map" - - _requirements = { - "siesOut": { - "files": ["$struct$.DIM", "$struct$.PLD", "*.ion", "$struct$.selected.WFSX"], - "codes": { - - "denchar": { - "reason": "The 'denchar' code is used in this case to generate STS spectra." - } - - } - }, - - } - - _parameters = ( - - RangeSliderInput( - key = "Erange", name = "Energy range", - default = [-2, 4], - width = "s90%", - params = { - "min": -10, - "max": 10, - "allowCross": False, - "step": 0.1, - "marks": {**{i: str(i) for i in range(-10, 11)}, 0: "Ef", }, - }, - help = "Energy range where the STS spectra are computed." - ), - - IntegerInput( - key = "nE", name = "Energy points", - default = 100, - params = { - "min": 1 - }, - help = "The number of energy points that are calculated for each spectra" - ), - - FloatInput( - key = "STSEta", name = "Smearing factor (eV)", - default = 0.05, - params = { - "min": 0.01, - "step": 0.01 - }, - help = """This determines the smearing factor of each STS spectra. You can play with this to modify sensibility in the vertical direction. -
If the smearing value is too high, your map will have a lot of vertical noise""" - ), - - FloatInput( - key = "dist_step", name = "Distance step (Ang)", - default = 0.1, - params = { - "min": 0, - "step": 0.01, - }, - help = "The step in distance between one point and the next one in the path." - ), - - ProgramaticInput( - key = "trajectory", name = "Trajectory", - default = [], - help = """You can directly provide a trajectory instead of the corner points.
- This option has preference over 'points', but can't be used through the GUI.
- It is useful if you want a non-straight trajectory.""" - ), - - ProgramaticInput( - key = "widen_func", name = "Widen function", - default = None, - help = """You can widen the path with this parameter. - This option has preference over 'widenX', 'widenY' and 'widenZ', but can't be used through the GUI.
- This must be a function that gets a point of the path and returns a set of points surrounding it (including the point itself).
- All points of the path must be widened with the same amount of points, otherwise you will get an error.""" - ), - - OptionsInput( - key = "widen_method", name = "Widen method", - default = "sum", - width = "s100% m50% l40%", - params = { - "options": [{"label": "Sum", "value": "sum"}, {"label": "Average", "value": "average"}], - "isMulti": False, - "placeholder": "", - "isClearable": False, - "isSearchable": True, - }, - help = "Determines whether values surrounding a point should be summed or averaged" - ), - - QueriesInput( - key = "points", name = "Path corners", - default = [{"x": 0, "y": 0, "z": 0, "atom": None, "active": True}], - queryForm = [ - - *[FloatInput( - key = key, name = key.upper(), - default = 0, - width = "s30%", - params = { - "step": 0.01 - } - ) for key in ("x", "y", "z")], - - OptionsInput( - key = "atom", name = "Atom index", - default = None, - params = { - "options": [], - "isMulti": False, - "placeholder": "", - "isClearable": True, - "isSearchable": True, - }, - help = """You can provide an atom index instead of the coordinates
- If an atom is provided, x, y and z will be interpreted as the supercell indices.
- That is: atom 23 [x=0,y=0,z=0] is atom 23 in the primary cell, while atom 23 [x=1,y=0,z=0] - is the image of atom 23 in the adjacent cell in the direction of x""" - ) - ], - help = """Provide the points to generate the path through which STS need to be calculated.""" - ), - - FloatInput( - key = "cmin", name = "Lower color limit", - default = 0, - params = { - "step": 10*-6 - }, - help = "All points below this value will be displayed as 0." - ), - - FloatInput( - key = "cmax", name = "Upper color limit", - default = 0, - params = { - "step": 10*-6 - }, - help = "All points above this value will be displayed as the maximum.
Decreasing this value will increase saturation." - ), - - ) - - _layout_defaults = { - 'xaxis_title': "Path coordinate", - 'yaxis_title': "E-Ef (eV)" - } - - def _getdencharSTSfdf(self, stsPosition, Erange, nE, STSEta): - - return """ - Denchar.PlotSTS .true. - Denchar.PlotWaveFunctions .false. - Denchar.PlotCharge .false. - - %block Denchar.STSposition - {} {} {} - %endblock Denchar.STSposition - - Denchar.STSEmin {} eV - Denchar.STSEmax {} eV - Denchar.STSEnergyPoints {} - Denchar.CoorUnits Ang - Denchar.STSEta {} eV - """.format(*stsPosition, *(np.array(Erange) + self.fermi), nE, STSEta) - - @entry_point('siesta') - def _read_siesta_output(self, Erange, nE, STSEta, root_fdf, trajectory, points, dist_step, widen_func): - """Function that uses denchar to get STSpecra along a path""" - import xarray as xr - - fdf_sile = self.get_sile(root_fdf) - root_dir = fdf_sile._directory - - self.geom = fdf_sile.read_geometry(output = True) - - #Find fermi level - self.fermi = False - for out_fileName in (self.struct, self.fdf_sile.base_file.replace(".fdf", "")): - try: - for line in open(fdf_sile.dir_file(f"{out_fileName}.out")): - if "Fermi =" in line: - self.fermi = float(line.split()[-1]) - print("\nFERMI LEVEL FOUND: {} eV\n Energies will be relative to this level (E-Ef)\n".format(self.fermi)) - break - except FileNotFoundError: - pass - - if not self.fermi: - print("\nFERMI LEVEL NOT FOUND IN THE OUTPUT FILE. \nEnergy values will be absolute\n") - self.fermi = 0 - - #Get the path (this also sets some attributes: 'distances', 'pointsByStage', 'totalPoints') - self._getPath(trajectory, points, dist_step, widen_func) - - #Prepare the array that will store all the spectra - self.spectra = np.zeros((self.path.shape[0], self.path.shape[1], nE)) - #Other helper arrays - pathIs = np.linspace(0, self.path.shape[0] - 1, self.path.shape[0]) - Epoints = np.linspace(*(np.array(Erange) + self.fermi), nE) - - #Copy selected WFSX into WFSX if it exists (denchar reads from .WFSX) - system_label = fdf_sile.get("SystemLabel", default="siesta") - shutil.copyfile(fdf_sile.dir_file(f"{system_label}.selected.WFSX"), - fdf_sile.dir_file(f"{system_label}.WFSX")) - - #Get the fdf file and replace include paths so that they work - with open(root_fdf, "r") as f: - self.fdfLines = f.readlines() - - for i, line in enumerate(self.fdfLines): - if "%include" in line and not os.path.isabs(line.split()[-1]): - - self.fdfLines[i] = "%include {}\n".format(os.path.join("../", line.split()[-1])) - - #Denchar needs to be run from the directory where everything is stored - cwd = os.getcwd() - os.chdir(root_dir) - - #Inform that the WFSX file is used so that changes in it can be followed - self.follow(fdf_sile.dir_file(f"{system_label}.WFSX")) - - def getSpectraForPath(argsTuple): - - path, nE, iPath, root_dir, struct, STSflags, args, kwargs = argsTuple - - #Generate a temporal directory so that we don't interfere with the other processes - tempDir = "{}tempSTS".format(iPath) - - os.makedirs(tempDir, exist_ok = True) - os.chdir(tempDir) - - tempFdf = os.path.join('{}STS.fdf'.format(struct)) - outputFile = os.path.join('{}.STS'.format(struct)) - - #Link all the needed files to this directory - os.system("ln -s ../*fdf ../*out ../*ion* ../*WFSX ../*DIM ../*PLD . ") - - spectra = []; failedPoints = 0 - - for i, point in enumerate(path): - - #Write the fdf - with open(tempFdf, "w") as fh: - fh.writelines(kwargs["fdfLines"]) - fh.write(STSflags[i]) - - #Do the STS calculation for the point - os.system("denchar < {} > /dev/null".format(tempFdf)) - - if i%100 == 0 and i != 0: - print("PATH {}. Points calculated: {}".format(int(iPath), i)) - - #Retrieve and save the output appropiately - try: - spectrum = np.loadtxt(outputFile) - - spectra.append(spectrum[:, 1]) - except Exception as e: - - print("Error calculating the spectra for point {}: \n{}".format(point, e)) - failedPoints += 1 - #If any spectrum was read, just fill it with zeros - spectra.append(np.zeros(nE)) - - if failedPoints: - print("Path {} finished with {} error{} ({}/{} points succesfully calculated)".format(int(iPath), failedPoints, "s" if failedPoints > 1 else "", len(path) - failedPoints, len(path))) - - os.chdir("..") - shutil.rmtree(tempDir, ignore_errors=True) - - return spectra - - self.spectra = run_multiple( - getSpectraForPath, - self.path, - nE, - pathIs, - root_dir, self.struct, - #All the strings that need to be added to each file - [[self._getdencharSTSfdf(point, Erange, nE, STSEta) for point in points] for points in self.path], - kwargsList = {"root_fdf": root_fdf, "fdfLines": self.fdfLines}, - messageFn = lambda nTasks, nodes: "Calculating {} simultaneous paths in {} nodes".format(nTasks, nodes), - serial = self.isChildPlot - ) - - self.spectra = np.array(self.spectra) - - #WITH XARRAY - self.xarr = xr.DataArray( - name = "LDOSmap", - data = self.spectra, - dims = ["iPath", "x", "E"], - coords = [pathIs, list(range(self.path.shape[1])), Epoints] - ) - - os.chdir(cwd) - - #Update the values for the limits so that they are automatically set - self.update_settings(run_updates = False, cmin = 0, cmax = 0) - - def _getPath(self, trajectory, points, dist_step, widen_func): - - if list(trajectory): - #If the user provides a trajectory, we are going to use that without questioning it - self.path = np.array(trajectory) - - #At the moment these make little sense, but in the future there will be the possibility to add breakpoints - self.pointsByStage = np.array([len(self.path)]) - self.distances = np.array([np.linalg.norm(self.path[-1] - self.path[0])]) - else: - #Otherwise, we will calculate the trajectory according to the points provided - points = [] - for reqPoint in points: - - if reqPoint.get("atom"): - translate = np.array([reqPoint.get("x", 0), reqPoint.get("y", 0), reqPoint.get("z", 0)]).dot(self.geom.cell) - points.append(self.geom[reqPoint["atom"]] + translate) - else: - points.append([reqPoint["x"], reqPoint["y"], reqPoint["z"]]) - points = np.array(points) - - nCorners = len(points) - if nCorners < 2: - raise ValueError("You need more than 1 point to generate a path! You better provide 2 next time...\n") - - #Generate an evenly distributed path along the points provided - self.path = [] - #This array will store the number of points that each stage has - self.pointsByStage = np.zeros(nCorners - 1) - self.distances = np.zeros(nCorners - 1) - - for i, point in enumerate(points[1:]): - - prevPoint = points[i] - - self.distances[i] = np.linalg.norm(point - prevPoint) - nSteps = int(round(self.distances[i]/dist_step)) + 1 - - #Add the trajectory from the previous point to this one to the path - self.path = [*self.path, *np.linspace(prevPoint, point, nSteps)] - - self.pointsByStage[i] = nSteps - - self.path = np.array(self.path) - - #Then, let's widen the path if the user wants to do it (check also points that surround the path) - if callable(widen_func): - self.path = widen_func(self.path) - else: - #This is just to normalize path - self.path = np.expand_dims(self.path, 0) - - #Store the total number of points of the path - self.nPathPoints = self.path.shape[1] - self.totalPoints = self.path.shape[0] * self.path.shape[1] - self.iCorners = self.pointsByStage.cumsum() - - def _set_data(self, widen_method, cmin, cmax, Erange, nE): - - #With xarray - if widen_method == "sum": - spectraToPlot = self.xarr.sum(dim = "iPath") - elif widen_method == "average": - spectraToPlot = self.xarr.mean(dim = "iPath") - - self.data = [{ - 'type': 'heatmap', - 'z': spectraToPlot.transpose("E", "x").values, - #These limits determine the contrast of the image - 'zmin': cmin, - 'zmax': cmax, - #Yaxis is the energy axis - 'y': np.linspace(*Erange, nE)}] diff --git a/src/sisl/viz/plots/fatbands.py b/src/sisl/viz/plots/fatbands.py deleted file mode 100644 index fd74eb9024..0000000000 --- a/src/sisl/viz/plots/fatbands.py +++ /dev/null @@ -1,437 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -import sisl -from sisl.physics.spin import Spin - -from ..input_fields import ( - BoolInput, - ColorInput, - FloatInput, - OrbitalQueries, - SileInput, - TextInput, -) -from ..plot import entry_point -from ..plotutils import random_color -from .bands import BandsPlot - - -class FatbandsPlot(BandsPlot): - """Colorful representation of orbital weights in bands. - - Parameters - ------------- - wfsx_file: wfsxSileSiesta, optional - The WFSX file to get the weights of the different orbitals in the - bands. In standard SIESTA nomenclature, this should be - the *.bands.WFSX file, as it is the one that contains the - weights that correspond to the bands. This - file is only meaningful (and required) if fatbands are plotted from - the .bands file. Otherwise, the bands and weights will be - generated from the hamiltonian by sisl. If the *.bands - file is provided but the wfsx one isn't, we will try to find it. - If `bands_file` is SystemLabel.bands, we will look for - SystemLabel.bands.WFSX - scale: float, optional - The factor by which the width of all fatbands should be multiplied. - Note that each group has an additional individual factor that you can - also tweak. - groups: array-like of dict, optional - The different groups that are displayed in the fatbands Each item - is a dict. Structure of the dict: { 'name': - 'species': 'atoms': Structure of the dict: { - 'index': Structure of the dict: { 'in': } 'fx': - 'fy': 'fz': 'x': 'y': 'z': - 'Z': 'neighbours': Structure of the dict: { - 'range': 'R': 'neigh_tag': } 'tag': - 'seq': } 'orbitals': 'spin': 'normalize': - 'color': 'scale': } - bands_file: bandsSileSiesta, optional - This parameter explicitly sets a .bands file. Otherwise, the bands - file is attempted to read from the fdf file - band_structure: BandStructure, optional - A band structure. it can either be provided as a sisl.BandStructure - object or as a list of points, which will be parsed into a - band structure object. Each item is a dict. Structure - of the dict: { 'x': 'y': 'z': - 'divisions': 'names': Tick that should be displayed at this - corner of the path. } - aiida_bands: optional - An aiida BandsData node. - add_band_data: optional - This function receives each band and should return a dictionary with - additional arguments that are passed to the band drawing - routine. It also receives the plot as the second argument. - See the docs of `sisl.viz.backends.templates.Backend.draw_line` to - understand what are the supported arguments to be - returned. Notice that the arguments that the backend is able to - process can be very framework dependant. - Erange: array-like of shape (2,), optional - Energy range where the bands are displayed. - E0: float, optional - The energy to which all energies will be referenced (including - Erange). - bands_range: array-like of shape (2,), optional - The bands that should be displayed. Only relevant if Erange is None. - spin: optional - Determines how the different spin configurations should be displayed. - In spin polarized calculations, it allows you to choose between spin - 0 and 1. In non-colinear spin calculations, it allows you - to ask for a given spin texture, by specifying the - direction. - spin_texture_colorscale: str, optional - The plotly colorscale to use for the spin texture (if displayed) - gap: bool, optional - Whether the gap should be displayed in the plot - direct_gaps_only: bool, optional - Whether to show only gaps that are direct, according to the gap - tolerance - gap_tol: float, optional - The difference in k that must exist to consider to gaps - different. If two gaps' positions differ in less than - this, only one gap will be drawn. Useful in cases - where there are degenerated bands with exactly the same values. - gap_color: str, optional - Color to display the gap - custom_gaps: array-like of dict, optional - List of all the gaps that you want to display. Each item is a dict. - Structure of the dict: { 'from': K value where to start - measuring the gap. It can be either the label of - the k-point or the numeric value in the plot. 'to': K value - where to end measuring the gap. It can be either - the label of the k-point or the numeric value in the plot. - 'color': The color with which the gap should be displayed - 'spin': The spin components where the gap should be calculated. } - bands_width: float, optional - Width of the lines that represent the bands - bands_color: str, optional - Choose the color to display the bands. This will be used for the - spin up bands if the calculation is spin polarized - spindown_color: str, optional - Choose the color for the spin down bands.Only used if the - calculation is spin polarized. - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - """ - - _plot_type = 'Fatbands' - - _update_methods = { - "read_data": [], - "set_data": ["_draw_gaps", "_get_groups_weights"], - "get_figure": [] - } - - _parameters = ( - - FloatInput(key='scale', name='Scale factor', - default=None, - help="""The factor by which the width of all fatbands should be multiplied. - Note that each group has an additional individual factor that you can also tweak.""" - # Probably scale should not multiply but normalize everything relative to the energy range! - ), - - OrbitalQueries( - key="groups", name="Fatbands groups", - default=None, - help="""The different groups that are displayed in the fatbands""", - queryForm=[ - - TextInput( - key="name", name="Name", - default="Group", - params={ - "placeholder": "Name of the line..." - }, - ), - - 'species', 'atoms', 'orbitals', 'spin', - - BoolInput( - key="normalize", name="Normalize", - default=True, - params={ - "offLabel": "No", - "onLabel": "Yes" - } - ), - - ColorInput( - key="color", name="Color", - default=None, - ), - - FloatInput( - key="scale", name="Scale factor", - default=1, - ), - ] - ), - - ) - - @property - def weights(self): - return self.bands_data["weight"] - - @entry_point("wfsx file", 0) - def _read_from_wfsx(self, root_fdf, wfsx_file): - """Generates fatbands from SIESTA output. - - Uses the `.wfsx` file to retrieve the eigenstates. From them, it computes - all the needed quantities (eigenvalues, orbital contribution, ...). - """ - self._entry_point_with_extra_vars(super()._read_from_wfsx, need_H=True) - - @entry_point("hamiltonian", 1) - def _read_from_H(self): - """Calculates the fatbands from a sisl hamiltonian.""" - self._entry_point_with_extra_vars(super()._read_from_H) - - def _entry_point_with_extra_vars(self, entry_point, **kwargs): - # Define the function that will "catch" each eigenstate and - # build the weights array. See BandsPlot._read_from_H to understand where - # this will go exactly - def _weights_from_eigenstate(eigenstate, plot, spin_index): - - weights = eigenstate.norm2(sum=False) - - if not plot.spin.is_diagonal: - # If it is a non-colinear or spin orbit calculation, we have two weights for each - # orbital (one for each spin component of the state), so we just pair them together - # and sum their contributions to get the weight of the orbital. - weights = weights.reshape(len(weights), -1, 2).sum(2) - - return weights.real - - # We make bands plot read the bands, which will also populate the weights - # thanks to the above step - bands_read = False; err = None - try: - entry_point(extra_vars=[{"coords": ("band", "orb"), "name": "weight", "getter": _weights_from_eigenstate}], **kwargs) - bands_read = True - except Exception as e: - # Let's keep this error, we are going to at least set the group options so that the - # user knows what can they choose (specially important for the GUI) - err = e - - self._set_group_options() - if not bands_read: - raise err - - def _set_group_options(self): - - # Try to find a geometry if there isn't already one - if not hasattr(self, "geometry"): - - # From the hamiltonian - band_struct = self.get_setting("band_structure") - if band_struct is not None: - self.geometry = band_struct.parent.geometry - - self.get_param('groups').update_options(self.geometry, self.spin) - - def _set_data(self): - # We get the information that the Bandsplot wants to send to the drawer - from_bandsplot = super()._set_data() - - # And add some extra information related to the weights. - return { - **from_bandsplot, - **self._get_groups_weights() - } - - def _get_groups_weights(self, groups, E0, bands_range, scale): - """Returns a dictionary with information about all the weights that have been requested - The return of this function is expected to be passed to the drawers. - """ - # We get the bands range that is going to be plotted - # Remember that the BandsPlot will have updated this setting accordingly, - # so it's safe to use it directly - min_band, max_band = bands_range - - # Get the weights that matter - plot_weights = self.weights.sel(band=slice(min_band, max_band)) - - if groups is None: - groups = () - - if scale is None: - # Probably we can calculate a more suitable scale - scale = 1 - - groups_weights = {} - groups_metadata = {} - # Here we get the values of the weights for each group of orbitals. - for i, group in enumerate(groups): - group = {**group} - - # Use only the active requests - if not group.get("active", True): - continue - - # Give a name to the request in case it didn't have one. - if group.get("name") is None: - group["name"] = f"Group {i}" - - # Multiply the groups' scale by the global scale - group["scale"] = group.get("scale", 1) * scale - - # Get the weight values for the request and store them to send to the drawer - self._get_group_weights(group, plot_weights, values_storage=groups_weights, metadata_storage=groups_metadata) - - return {"groups_weights": groups_weights, "groups_metadata": groups_metadata} - - def _get_group_weights(self, group, weights=None, values_storage=None, metadata_storage=None): - """Extracts the weight values that correspond to a specific fatbands request. - Parameters - -------------- - group: dict - the request to process. - weights: DataArray, optional - the part of the weights dataarray that falls in the energy range that we want to draw. - If not provided, the full weights data stored in `self.weights` is used. - values_storage: dict, optional - a dictionary where the weights values will be stored using the request's name as the key. - metadata_storage: dict, optional - a dictionary where metadata for the request will be stored using the request's name as the key. - Returns - ---------- - xarray.DataArray - The weights resulting from the request. They are indexed by spin, band and k value. - """ - - if weights is None: - weights = self.weights - if "spin" not in weights.coords: - weights = weights.expand_dims("spin") - - groups_param = self.get_param("groups") - - group = groups_param.complete_query(group) - - orb = groups_param.get_orbitals(group) - - # Get the weights for the requested orbitals - weights = weights.sel(orb=orb) - - # Now get a particular spin component if the user wants it - if group["spin"] is not None: - weights = weights.sel(spin=group["spin"]) - - if group["normalize"]: - weights = weights.mean("orb") - else: - weights = weights.sum("orb") - - if group["color"] is None: - group["color"] = random_color() - - group_name = group["name"] - values = weights.transpose("spin", "band", "k") * group["scale"] - - if values_storage is not None: - if group_name in values_storage: - raise ValueError(f"There are multiple groups that are named '{group_name}'") - values_storage[group_name] = values - - if metadata_storage is not None: - # Build the dictionary that contains metadata for this group. - metadata = { - "style": { - "line": {"color": group["color"]} - } - } - - metadata_storage[group_name] = metadata - - return values - - # ------------------------------------- - # Convenience methods - # ------------------------------------- - - def split_groups(self, on="species", only=None, exclude=None, clean=True, colors=(), **kwargs): - """ - Builds groups automatically to draw their contributions. - Works exactly the same as `PdosPlot.split_DOS` - Parameters - -------- - on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"} or list of str - the parameter to split along. - Note that you can combine parameters with a "+" to split along multiple parameters - at the same time. You can get the same effect also by passing a list. - only: array-like, optional - if desired, the only values that should be plotted out of - all of the values that come from the splitting. - exclude: array-like, optional - values that should not be plotted - clean: boolean, optional - whether the plot should be cleaned before drawing. - If False, all the groups that come from the method will - be drawn on top of what is already there. - colors: array-like, optional - A list of colors to be used. There can be more colors than - needed, or less. If there are less colors than groups, the colors - will just be repeated. - **kwargs: - keyword arguments that go directly to each request. - This is useful to add extra filters. For example: - `plot.split_groups(on="orbitals", species=["C"])` - will split the PDOS on the different orbitals but will take - only those that belong to carbon atoms. - Examples - ----------- - >>> plot = H.plot.fatbands() - >>> - >>> # Split the fatbands in n and l but show only the fatbands from Au - >>> # Also use "Au $ns" as a template for the name, where $n will - >>> # be replaced by the value of n. - >>> plot.split_groups(on="n+l", species=["Au"], name="Au $ns") - """ - groups = self.get_param('groups')._generate_queries( - on=on, only=only, exclude=exclude, **kwargs) - - if len(colors) > 0: - # Repeat the colors in case there are more groups than colors - colors = np.tile(colors, len(groups) // len(colors) + 1) - - # Asign colors - for i, _ in enumerate(groups): - groups[i]['color'] = colors[i] - - # If the user doesn't want to clean the plot, we will just add the groups to the existing ones - if not clean: - groups = [*self.get_setting("groups"), *groups] - - return self.update_settings(groups=groups) - - def scale_fatbands(self, factor, from_current=False): - """Scales all bands by a given factor. - Basically, it updates 'scale' setting. - Parameters - ----------- - factor: float - the factor that should be used to scale. - from_current: boolean, optional - whether 'factor' is meant to multiply the current scaling factor. - If False, it will just replace the current factor. - """ - - if from_current: - scale = self.get_setting('scale') * factor - else: - scale = factor - - return self.update_settings(scale=scale) diff --git a/src/sisl/viz/plots/geometry.py b/src/sisl/viz/plots/geometry.py index a1c3d3ea1c..580f8afaf0 100644 --- a/src/sisl/viz/plots/geometry.py +++ b/src/sisl/viz/plots/geometry.py @@ -1,1132 +1,333 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import itertools -import re -from functools import wraps +from typing import Callable, Literal, Optional, Sequence, Tuple, TypeVar, Union import numpy as np -from sisl import Atom, AtomGhost, Geometry, PeriodicTable -from sisl._dispatcher import AbstractDispatch -from sisl._lattice import cell_invert -from sisl.messages import warn -from sisl.utils import direction -from sisl.utils.mathematics import fnorm - -from ..input_fields import ( - Array1DInput, - AtomSelect, - BoolInput, - ColorInput, - DictInput, - FilePathInput, - FloatInput, - GeomAxisSelect, - IntegerInput, - OptionsInput, - PlotableInput, - ProgramaticInput, - QueriesInput, - TextInput, +from sisl import BrillouinZone, Geometry +from sisl.typing import AtomsArgument +from sisl.viz.figure import Figure, get_figure +from sisl.viz.plotters import plot_actions as plot_actions +from sisl.viz.types import AtomArrowSpec, AtomsStyleSpec, Axes, StyleSpec + +from ..plot import Plot +from ..plotters.cell import cell_plot_actions, get_ndim, get_z +from ..plotters.xarray import draw_xarray_xy +from ..processors.axes import sanitize_axes +from ..processors.coords import project_to_axes +from ..processors.geometry import ( + add_xyz_to_bonds_dataset, + add_xyz_to_dataset, + bonds_to_lines, + find_all_bonds, + get_sites_units, + parse_atoms_style, + sanitize_arrows, + sanitize_atoms, + sanitize_bonds_selection, + sites_obj_to_geometry, + stack_sc_data, + style_bonds, + tile_data_sc, ) -from ..plot import Plot, entry_point -from ..plotutils import values_to_colors - - -class BoundGeometry(AbstractDispatch): - """ - Updates the plot after a method is run on the plot's geometry. - """ - - def __init__(self, geom, parent_plot): - - self.parent_plot = parent_plot - super().__init__(geom) +from ..processors.logic import matches, switch +from ..processors.xarray import scale_variable, select - def dispatch(self, method): - @wraps(method) - def with_plot_update(*args, **kwargs): - - ret = method(*args, **kwargs) - - # Maybe the returned value is not a geometry - if isinstance(ret, Geometry): - self.parent_plot.update_settings(geometry=ret) - return self.parent_plot.on_geom - - return ret - - return with_plot_update - - -class GeometryPlot(Plot): - """ - Versatile representation of geometries. - - This class contains all functions necessary to plot geometries in very diverse ways. +def _get_atom_mode(drawing_mode, ndim): + if drawing_mode is None: + if ndim == 3: + return 'balls' + else: + return 'scatter' + + return drawing_mode + +def _get_arrow_plottings(atoms_data, arrows, nsc=[1,1,1]): + + reps = np.prod(nsc) + actions = [] + atoms_data = atoms_data.unstack("sc_atom") + for arrows_spec in arrows: + filtered = atoms_data.sel(atom=arrows_spec['atoms']) + dxy = arrows_spec['data'][arrows_spec['atoms']] + dxy = np.tile(np.ravel(dxy), reps).reshape(-1, arrows_spec['data'].shape[-1]) + + # If it is a 1D plot, make sure that the arrows have two coordinates, being 0 the second one. + if dxy.shape[-1] == 1: + dxy = np.array([dxy[:, 0], np.zeros_like(dxy[:, 0])]).T + + kwargs = {} + kwargs['line'] = {'color': arrows_spec['color'], 'width': arrows_spec['width'], 'opacity': arrows_spec.get('opacity', 1)} + kwargs['name'] = arrows_spec['name'] + kwargs['arrowhead_scale'] = arrows_spec['arrowhead_scale'] + kwargs['arrowhead_angle'] = arrows_spec['arrowhead_angle'] + kwargs['annotate'] = arrows_spec.get('annotate', False) + kwargs['scale'] = arrows_spec['scale'] + + if dxy.shape[-1] < 3: + action = plot_actions.draw_arrows(x=np.ravel(filtered.x), y=np.ravel(filtered.y), dxy=dxy, **kwargs) + else: + action = plot_actions.draw_arrows_3D(x=np.ravel(filtered.x), y=np.ravel(filtered.y), z=np.ravel(filtered.z), dxyz=dxy, **kwargs) + actions.append(action) + + return actions + +def _sanitize_scale(scale: float, ndim: int, ndim_scale: Tuple[float, float, float] = (16, 16, 1)): + return ndim_scale[ndim-1] * scale + +def geometry_plot(geometry: Geometry, + axes: Axes = ["x", "y", "z"], + atoms: AtomsArgument = None, + atoms_style: Sequence[AtomsStyleSpec] = [], + atoms_scale: float = 1., + atoms_colorscale: Optional[str] = None, + drawing_mode: Literal["scatter", "balls", None] = None, + bind_bonds_to_ats: bool = True, + points_per_bond: int = 20, + bonds_style: StyleSpec = {}, + bonds_scale: float = 1., + bonds_colorscale: Optional[str] = None, + show_atoms: bool = True, + show_bonds: bool = True, + show_cell: Literal["box", "axes", False] = "box", + cell_style: StyleSpec = {}, + nsc: Tuple[int, int, int] = (1, 1, 1), + atoms_ndim_scale: Tuple[float, float, float] = (16, 16, 1), + bonds_ndim_scale: Tuple[float, float, float] = (1, 1, 10), + dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, + arrows: Sequence[AtomArrowSpec] = (), + backend="plotly", +) -> Figure: + """Plots a geometry structure, with plentiful of customization options. + Parameters - ------------- - geometry: Geometry, optional - A geometry object - geom_file: str, optional - A file name that can read a geometry - show_bonds: bool, optional - Show bonds between atoms. - bonds_style: dict, optional - Customize the style of the bonds by passing style specifications. - Currently, you can only pass one style specification. Styling bonds - individually is not supported yet, but it will be in the future. - Structure of the dict: { } - axes: optional - The axis along which you want to see the geometry. You - can provide as many axes as dimensions you want for your plot. - Note that the order is important and will result in setting the plot - axes diferently. For 2D and 1D representations, you can - pass an arbitrary direction as an axis (array of shape (3,)) - dataaxis_1d: array-like or function, optional - If you want a 1d representation, you can provide a data axis. - It determines the second coordinate of the atoms. - If it's a function, it will recieve the projected 1D coordinates and - needs to returns the coordinates for the other axis as - an array. If not provided, the other axis - will just be 0 for all points. - show_cell: optional - Specifies how the cell should be rendered. (False: not - rendered, 'axes': render axes only, 'box': render a bounding box) - nsc: array-like, optional - Make the geometry larger by tiling it along each lattice vector - atoms: dict, optional - The atoms that are going to be displayed in the plot. - This also has an impact on bonds (see the `bind_bonds_to_ats` and - `show_atoms` parameters). If set to None, all atoms are - displayed Structure of the dict: { 'index': Structure of - the dict: { 'in': } 'fx': 'fy': - 'fz': 'x': 'y': 'z': 'Z': - 'neighbours': Structure of the dict: { 'range': - 'R': 'neigh_tag': } 'tag': 'seq': } - atoms_style: array-like of dict, optional - Customize the style of the atoms by passing style specifications. - Each style specification can have an "atoms" key to select the atoms - for which that style should be used. If an atom fits into - more than one selector, the last specification is used. - Each item is a dict. Structure of the dict: { 'atoms': - Structure of the dict: { 'index': Structure of the dict: { - 'in': } 'fx': 'fy': 'fz': 'x': - 'y': 'z': 'Z': 'neighbours': Structure - of the dict: { 'range': 'R': 'neigh_tag': - } 'tag': 'seq': } 'color': 'size': - 'opacity': 'vertices': In a 3D representation, the number of - vertices that each atom sphere is composed of. } - arrows: array-like of dict, optional - Add arrows centered at the atoms to display some vector property. - You can add as many arrows as you want, each with different styles. - Each item is a dict. Structure of the dict: { 'atoms': - Structure of the dict: { 'index': Structure of the dict: { - 'in': } 'fx': 'fy': 'fz': 'x': - 'y': 'z': 'Z': 'neighbours': Structure - of the dict: { 'range': 'R': 'neigh_tag': - } 'tag': 'seq': } 'data': 'scale': - 'color': 'width': 'name': - 'arrowhead_scale': 'arrowhead_angle': } - atoms_scale: float, optional - A scaling factor for atom sizes. This is a very quick way to rescale. - atoms_colorscale: str, optional - The colorscale to use to map values to colors for the atoms. - Only used if atoms_color is provided and is an array of values. - bind_bonds_to_ats: bool, optional - whether only the bonds that belong to an atom that is present should - be displayed. If False, all bonds are displayed - regardless of the `atoms` parameter - show_atoms: bool, optional - If set to False, it will not display atoms. Basically - this is a shortcut for ``atoms = [], bind_bonds_to_ats=False``. - Therefore, it will override these two parameters. - points_per_bond: int, optional - Number of points that fill a bond in 2D in case each bond has a - different color or different size. More points will make it look - more like a line but will slow plot rendering down. - cell_style: dict, optional - The style of the unit cell lines Structure of the dict: { - 'color': 'width': 'opacity': } - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. + ---------- + geometry: + The geometry to plot. + axes: + The axes to project the geometry to. + atoms: + The atoms to plot. If None, all atoms are plotted. + atoms_style: + List of style specifications for the atoms. See the showcase notebooks for examples. + atoms_scale: + Scaling factor for the size of all atoms. + atoms_colorscale: + Colorscale to use for the atoms in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + drawing_mode: + The method used to draw the atoms. + bind_bonds_to_ats: + Whether to display only bonds between atoms that are being displayed. + points_per_bond: + When the points are drawn using points instead of lines (e.g. in some frameworks + to draw multicolor bonds), the number of points used per bond. + bonds_style: + Style specification for the bonds. See the showcase notebooks for examples. + bonds_scale: + Scaling factor for the width of all bonds. + bonds_colorscale: + Colorscale to use for the bonds in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + show_atoms: + Whether to display the atoms. + show_bonds: + Whether to display the bonds. + show_cell: + Mode to display the cell. If False, the cell is not displayed. + cell_style: + Style specification for the cell. See the showcase notebooks for examples. + nsc: + Number of unit cells to display in each direction. + atoms_ndim_scale: + Scaling factor for the size of the atoms for different dimensionalities (1D, 2D, 3D). + bonds_ndim_scale: + Scaling factor for the width of the bonds for different dimensionalities (1D, 2D, 3D). + dataaxis_1d: + Only meaningful for 1D plots. The data to plot on the Y axis. + arrows: + List of arrow specifications to display. See the showcase notebooks for examples. + backend: + The backend to use to generate the figure. """ - _plot_type = "Geometry" - - _param_groups = ( - { - "key": "cell", - "name": "Cell display", - "icon": "check_box_outline_blank", - "description": "These are all inputs related to the geometry's cell." - }, - - { - "key": "atoms", - "name": "Atoms display", - "icon": "album", - "description": "Inputs related to which and how atoms are displayed." - }, - - { - "key": "bonds", - "name": "Bonds display", - "icon": "power_input", - "description": "Inputs related to which and how bonds are displayed." - }, - + # INPUTS ARE NOT GETTING PARSED BECAUSE WORKFLOWS RUN GET ON FINAL NODE + # SO PARSING IS DELEGATED TO NODES. + axes = sanitize_axes(axes) + sanitized_atoms = sanitize_atoms(geometry, atoms=atoms) + ndim = get_ndim(axes) + z = get_z(ndim) + + # Atoms and bonds are processed in parallel paths, which means that one is able + # to update without requiring the other. This means: 1) Faster updates if only one + # of them needs to update; 2) It should be possible to run each path in a different + # thread/process, potentially increasing speed. + parsed_atom_style = parse_atoms_style(geometry, atoms_style=atoms_style) + atoms_dataset = add_xyz_to_dataset(parsed_atom_style) + atoms_filter = switch(show_atoms, sanitized_atoms, []) + filtered_atoms = select(atoms_dataset, "atom", atoms_filter) + tiled_atoms = tile_data_sc(filtered_atoms, nsc=nsc) + sc_atoms = stack_sc_data(tiled_atoms, newname="sc_atom", dims=["atom"]) + projected_atoms = project_to_axes(sc_atoms, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d) + + atoms_scale = _sanitize_scale(atoms_scale, ndim, atoms_ndim_scale) + final_atoms = scale_variable(projected_atoms, "size", scale=atoms_scale) + atom_mode = _get_atom_mode(drawing_mode, ndim) + atom_plottings = draw_xarray_xy( + data=final_atoms, x="x", y="y", z=z, width="size", what=atom_mode, colorscale=atoms_colorscale, + set_axequal=True, name="Atoms" ) - - _parameters = ( - - PlotableInput(key='geometry', name="Geometry", - dtype=Geometry, - default=None, - group="dataread", - help="A geometry object", - ), - - FilePathInput(key="geom_file", name="Geometry file", - group="dataread", - default=None, - help="A file name that can read a geometry", - ), - - BoolInput(key='show_bonds', name='Show bonds', - default=True, - group="bonds", - help="Show bonds between atoms." - ), - - DictInput(key="bonds_style", name="Bonds style", - default={}, - group="bonds", - help = """Customize the style of the bonds by passing style specifications. - Currently, you can only pass one style specification. Styling bonds - individually is not supported yet, but it will be in the future. - """, - queryForm = [ - - ColorInput(key="color", name="Color", default="#cccccc"), - - FloatInput(key="width", name="Width", default=None), - - FloatInput(key="opacity", name="Opacity", - default=1, - params={"min": 0, "max": 1}, - ), - - ] - ), - - GeomAxisSelect( - key="axes", name="Axes to display", - default=["x", "y", "z"], - group="cell", - help="""The axis along which you want to see the geometry. - You can provide as many axes as dimensions you want for your plot. - Note that the order is important and will result in setting the plot axes diferently. - For 2D and 1D representations, you can pass an arbitrary direction as an axis (array of shape (3,))""" - ), - - ProgramaticInput( - key="dataaxis_1d", name="1d data axis", - default=None, - dtype="array-like or function", - help="""If you want a 1d representation, you can provide a data axis. - It determines the second coordinate of the atoms. - - If it's a function, it will recieve the projected 1D coordinates and needs to returns - the coordinates for the other axis as an array. - - If not provided, the other axis will just be 0 for all points. - """ - ), - - OptionsInput(key="show_cell", name="Cell display", - default="box", - params={ - 'options': [ - {'label': 'False', 'value': False}, - {'label': 'axes', 'value': 'axes'}, - {'label': 'box', 'value': 'box'} - ], - 'isMulti': False, - 'isSearchable': True, - 'isClearable': False - }, - group="cell", - help="""Specifies how the cell should be rendered. - (False: not rendered, 'axes': render axes only, 'box': render a bounding box)""" - ), - - Array1DInput( - key="nsc", name="Supercell", - default=[1, 1, 1], - params={ - 'inputType': 'number', - 'shape': (3,), - 'extendable': False, - }, - group="cell", - help="""Make the geometry larger by tiling it along each lattice vector""" - ), - - AtomSelect(key="atoms", name="Atoms to display", - default=None, - params={ - "options": [], - "isSearchable": True, - "isMulti": True, - "isClearable": True - }, - group="atoms", - help="""The atoms that are going to be displayed in the plot. - This also has an impact on bonds (see the `bind_bonds_to_ats` and `show_atoms` parameters). - If set to None, all atoms are displayed""" - ), - - QueriesInput(key="atoms_style", name="Atoms style", - default=[], - group="atoms", - help = """Customize the style of the atoms by passing style specifications. - Each style specification can have an "atoms" key to select the atoms for which - that style should be used. If an atom fits into more than one selector, the last - specification is used. - """, - queryForm = [ - - AtomSelect(key="atoms", name="Atoms", default=None), - - ColorInput(key="color", name="Color", default=None), - - FloatInput(key="size", name="Size", default=None), - - FloatInput(key="opacity", name="Opacity", - default=1, - params={"min": 0, "max": 1}, - ), - - IntegerInput(key="vertices", name="Vertices", default=15, - help="""In a 3D representation, the number of vertices that each atom sphere is composed of."""), - - ] - ), - - QueriesInput(key="arrows", name="Arrows", - default=[], - group="atoms", - help = """Add arrows centered at the atoms to display some vector property. - You can add as many arrows as you want, each with different styles.""", - queryForm = [ - - AtomSelect(key="atoms", name="Atoms", default=None), - - Array1DInput(key="data", name="Data", default=None, params={"shape": (3,)}), - - FloatInput(key="scale", name="Scale", default=1), - - ColorInput(key="color", name="Color", default=None), - - FloatInput(key="width", name="Width", default=None), - - TextInput(key="name", name="Name", default=None), - - FloatInput(key="arrowhead_scale", name="Arrowhead scale", default=0.2), - - FloatInput(key="arrowhead_angle", name="Arrowhead angle", default=20), - ] - ), - - FloatInput(key="atoms_scale", name="Atoms scale", - default=1., - group="atoms", - help="A scaling factor for atom sizes. This is a very quick way to rescale." - ), - - TextInput(key="atoms_colorscale", name="Atoms colorscale", - group="atoms", - default="viridis", - help="""The colorscale to use to map values to colors for the atoms. - Only used if atoms_color is provided and is an array of values.""" - ), - - BoolInput(key="bind_bonds_to_ats", name="Bind bonds to atoms", - default=True, - group="bonds", - help="""whether only the bonds that belong to an atom that is present should be displayed. - If False, all bonds are displayed regardless of the `atoms` parameter""" - ), - - BoolInput(key="show_atoms", name="Show atoms", - default=True, - group="atoms", - help="""If set to False, it will not display atoms. - Basically this is a shortcut for ``atoms = [], bind_bonds_to_ats=False``. - Therefore, it will override these two parameters.""" - ), - - IntegerInput( - key="points_per_bond", name="Points per bond", - group="bonds", - default=10, - help="Number of points that fill a bond in 2D in case each bond has a different color or different size.
More points will make it look more like a line but will slow plot rendering down." - ), - - DictInput(key="cell_style", name="Cell style", - default={"color": "green"}, - group="cell", - help="""The style of the unit cell lines""", - fields=[ - ColorInput(key="color", name="Color", default="green"), - - FloatInput(key="width", name="Width", default=None), - - FloatInput(key="opacity", name="Opacity", default=1), - ] - ), - + + # Here we start to process bonds + bonds = find_all_bonds(geometry) + show_bonds = matches(ndim, 1, False, show_bonds) + styled_bonds = style_bonds(bonds, bonds_style) + bonds_dataset = add_xyz_to_bonds_dataset(styled_bonds) + bonds_filter = sanitize_bonds_selection(bonds_dataset, sanitized_atoms, bind_bonds_to_ats, show_bonds) + filtered_bonds = select(bonds_dataset, "bond_index", bonds_filter) + tiled_bonds = tile_data_sc(filtered_bonds, nsc=nsc) + + projected_bonds = project_to_axes(tiled_bonds, axes=axes) + bond_lines = bonds_to_lines(projected_bonds, points_per_bond=points_per_bond) + + bonds_scale = _sanitize_scale(bonds_scale, ndim, bonds_ndim_scale) + final_bonds = scale_variable(bond_lines, "width", scale=bonds_scale) + bond_plottings = draw_xarray_xy(data=final_bonds, x="x", y="y", z=z, set_axequal=True, name="Bonds", colorscale=bonds_colorscale) + + # And now the cell + show_cell = matches(ndim, 1, False, show_cell) + cell_plottings = cell_plot_actions( + cell=geometry, show_cell=show_cell, cell_style=cell_style, + axes=axes, dataaxis_1d=dataaxis_1d ) + + # And the arrows + arrow_data = sanitize_arrows(geometry, arrows, atoms=sanitized_atoms, ndim=ndim, axes=axes) + arrow_plottings = _get_arrow_plottings(projected_atoms, arrow_data, nsc=nsc) - # Colors of the atoms following CPK rules - _atoms_colors = { - "H": "#cccccc", # Should be white but the default background is white - "O": "red", - "Cl": "green", - "N": "blue", - "C": "grey", - "S": "yellow", - "P": "orange", - "Au": "gold", - "else": "pink" - } - - _pt = PeriodicTable() - - _update_methods = { - "read_data": [], - "set_data": ["_prepare1D", "_prepare2D", "_prepare3D"], - "get_figure": [] - } - - @entry_point('geometry', 0) - def _read_nosource(self, geometry): - """ - Reads directly from a sisl geometry. - """ - self.geometry = geometry or getattr(self, "geometry", None) - - if self.geometry is None: - raise ValueError("No geometry has been provided.") - - @entry_point('geometry file', 1) - def _read_siesta_output(self, geom_file, root_fdf): - """ - Reads from a sile that contains a geometry using the `read_geometry` method. - """ - geom_file = geom_file or root_fdf - - self.geometry = self.get_sile(geom_file).read_geometry() - - def _after_read(self, show_bonds, nsc): - # Tile the geometry. It shouldn't be done here, since we will need to calculate the bonds for - # the whole supercell. FIND A SMARTER WAY!! - self._tiled_geometry = self.geometry - for ax, reps in enumerate(nsc): - self._tiled_geometry = self._tiled_geometry.tile(reps, ax) - - if show_bonds: - self.bonds = self.find_all_bonds(self._tiled_geometry) - - self.get_param("atoms").update_options(self.geometry) - self.get_param("atoms_style").get_param("atoms").update_options(self.geometry) - self.get_param("arrows").get_param("atoms").update_options(self.geometry) - - def _parse_atoms_style(self, atoms_style, ndim): - """Parses the `atoms_style` setting to a dictionary of style specifications. - - Parameters - ----------- - atoms_style: - the value of the atoms_style setting. - ndim: int - the number of dimensions of the plot, only used for the default atom sizes. - """ - - # Set the radius scale for the different representations (1D and 2D measure size in pixels, - # while in 3D this is a real distance) - radius_scale = [16, 16, 1][ndim-1] - - # Add the default styles first - atoms_style = [ - { - "color": [self.atom_color(atom.Z) for atom in self.geometry.atoms], - "size": [self._pt.radius(abs(atom.Z))*radius_scale for atom in self.geometry.atoms], - "opacity": [0.4 if isinstance(atom, AtomGhost) else 1 for atom in self.geometry.atoms], - "vertices": 15, - }, - *atoms_style - ] - - def _tile_if_needed(atoms, spec): - """Function that tiles an array style specification. - - It does so if the specification needs to be applied to more atoms - than items are in the array.""" - if isinstance(spec, (tuple, list, np.ndarray)): - n_ats = len(atoms) - n_spec = len(spec) - if n_ats != n_spec and n_ats % n_spec == 0: - spec = np.tile(spec, n_ats // n_spec) - return spec - - # Initialize the styles. - parsed_atoms_style = { - "color": np.empty((self.geometry.na, ), dtype=object), - "size": np.empty((self.geometry.na, ), dtype=float), - "vertices": np.empty((self.geometry.na, ), dtype=int), - "opacity": np.empty((self.geometry.na), dtype=float), - } - - # Go specification by specification and apply the styles - # to the corresponding atoms. - for style_spec in atoms_style: - atoms = self.geometry._sanitize_atoms(style_spec.get("atoms")) - for key in parsed_atoms_style: - if style_spec.get(key) is not None: - parsed_atoms_style[key][atoms] = _tile_if_needed(atoms, style_spec[key]) - - return parsed_atoms_style - - def _parse_arrows(self, arrows, atoms, ndim, axes, nsc): - arrows_param = self.get_param("arrows") - - def _sanitize_spec(arrow_spec): - arrow_spec = arrows_param.complete_query(arrow_spec) - - arrow_spec["atoms"] = np.atleast_1d(self.geometry._sanitize_atoms(arrow_spec["atoms"])) - arrow_atoms = arrow_spec["atoms"] - - not_displayed = set(arrow_atoms) - set(atoms) - if not_displayed: - warn(f"Arrow data for atoms {not_displayed} will not be displayed because these atoms are not displayed.") - if set(atoms) == set(atoms) - set(arrow_atoms): - # Then it makes no sense to store arrows, as nothing will be drawn - return None - - arrow_data = np.full((self.geometry.na, ndim), np.nan, dtype=np.float64) - provided_data = np.array(arrow_spec["data"]) - - # Get the projected directions if we are not in 3D. - if ndim == 1: - provided_data = self._projected_1Dcoords(self.geometry, provided_data, axis=axes[0]) - provided_data = np.expand_dims(provided_data, axis=-1) - elif ndim == 2: - provided_data = self._projected_2Dcoords(self.geometry, provided_data, xaxis=axes[0], yaxis=axes[1]) - - arrow_data[arrow_atoms] = provided_data - arrow_spec["data"] = arrow_data[atoms] - - arrow_spec["data"] = self._tile_atomic_data(arrow_spec["data"]) - - return arrow_spec - - arrows = [_sanitize_spec(arrow_spec) for arrow_spec in arrows] - - return [arrow_spec for arrow_spec in arrows if arrow_spec is not None] - - def _tile_atomic_data(self, data): - tiles = np.ones(np.array(data).ndim, dtype=int) - tiles[0] = self._tiled_geometry.na // self.geometry.na - return np.tile(data, tiles) - - def _tiled_atoms(self, atoms): - if len(atoms) == 0: - return atoms - - n_tiles = self._tiled_geometry.na // self.geometry.na - - tiled_atoms = np.tile(atoms, n_tiles).reshape(-1, atoms.shape[0]) - - tiled_atoms += np.linspace(0, self.geometry.na*(n_tiles - 1), n_tiles, dtype=int).reshape(-1, 1) - return tiled_atoms.ravel() - - def _tiled_coords(self, atoms): - return self._tiled_geometry[self._tiled_atoms(atoms)] - - def _set_data(self, axes, - atoms, atoms_style, atoms_scale, atoms_colorscale, show_atoms, bind_bonds_to_ats, bonds_style, - arrows, dataaxis_1d, show_cell, cell_style, nsc, kwargs3d={}, kwargs2d={}, kwargs1d={}): - self._ndim = len(axes) - - if show_atoms == False: - atoms = [] - bind_bonds_to_ats = False - - atoms = np.atleast_1d(self.geometry._sanitize_atoms(atoms)) + all_actions = plot_actions.combined(bond_plottings, atom_plottings, cell_plottings, arrow_plottings, composite_method=None) + + return get_figure(backend=backend, plot_actions=all_actions) - arrows = self._parse_arrows(arrows, atoms, self._ndim, axes, nsc) - - atoms_styles = self._parse_atoms_style(atoms_style, self._ndim) - atoms_styles["colorscale"] = atoms_colorscale - - atoms_kwargs = {"atoms": atoms, "atoms_styles": atoms_styles, "atoms_scale": atoms_scale} - - if self._ndim == 3: - xaxis, yaxis, zaxis = axes - backend_info = self._prepare3D( - **atoms_kwargs, bonds_styles=bonds_style, - bind_bonds_to_ats=bind_bonds_to_ats, **kwargs3d - ) - elif self._ndim == 2: - xaxis, yaxis = axes - backend_info = self._prepare2D( - xaxis=xaxis, yaxis=yaxis, bonds_styles=bonds_style, **atoms_kwargs, - bind_bonds_to_ats=bind_bonds_to_ats, nsc=nsc, **kwargs2d - ) - elif self._ndim == 1: - xaxis = axes[0] - yaxis = dataaxis_1d - backend_info = self._prepare1D(**atoms_kwargs, coords_axis=xaxis, data_axis=yaxis, nsc=nsc, **kwargs1d) - - # Define the axes titles - backend_info["axes_titles"] = { - "xaxis": self._get_ax_title(xaxis), - "yaxis": self._get_ax_title(yaxis), - } - if self._ndim == 3: - backend_info["axes_titles"]["zaxis"] = self._get_ax_title(zaxis) - - backend_info["ndim"] = self._ndim - backend_info["show_cell"] = show_cell - backend_info["arrows"] = arrows - - cell_style = self.get_param("cell_style").complete_dict(cell_style) - backend_info["cell_style"] = cell_style - - return backend_info - - @staticmethod - def _get_ax_title(ax): - """Generates the title for a given axis""" - if hasattr(ax, "__name__"): - title = ax.__name__ - elif isinstance(ax, np.ndarray) and ax.shape == (3,): - title = str(ax) - elif not isinstance(ax, str): - title = "" - elif re.match("[+-]?[xXyYzZ]", ax): - title = f'{ax.upper()} axis [Ang]' - elif re.match("[+-]?[aAbBcC]", ax): - title = f'{ax.upper()} lattice vector' - else: - title = ax +class GeometryPlot(Plot): - return title + function = staticmethod(geometry_plot) - # From here, we start to define all the helper methods: @property - def on_geom(self): - return BoundGeometry(self.geometry, self) - - @staticmethod - def _sphere(center=[0, 0, 0], r=1, vertices=10): - phi, theta = np.mgrid[0.0:np.pi: 1j*vertices, 0.0:2.0*np.pi: 1j*vertices] - - x = center[0] + r*np.sin(phi)*np.cos(theta) - y = center[1] + r*np.sin(phi)*np.sin(theta) - z = center[2] + r*np.cos(phi) - - return {'x': x, 'y': y, 'z': z} - - @classmethod - def atom_color(cls, atom): - - atom = Atom(atom) - - ghost = isinstance(atom, AtomGhost) - - color = cls._atoms_colors.get(atom.symbol, cls._atoms_colors["else"]) - - if ghost: - import matplotlib.colors - - color = (np.array(matplotlib.colors.to_rgb(color))*255).astype(int) - color = f'rgba({",".join(color.astype(str))}, 0.4)' - - return color - - @staticmethod - def find_all_bonds(geometry, tol=0.2): - """ - Finds all bonds present in a geometry. - - Parameters - ----------- - geometry: sisl.Geometry - the structure where the bonds should be found. - tol: float - the fraction that the distance between atoms is allowed to differ from - the "standard" in order to be considered a bond. - - Return - --------- - np.ndarray of shape (nbonds, 2) - each item of the array contains the 2 indices of the atoms that participate in the - bond. - """ - pt = PeriodicTable() - - bonds = [] - for at in geometry: - neighs = geometry.close(at, R=[0.1, 3])[-1] - - for neigh in neighs[neighs > at]: - summed_radius = pt.radius([abs(geometry.atoms[at].Z), abs(geometry.atoms[neigh % geometry.na].Z)]).sum() - bond_thresh = (1+tol) * summed_radius - if bond_thresh > fnorm(geometry[neigh] - geometry[at]): - bonds.append([at, neigh]) - - return np.array(bonds, dtype=int) - - @staticmethod - def _direction(ax, cell=None): - if isinstance(ax, (int, str)): - sign = 1 - # If the axis contains a -, we need to mirror the direction. - if isinstance(ax, str) and ax[0] == "-": - sign = -1 - ax = ax[1] - ax = sign * direction(ax, abc=cell, xyz=np.diag([1., 1., 1.])) - - return ax - - @classmethod - def _cross_product(cls, v1, v2, cell=None): - """An enhanced version of the cross product. - - It is an enhanced version because both bectors accept strings that represent - the cartesian axes or the lattice vectors (see `v1`, `v2` below). It has been built - so that cross product between lattice vectors (-){"a", "b", "c"} follows the same rules - as (-){"x", "y", "z"} - Parameters - ---------- - v1, v2: array-like of shape (3,) or (-){"x", "y", "z", "a", "b", "c"} - The vectors to take the cross product of. - cell: array-like of shape (3, 3) - The cell of the structure, only needed if lattice vectors {"a", "b", "c"} - are passed for `v1` and `v2`. - """ - # Make abc follow the same rules as xyz to find the orthogonal direction - # That is, a X b = c; -a X b = -c and so on. - if isinstance(v1, str) and isinstance(v2, str): - if re.match("([+-]?[abc]){2}", v1 + v2): - v1 = v1.replace("a", "x").replace("b", "y").replace("c", "z") - v2 = v2.replace("a", "x").replace("b", "y").replace("c", "z") - ort = cls._cross_product(v1, v2) - ort_ax = "abc"[np.where(ort != 0)[0][0]] - if ort.sum() == -1: - ort_ax = "-" + ort_ax - return cls._direction(ort_ax, cell) - - # If the vectors are not abc, we just need to take the cross product. - return np.cross(cls._direction(v1, cell), cls._direction(v2, cell)) - - @staticmethod - def _get_cell_corners(cell, unique=False): - """Gets the coordinates of a cell's corners. - - Parameters - ---------- - cell: np.ndarray of shape (3, 3) - the cell for which you want the corner's coordinates. - unique: bool, optional - if `False`, a full path to draw a cell is returned. - if `True`, only unique points are returned, in no particular order. - - Returns - --------- - np.ndarray of shape (x, 3) - where x is 16 if unique=False and 8 if unique=True. - """ - if unique: - verts = list(itertools.product([0, 1], [0, 1], [0, 1])) - else: - # Define the vertices of the cube. They follow an order so that we can - # draw a line that represents the cell's box - verts = [ - (0, 0, 0), (0, 1, 0), (1, 1, 0), (1, 1, 1), (0, 1, 1), (0, 1, 0), - (np.nan, np.nan, np.nan), - (0, 1, 1), (0, 0, 1), (0, 0, 0), (1, 0, 0), (1, 0, 1), (0, 0, 1), - (np.nan, np.nan, np.nan), - (1, 1, 0), (1, 0, 0), - (np.nan, np.nan, np.nan), - (1, 1, 1), (1, 0, 1) - ] - - verts = np.array(verts, dtype=np.float64) - - return verts.dot(cell) - - @classmethod - def _projected_1Dcoords(cls, geometry, xyz=None, axis="x"): - """ - Moves the 3D positions of the atoms to a 2D supspace. - - In this way, we can plot the structure from the "point of view" that we want. - - NOTE: If axis is one of {"a", "b", "c", "1", "2", "3"} the function doesn't - project the coordinates in the direction of the lattice vector. The fractional - coordinates, taking in consideration the three lattice vectors, are returned - instead. - - Parameters - ------------ - geometry: sisl.Geometry - the geometry for which you want the projected coords - xyz: array-like of shape (natoms, 3), optional - the 3D coordinates that we want to project. - otherwise they are taken from the geometry. - axis: {"x", "y", "z", "a", "b", "c", "1", "2", "3"} or array-like of shape 3, optional - the direction to be displayed along the X axis. - nsc: array-like of shape (3, ), optional - only used if `axis` is a lattice vector. It is used to rescale everything to the unit - cell lattice vectors, otherwise `GeometryPlot` doesn't play well with `GridPlot`. - - Returns - ---------- - np.ndarray of shape (natoms, ) - the 1D coordinates of the geometry, with all positions projected into the line - defined by axis. - """ - if xyz is None: - xyz = geometry.xyz - - if isinstance(axis, str) and axis in ("a", "b", "c", "0", "1", "2"): - return cls._projected_2Dcoords(geometry, xyz, xaxis=axis, yaxis="a" if axis == "c" else "c")[..., 0] - - # Get the direction that the axis represents - axis = cls._direction(axis, geometry.cell) - - return xyz.dot(axis/fnorm(axis)) / fnorm(axis) - - @classmethod - def _projected_2Dcoords(cls, geometry, xyz=None, xaxis="x", yaxis="y"): - """ - Moves the 3D positions of the atoms to a 2D supspace. - - In this way, we can plot the structure from the "point of view" that we want. - - NOTE: If xaxis/yaxis is one of {"a", "b", "c", "1", "2", "3"} the function doesn't - project the coordinates in the direction of the lattice vector. The fractional - coordinates, taking in consideration the three lattice vectors, are returned - instead. - - Parameters - ------------ - geometry: sisl.Geometry - the geometry for which you want the projected coords - xyz: array-like of shape (natoms, 3), optional - the 3D coordinates that we want to project. - otherwise they are taken from the geometry. - xaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional - the direction to be displayed along the X axis. - yaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional - the direction to be displayed along the X axis. - - Returns - ---------- - np.ndarray of shape (2, natoms) - the 2D coordinates of the geometry, with all positions projected into the plane - defined by xaxis and yaxis. - """ - if xyz is None: - xyz = geometry.xyz - - try: - all_lattice_vecs = len(set([xaxis, yaxis]).intersection(["a", "b", "c"])) == 2 - except Exception: - # If set fails it is because xaxis/yaxis is unhashable, which means it - # is a numpy array - all_lattice_vecs = False - - if all_lattice_vecs: - coord_indices = ["abc".index(ax) for ax in (xaxis, yaxis)] - - icell = cell_invert(geometry.cell) - else: - # Get the directions that these axes represent - xaxis = cls._direction(xaxis, geometry.cell) - yaxis = cls._direction(yaxis, geometry.cell) - - fake_cell = np.array([xaxis, yaxis, np.cross(xaxis, yaxis)], dtype=np.float64) - icell = cell_invert(fake_cell) - coord_indices = [0, 1] - - return np.dot(xyz, icell.T)[..., coord_indices] - - def _get_atoms_bonds(self, bonds, atoms): - """ - Gets the bonds where the given atoms are involved - """ - return [bond for bond in bonds if np.any([at in atoms for at in bond])] - - #--------------------------------------------------- - # 1D plotting - #--------------------------------------------------- - - def _prepare1D(self, atoms=None, atoms_styles=None, coords_axis="x", data_axis=None, wrap_atoms=None, atoms_scale=1., - nsc=(1, 1, 1), **kwargs): - """ - Returns a 1D representation of the plot's geometry. - - Parameters - ----------- - atoms: array-like of int, optional - the indices of the atoms that you want to plot - coords_axis: {0,1,2, "x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional - the axis onto which all the atoms are projected. - data_axis: function or array-like, optional - determines the second coordinate of the atoms - - If it's a function, it will recieve the projected 1D coordinates and needs to returns - the coordinates for the other axis as an array. - - If not provided, the other axis will just be 0 for all points. - atoms_styles: dict, optional - dictionary containing all the style properties of the atoms, it should be build by `self._parse_atoms_style`. - atoms_colorscale: str or list, optional - the name of a plotly colorscale or a list of colors. - - Only used if atoms_color is an array of values. - wrap_atoms: function, optional - function that takes the 2D positions of the atoms in the plot and returns a tuple of (args, kwargs), - that are passed to self._atoms_scatter_trace2D. - If not provided self._default_wrap_atoms is used. - nsc: array-like of shape (3,), optional - the number of times the geometry has been tiled in each direction. This is only used to rescale - fractional coordinates. - **kwargs: - passed directly to the atoms scatter trace - """ - wrap_atoms = wrap_atoms or self._default_wrap_atoms1D - - x = self._projected_1Dcoords(self.geometry, self._tiled_coords(atoms), axis=coords_axis) - if data_axis is None: - def data_axis(x): - return np.zeros(x.shape[0]) - - data_axis_name = data_axis.__name__ if callable(data_axis) else 'Data axis' - if callable(data_axis): - data_axis = np.array(data_axis(x)) - - xy = np.array([x, data_axis]).T - - atoms_props = wrap_atoms(atoms, xy, atoms_styles) - atoms_props["size"] *= atoms_scale - - return { - "geometry": self.geometry, "xaxis": coords_axis, "yaxis": data_axis_name, "atoms_props": atoms_props, "bonds_props": [] - } - - def _default_wrap_atoms1D(self, ats, xy, atoms_styles): - - extra_kwargs = {} - - color = atoms_styles["color"][ats] - - try: - color.astype(float) - extra_kwargs["marker_colorscale"] = atoms_styles["colorscale"] - extra_kwargs["text"] = self._tile_atomic_data([f"Color: {c}" for c in color]) - except ValueError: - pass - - return { - "xy": xy, - "text": self._tile_atomic_data([f'{self.geometry[at]}
{at} ({self.geometry.atoms[at].tag})' for at in ats]), - "name": "Atoms", - **{k: self._tile_atomic_data(atoms_styles[k][ats]) for k in ("color", "size", "opacity")}, - **extra_kwargs - } - - #--------------------------------------------------- - # 2D plotting - #--------------------------------------------------- - - def _prepare2D(self, xaxis="x", yaxis="y", - atoms=None, atoms_styles=None, atoms_scale=1., - show_bonds=True, bonds_styles=None, bind_bonds_to_ats=True, - points_per_bond=5, wrap_atoms=None, wrap_bond=None, nsc=(1, 1, 1)): - """Returns a 2D representation of the plot's geometry. - - Parameters - ----------- - xaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional - the direction to be displayed along the X axis. - yaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional - the direction to be displayed along the X axis. - atoms: array-like of int, optional - the indices of the atoms that you want to plot - atoms_styles: dict, optional - dictionary containing all the style properties of the atoms, it should be build by `self._parse_atoms_style`. - atoms_scale: float, optional - a factor to multiply atom sizes by. - atoms_colorscale: str or list, optional - the name of a plotly colorscale or a list of colors. - Only used if atoms_color is an array of values. - show_bonds: boolean, optional - whether bonds should be plotted. - bind_bonds_to_ats: boolean, optional - whether only the bonds that belong to an atom that is present should be displayed. - If False, all bonds are displayed regardless of the `atom` parameter. - bonds_styles: dict, optional - dictionary containing all the style properties of the bonds. - points_per_bond: int, optional - If `bonds_together` is True and you provide a variable color or size (using `wrap_bonds`), this is - the number of points that are used for each bond. See `bonds_together` for more info. - wrap_atoms: function, optional - function that recieves the 2D coordinates and returns - the args (array-like) and kwargs (dict) that go into self._atoms_scatter_trace2D() - If not provided, self._default_wrap_atoms2D will be used. - wrap_atom: function, optional - function that recieves the index of an atom and returns - the args (array-like) and kwargs (dict) that go into self._atom_trace3D() - If not provided, self._default_wrap_atoms3D will be used. - wrap_bond: function, optional - function that recieves "a bond" (list of 2 atom indices) and its coordinates ((x1,y1), (x2, y2)). - It should return the args (array-like) and kwargs (dict) that go into `self._bond_trace2D()` - If not provided, self._default_wrap_bond2D will be used. - """ - wrap_atoms = wrap_atoms or self._default_wrap_atoms2D - wrap_bond = wrap_bond or self._default_wrap_bond2D - - # We need to sort the geometry according to depth, because when atoms are drawn they can be one - # on top of the other. The last atoms should be the ones on top. - if len(atoms) > 0: - depth_vector = self._cross_product(xaxis, yaxis, self.geometry.cell) - sorted_atoms = np.concatenate(self.geometry.sort(atoms=atoms, vector=depth_vector, ret_atoms=True)[1]) - else: - sorted_atoms = atoms - xy = self._projected_2Dcoords(self.geometry, self._tiled_coords(sorted_atoms), xaxis=xaxis, yaxis=yaxis) - - # Add atoms - atoms_props = wrap_atoms(sorted_atoms, xy, atoms_styles) - atoms_props["size"] *= atoms_scale - - # Add bonds - if show_bonds: - # Define the actual bonds that we are going to draw depending on which - # atoms are requested - bonds = self.bonds - if bind_bonds_to_ats: - bonds = self._get_atoms_bonds(bonds, self._tiled_atoms(atoms)) - - bonds_xyz = np.array([self._tiled_geometry[bond] for bond in bonds]) - if len(bonds_xyz) != 0: - xys = self._projected_2Dcoords(self.geometry, bonds_xyz, xaxis=xaxis, yaxis=yaxis) - - # Try to get the bonds colors (It might be that the user is not setting them) - bonds_props = [wrap_bond(bond, xy, bonds_styles) for bond, xy in zip(bonds, xys)] - else: - bonds_props = [] - else: - bonds_props = [] - - return { - "geometry": self.geometry, "xaxis": xaxis, "yaxis": yaxis, "atoms_props": atoms_props, - "bonds_props": bonds_props, "points_per_bond": points_per_bond, - } - - def _default_wrap_atoms2D(self, ats, xy, atoms_styles): - return self._default_wrap_atoms1D(ats, xy, atoms_styles) - - def _default_wrap_bond2D(self, bond, xys, bonds_styles): - return { - "xys": xys, - **bonds_styles, - } - - #--------------------------------------------------- - # 3D plotting - #--------------------------------------------------- - - def _prepare3D(self, wrap_atoms=None, wrap_bond=None, - atoms=None, atoms_styles=None, bind_bonds_to_ats=True, atoms_scale=1., - show_bonds=True, bonds_styles=None): - """Returns a 3D representation of the plot's geometry. - - Parameters - ----------- - wrap_atoms: function, optional - function that recieves the index of the atoms and returns - a dictionary with properties of the atoms. - If not provided, self._default_wrap_atoms3D will be used. - wrap_bond: function, optional - function that recieves "a bond" (list of 2 atom indices) and returns - the args (array-like) and kwargs (dict) that go into self._bond_trace3D() - If not provided, self._default_wrap_bond3D will be used. - show_cell: {'axes', 'box', False}, optional - defines how the unit cell is drawn - atoms: array-like of int, optional - the indices of the atoms that you want to plot - bind_bonds_to_ats: boolean, optional - whether only the bonds that belong to an atom that is present should be displayed. - If False, all bonds are displayed regardless of the `atom` parameter - atoms_vertices: int - the "definition" of the atom sphere, if not in cheap mode. The more vertices, the more defined the sphere - will be. However, it will also be more expensive to render. - atoms_styles: dict, optional - dictionary containing all the style properties of the atoms, it should be build by `self._parse_atoms_style`. - """ - wrap_atoms = wrap_atoms or self._default_wrap_atoms3D - wrap_bond = wrap_bond or self._default_wrap_bond3D - - try: - atoms_styles["color"] = np.array(values_to_colors(atoms_styles["color"], atoms_styles["colorscale"])) - except Exception: - pass - - atoms_props = wrap_atoms(atoms, atoms_styles) - atoms_props["size"] *= atoms_scale - - if show_bonds: - # Try to get the bonds colors (It might be that the user is not setting them) - bonds = self.bonds - if bind_bonds_to_ats: - bonds = self._get_atoms_bonds(bonds, self._tiled_atoms(atoms)) - bonds_props = [wrap_bond(bond, bonds_styles) for bond in bonds] - else: - bonds = [] - bonds_props = [] - - return {"geometry": self.geometry, "atoms_props": atoms_props, "bonds_props": bonds_props} + def geometry(self): + return self.nodes.inputs['geometry']._output + +_T = TypeVar("_T", list, tuple, dict) + +def _sites_specs_to_atoms_specs(sites_specs: _T) -> _T: + + if isinstance(sites_specs, dict): + if "sites" in sites_specs: + sites_specs = sites_specs.copy() + sites_specs['atoms'] = sites_specs.pop('sites') + return sites_specs + else: + return type(sites_specs)(_sites_specs_to_atoms_specs(style_spec) for style_spec in sites_specs) + +def sites_plot( + sites_obj: BrillouinZone, + axes: Axes = ["x", "y", "z"], + sites: AtomsArgument = None, + sites_style: Sequence[AtomsStyleSpec] = [], + sites_scale: float = 1., + sites_name: str = "Sites", + sites_colorscale: Optional[str] = None, + drawing_mode: Literal["scatter", "balls", "line", None] = None, + show_cell: Literal["box", "axes", False] = False, + cell_style: StyleSpec = {}, + nsc: Tuple[int, int, int] = (1, 1, 1), + sites_ndim_scale: Tuple[float, float, float] = (1, 1, 1), + dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, + arrows: Sequence[AtomArrowSpec] = (), + backend="plotly", +) -> Figure: + """Plots sites from an object that can be parsed into a geometry. + + The only differences between this plot and a geometry plot is the naming of the inputs + and the fact that there are no options to plot bonds. + + Parameters + ---------- + sites_obj: + The object to be converted to sites. + axes: + The axes to project the sites to. + sites: + The sites to plot. If None, all sites are plotted. + sites_style: + List of style specifications for the sites. See the showcase notebooks for examples. + sites_scale: + Scaling factor for the size of all sites. + sites_name: + Name to give to the trace that draws the sites. + sites_colorscale: + Colorscale to use for the sites in case the color attribute is an array of values. + If None, the default colorscale is used for each backend. + drawing_mode: + The method used to draw the sites. + show_cell: + Mode to display the reciprocal cell. If False, the cell is not displayed. + cell_style: + Style specification for the reciprocal cell. See the showcase notebooks for examples. + nsc: + Number of unit cells to display in each direction. + sites_ndim_scale: + Scaling factor for the size of the sites for different dimensionalities (1D, 2D, 3D). + dataaxis_1d: + Only meaningful for 1D plots. The data to plot on the Y axis. + arrows: + List of arrow specifications to display. See the showcase notebooks for examples. + backend: + The backend to use to generate the figure. + """ - def _default_wrap_atoms3D(self, ats, atoms_styles): + # INPUTS ARE NOT GETTING PARSED BECAUSE WORKFLOWS RUN GET ON FINAL NODE + # SO PARSING IS DELEGATED TO NODES. + axes = sanitize_axes(axes) + fake_geometry = sites_obj_to_geometry(sites_obj) + sanitized_sites = sanitize_atoms(fake_geometry, atoms=sites) + ndim = get_ndim(axes) + z = get_z(ndim) + + # Process sites + atoms_style = _sites_specs_to_atoms_specs(sites_style) + parsed_sites_style = parse_atoms_style(fake_geometry, atoms_style=atoms_style) + sites_dataset = add_xyz_to_dataset(parsed_sites_style) + filtered_sites = select(sites_dataset, "atom", sanitized_sites) + tiled_sites = tile_data_sc(filtered_sites, nsc=nsc) + sc_sites = stack_sc_data(tiled_sites, newname="sc_atom", dims=["atom"]) + sites_units = get_sites_units(sites_obj) + projected_sites = project_to_axes(sc_sites, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d, cartesian_units=sites_units) + + sites_scale = _sanitize_scale(sites_scale, ndim, sites_ndim_scale) + final_sites = scale_variable(projected_sites, "size", scale=sites_scale) + sites_mode = _get_atom_mode(drawing_mode, ndim) + site_plottings = draw_xarray_xy( + data=final_sites, x="x", y="y", z=z, width="size", what=sites_mode, colorscale=sites_colorscale, + set_axequal=True, name=sites_name, + ) + + # And now the cell + show_cell = matches(ndim, 1, False, show_cell) + cell_plottings = cell_plot_actions( + cell=fake_geometry, show_cell=show_cell, cell_style=cell_style, + axes=axes, dataaxis_1d=dataaxis_1d + ) + + # And the arrows + atom_arrows = _sites_specs_to_atoms_specs(arrows) + arrow_data = sanitize_arrows(fake_geometry, atom_arrows, atoms=sanitized_sites, ndim=ndim, axes=axes) + arrow_plottings = _get_arrow_plottings(projected_sites, arrow_data, nsc=nsc) - return { - "xyz": self._tiled_coords(ats), - "name": self._tile_atomic_data([f'{at} ({self.geometry.atoms[at].tag})' for at in ats]), - **{k: self._tile_atomic_data(atoms_styles[k][ats]) for k in ("color", "size", "vertices", "opacity")} - } + all_actions = plot_actions.combined(site_plottings, cell_plottings, arrow_plottings, composite_method=None) + + return get_figure(backend=backend, plot_actions=all_actions) - def _default_wrap_bond3D(self, bond, bonds_styles): +class SitesPlot(Plot): - return { - "xyz1": self._tiled_geometry[bond[0]], - "xyz2": self._tiled_geometry[bond[1]], - #"r": 15, - **bonds_styles, - } + function = staticmethod(sites_plot) diff --git a/src/sisl/viz/plots/grid.py b/src/sisl/viz/plots/grid.py index ed18fd6bc1..c522eb65f2 100644 --- a/src/sisl/viz/plots/grid.py +++ b/src/sisl/viz/plots/grid.py @@ -1,1600 +1,563 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from collections import defaultdict - -import numpy as np -from scipy.ndimage import affine_transform - -import sisl -from sisl import _array as _a -from sisl._lattice import cell_invert -from sisl.messages import warn -from sisl.viz.plots.geometry import GeometryPlot - -from ..input_fields import ( - Array1DInput, - BoolInput, - ColorInput, - CreatableOptionsInput, - FloatInput, - GeomAxisSelect, - IntegerInput, - OptionsInput, - PlotableInput, - ProgramaticInput, - QueriesInput, - RangeInput, - RangeSliderInput, - SileInput, - SislObjectInput, - SpinSelect, - TextInput, +from typing import Callable, ChainMap, Literal, Optional, Sequence, Tuple, Union + +from sisl.geometry import Geometry +from sisl.grid import Grid + +from ..data import EigenstateData +from ..figure import Figure, get_figure +from ..plot import Plot +from ..plotters.cell import cell_plot_actions +from ..plotters.grid import draw_grid +from ..plotters.plot_actions import combined +from ..processors.axes import sanitize_axes +from ..processors.eigenstate import ( + eigenstate_geometry, + get_eigenstate, + get_grid_nsc, + project_wavefunction, + tile_if_k, ) -from ..plot import Plot, entry_point +from ..processors.grid import ( + apply_transforms, + get_grid_axes, + get_grid_representation, + grid_geometry, + grid_to_dataarray, + interpolate_grid, + orthogonalize_grid_if_needed, + reduce_grid, + sub_grid, + tile_grid, +) +from ..types import Axes +from .geometry import geometry_plot + + +def _get_structure_plottings(plot_geom, geometry, axes, nsc, geom_kwargs={},): + if plot_geom: + geom_kwargs = ChainMap(geom_kwargs, {"axes": axes, "geometry": geometry, "nsc": nsc, "show_cell": False}) + plot_actions = geometry_plot(**geom_kwargs).plot_actions + else: + plot_actions = [] + + return plot_actions + +def grid_plot( + grid: Optional[Grid] = None, + axes: Axes = ["z"], + represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", + transforms: Sequence[Union[str, Callable]] = (), + reduce_method: Literal["average", "sum"] = "average", + boundary_mode: str = "grid-wrap", + nsc: Tuple[int, int, int] = (1, 1, 1), + interp: Tuple[int, int, int] = (1, 1, 1), + isos: Sequence[dict] = [], + smooth: bool = False, + colorscale: Optional[str] = None, + crange: Optional[Tuple[float, float]] = None, + cmid: Optional[float] = None, + show_cell: Literal["box", "axes", False] = "box", + cell_style: dict = {}, + x_range: Optional[Sequence[float]] = None, + y_range: Optional[Sequence[float]] = None, + z_range: Optional[Sequence[float]] = None, + plot_geom: bool = False, + geom_kwargs: dict = {}, + backend: str = "plotly" +) -> Figure: + """Plots a grid, with plentiful of customization options. + + Parameters + ---------- + grid: + The grid to plot. + axes: + The axes to project the grid to. + represent: + The representation of the grid to plot. + transforms: + List of transforms to apply to the grid before plotting. + reduce_method: + The method used to reduce the grid axes that are not displayed. + boundary_mode: + The method used to deal with the boundary conditions. + Only used if the grid is to be orthogonalized. + See scipy docs for more info on the possible values. + nsc: + The number of unit cells to display in each direction. + interp: + The interpolation factor to use for each axis to make the grid smoother. + isos: + List of isosurfaces or isocontours to plot. See the showcase notebooks for examples. + smooth: + Whether to ask the plotting backend to make an attempt at smoothing the grid display. + colorscale: + Colorscale to use for the grid display in the 2D representation. + If None, the default colorscale is used for each backend. + crange: + Min and max values for the colorscale. + cmid: + The value at which the colorscale is centered. + show_cell: + Method used to display the unit cell. If False, the cell is not displayed. + cell_style: + Style specification for the cell. See the showcase notebooks for examples. + x_range: + The range of the x axis to take into account. + Even if the X axis is not displayed! This is important because the reducing + operation will only be applied on this range. + y_range: + The range of the y axis to take into account. + Even if the Y axis is not displayed! This is important because the reducing + operation will only be applied on this range. + z_range: + The range of the z axis to take into account. + Even if the Z axis is not displayed! This is important because the reducing + operation will only be applied on this range. + plot_geom: + Whether to plot the associated geometry (if any). + geom_kwargs: + Keyword arguments to pass to the geometry plot of the associated geometry. + backend: + The backend to use to generate the figure. + + See also + ---------- + scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed. + """ + + axes = sanitize_axes(axes) -class GridPlot(Plot): - """ - Versatile visualization tool for any kind of grid. + geometry = grid_geometry(grid, geometry=None) - Parameters - ------------ - grid: Grid, optional - A sisl.Grid object. If provided, grid_file is ignored. - grid_file: cubeSile or rhoSileSiesta or ldosSileSiesta or rhoinitSileSiesta or rhoxcSileSiesta or drhoSileSiesta or baderSileSiesta or iorhoSileSiesta or totalrhoSileSiesta or stsSileSiesta or stmldosSileSiesta or hartreeSileSiesta or neutralatomhartreeSileSiesta or totalhartreeSileSiesta or gridncSileSiesta or ncSileSiesta or fdfSileSiesta or tsvncSileSiesta or chgSileVASP or locpotSileVASP, optional - A filename that can be return a Grid through `read_grid`. - represent: optional - The representation of the grid that should be displayed - transforms: optional - Transformations to apply to the whole grid. It can be a - function, or a string that represents the path to a - function (e.g. "scipy.exp"). If a string that is a single - word is provided, numpy will be assumed to be the module (e.g. - "square" will be converted into "np.square"). Note that - transformations will be applied in the order provided. Some - transforms might not be necessarily commutable (e.g. "abs" and - "cos"). - axes: optional - The axis along you want to see the grid, it will be reduced along the - other ones, according to the the `reduce_method` setting. - zsmooth: optional - Parameter that smoothens how data looks in a heatmap. - 'best' interpolates data, 'fast' interpolates pixels, 'False' - displays the data as is. - interp: array-like, optional - Interpolation factors to make the grid finer on each axis.See the - zsmooth setting for faster smoothing of 2D heatmap. - transform_bc: optional - The boundary conditions when a cell transform is applied to the grid. - Cell transforms are only applied when the grid's cell - doesn't follow the cartesian coordinates and the requested display is - 2D or 1D. - nsc: array-like, optional - Number of times the grid should be repeated - offset: array-like, optional - The offset of the grid along each axis. This is important if you are - planning to match this grid with other geometry related plots. - trace_name: str, optional - The name that the trace will show in the legend. Good when merging - with other plots to be able to toggle the trace in the legend - x_range: array-like of shape (2,), optional - Range where the X is displayed. Should be inside the unit cell, - otherwise it will fail. - y_range: array-like of shape (2,), optional - Range where the Y is displayed. Should be inside the unit cell, - otherwise it will fail. - z_range: array-like of shape (2,), optional - Range where the Z is displayed. Should be inside the unit cell, - otherwise it will fail. - crange: array-like of shape (2,), optional - The range of values that the colorbar must enclose. This controls - saturation and hides below threshold values. - cmid: int, optional - The value to set at the center of the colorbar. If not provided, the - color range is used - colorscale: str, optional - A valid plotly colorscale. See https://plotly.com/python/colorscales/ - reduce_method: optional - The method used to reduce the dimensions that will not be displayed - in the plot. - isos: array-like of dict, optional - The isovalues that you want to represent. The way they - will be represented is of course dependant on the type of - representation: - 2D representations: A contour (i.e. - a line) - 3D representations: A surface - Each item is a dict. Structure of the dict: { 'name': The - name of the iso query. Note that you can use $isoval$ as a template - to indicate where the isoval should go. 'val': The iso value. - If not provided, it will be infered from `frac` 'frac': If - val is not provided, this is used to calculate where the isosurface - should be drawn. It calculates them from the - minimum and maximum values of the grid like so: - If iso_frac = 0.3: (min_value----- - ISOVALUE(30%)-----------max_value) Therefore, it - should be a number between 0 and 1. - 'step_size': The step size to use to calculate the isosurface in case - it's a 3D representation A bigger step-size can - speed up the process dramatically, specially the rendering part - and the resolution may still be more than satisfactory (try to use - step_size=2). For very big grids your computer - may not even be able to render very fine surfaces, so it's worth - keeping this setting in mind. 'color': - The color of the surface/contour. 'opacity': Opacity of the - surface/contour. Between 0 (transparent) and 1 (opaque). } - plot_geom: bool, optional - If True the geometry associated to the grid will also be plotted - geom_kwargs: dict, optional - Extra arguments that are passed to geom.plot() if plot_geom is set to - True - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - """ + grid_repr = get_grid_representation(grid, represent=represent) - # Define all the class attributes - _plot_type = "Grid" - - _update_methods = { - "read_data": [], - "set_data": ["_prepare1D", "_prepare2D", "_prepare3D"], - "get_figure": [] - } - - _param_groups = ( - { - "key": "grid_shape", - "name": "Grid shape", - "icon": "image_aspect_ratio", - "description": "Settings related to the shape of the grid, including it's dimensionality and how it is reduced if needed." - }, - - { - "key": "grid_values", - "name": "Grid values", - "icon": "image", - "description": "Settings related to the values of the grid. They involve both how they are processed and displayed" - }, - ) + tiled_grid = tile_grid(grid_repr, nsc=nsc) + + ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode) + + grid_axes = get_grid_axes(ort_grid, axes=axes) + + transformed_grid = apply_transforms(ort_grid, transforms) + + subbed_grid = sub_grid(transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range) - _parameters = ( - - PlotableInput( - key="grid", name="Grid", - dtype=sisl.Grid, - default=None, - group="dataread", - help="A sisl.Grid object. If provided, grid_file is ignored." - ), - - SileInput( - key="grid_file", name="Path to grid file", - required_attrs=["read_grid"], - default=None, - params={ - "placeholder": "Write the path to your grid file here..." - }, - group="dataread", - help="A filename that can be return a Grid through `read_grid`." - ), - - OptionsInput( - key="represent", name="Representation of the grid", - default="real", - params={ - 'options': [ - {'label': 'Real part', 'value': "real"}, - {'label': 'Imaginary part', 'value': 'imag'}, - {'label': 'Complex modulus', 'value': "mod"}, - {'label': 'Phase (in rad)', 'value': 'phase'}, - {'label': 'Phase (in deg)', 'value': 'deg_phase'}, - ], - 'isMulti': False, - 'isSearchable': True, - 'isClearable': False - }, - group="grid_values", - help="""The representation of the grid that should be displayed""" - ), - - CreatableOptionsInput( - key="transforms", name="Grid transforms", - default=[], - params={ - 'options': [ - {'label': 'Square', 'value': 'square'}, - {'label': 'Absolute', 'value': 'abs'}, - ], - 'isMulti': True, - 'isSearchable': True, - 'isClearable': True - }, - group="grid_values", - help="""Transformations to apply to the whole grid. - It can be a function, or a string that represents the path - to a function (e.g. "scipy.exp"). If a string that is a single - word is provided, numpy will be assumed to be the module (e.g. - "square" will be converted into "np.square"). - Note that transformations will be applied in the order provided. Some - transforms might not be necessarily commutable (e.g. "abs" and "cos").""" - ), - - GeomAxisSelect( - key = "axes", name="Axes to display", - default=["z"], - group="grid_shape", - help = """The axis along you want to see the grid, it will be reduced along the other ones, according to the the `reduce_method` setting.""" - ), - - OptionsInput( - key = "zsmooth", name="2D heatmap smoothing method", - default=False, - params={ - 'options': [ - {'label': 'best', 'value': 'best'}, - {'label': 'fast', 'value': 'fast'}, - {'label': 'False', 'value': False}, - ], - 'isSearchable': True, - 'isClearable': False - }, - group="grid_values", - help = """Parameter that smoothens how data looks in a heatmap.
- 'best' interpolates data, 'fast' interpolates pixels, 'False' displays the data as is.""" - ), - - Array1DInput( - key="interp", name="Interpolation", - default=[1, 1, 1], - params={ - 'inputType': 'number', - 'shape': (3,), - 'extendable': False, - }, - group="grid_shape", - help="Interpolation factors to make the grid finer on each axis.
See the zsmooth setting for faster smoothing of 2D heatmap." - ), - - CreatableOptionsInput(key="transform_bc", name="Transform boundary conditions", - default="wrap", - params={ - 'options': [ - {'label': 'constant', 'value': 'constant'}, - {'label': 'wrap', 'value': 'wrap'}, - ], - }, - group="grid_values", - help="""The boundary conditions when a cell transform is applied to the grid. Cell transforms are only - applied when the grid's cell doesn't follow the cartesian coordinates and the requested display is 2D or 1D. - """ - ), - - Array1DInput( - key="nsc", name="Supercell", - default=[1, 1, 1], - params={ - 'inputType': 'number', - 'shape': (3,), - 'extendable': False, - }, - group="grid_shape", - help="Number of times the grid should be repeated" - ), - - Array1DInput( - key="offset", name="Grid offset", - default=[0, 0, 0], - params={ - 'inputType': 'number', - 'shape': (3,), - 'extendable': False, - }, - help="""The offset of the grid along each axis. This is important if you are planning to match this grid with other geometry related plots.""" - ), - - TextInput( - key="trace_name", name="Trace name", - default=None, - params={ - "placeholder": "Give a name to the trace..." - }, - help="""The name that the trace will show in the legend. Good when merging with other plots to be able to toggle the trace in the legend""" - ), - - RangeSliderInput( - key="x_range", name="X range", - default=None, - params={ - "min": 0 - }, - group="grid_shape", - help="Range where the X is displayed. Should be inside the unit cell, otherwise it will fail.", - ), - - RangeSliderInput( - key="y_range", name="Y range", - default=None, - params={ - "min": 0 - }, - group="grid_shape", - help="Range where the Y is displayed. Should be inside the unit cell, otherwise it will fail.", - ), - - RangeSliderInput( - key="z_range", name="Z range", - default=None, - params={ - "min": 0 - }, - group="grid_shape", - help="Range where the Z is displayed. Should be inside the unit cell, otherwise it will fail.", - ), - - RangeInput( - key="crange", name="Colorbar range", - default=[None, None], - group="grid_values", - help="The range of values that the colorbar must enclose. This controls saturation and hides below threshold values." - ), - - IntegerInput( - key="cmid", name="Colorbar center", - default=None, - group="grid_values", - help="""The value to set at the center of the colorbar. If not provided, the color range is used""" - ), - - TextInput( - key="colorscale", name="Color scale", - default=None, - group="grid_values", - help="""A valid plotly colorscale. See https://plotly.com/python/colorscales/""" - ), - - OptionsInput(key="reduce_method", name="Reduce method", - default="average", - params={ - 'options': [ - {'label': 'average', 'value': 'average'}, - {'label': 'sum', 'value': 'sum'}, - ], - }, - group="grid_values", - help="""The method used to reduce the dimensions that will not be displayed in the plot.""" - ), - - QueriesInput(key = "isos", name = "Isosurfaces / contours", - default = [], - group="grid_values", - help = """The isovalues that you want to represent. - The way they will be represented is of course dependant on the type of representation: - - 2D representations: A contour (i.e. a line) - - 3D representations: A surface - """, - queryForm = [ - - TextInput( - key="name", name="Name", - default="Iso=$isoval$", - params={ - "placeholder": "Name of the isovalue..." - }, - help="The name of the iso query. Note that you can use $isoval$ as a template to indicate where the isoval should go." - ), - - FloatInput( - key="val", name="Value", - default=None, - help="The iso value. If not provided, it will be infered from `frac`" - ), - - FloatInput( - key="frac", name="Fraction", - default=0.3, - params={ - "min": 0, - "max": 1, - "step": 0.05 - }, - help="""If val is not provided, this is used to calculate where the isosurface should be drawn. - It calculates them from the minimum and maximum values of the grid like so: - If iso_frac = 0.3: - (min_value-----ISOVALUE(30%)-----------max_value) - Therefore, it should be a number between 0 and 1. - """ - ), - - IntegerInput( - key="step_size", name="Step size", - default=1, - help="""The step size to use to calculate the isosurface in case it's a 3D representation - A bigger step-size can speed up the process dramatically, specially the rendering part - and the resolution may still be more than satisfactory (try to use step_size=2). For very big - grids your computer may not even be able to render very fine surfaces, so it's worth keeping - this setting in mind.""" - ), - - ColorInput( - key="color", name="Color", - default=None, - help="The color of the surface/contour." - ), - - FloatInput( - key="opacity", name="Opacity", - default=1, - params={ - "min": 0, - "max": 1, - "step": 0.1 - }, - help="Opacity of the surface/contour. Between 0 (transparent) and 1 (opaque)." - ) - - ] - ), - - BoolInput(key='plot_geom', name='Plot geometry', - default=False, - help="""If True the geometry associated to the grid will also be plotted""" - ), - - ProgramaticInput(key='geom_kwargs', name='Geometry plot extra arguments', - default={}, - dtype=dict, - help="""Extra arguments that are passed to geom.plot() if plot_geom is set to True""" - ), + reduced_grid = reduce_grid(subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes) + interp_grid = interpolate_grid(reduced_grid, interp=interp) + + # Finally, here comes the plotting! + grid_ds = grid_to_dataarray(interp_grid, axes=axes, grid_axes=grid_axes, nsc=nsc) + grid_plottings = draw_grid(data=grid_ds, isos=isos, colorscale=colorscale, crange=crange, cmid=cmid, smooth=smooth) + + # Process the cell as well + cell_plottings = cell_plot_actions( + cell=grid, show_cell=show_cell, cell_style=cell_style, + axes=axes, ) - def _after_init(self): - - self.offsets = defaultdict(lambda: _a.arrayd([0, 0, 0])) - - self._add_shortcuts() - - @entry_point('grid', 0) - def _read_nosource(self, grid): - """ - Reads the grid directly from a sisl grid. - """ - self.grid = grid - - if self.grid is None: - raise ValueError("grid was not set") - - @entry_point('grid file', 1) - def _read_grid_file(self, grid_file): - """ - Reads the grid from any sile that implements `read_grid`. - """ - self.grid = self.get_sile(grid_file).read_grid() - - def _after_read(self): - - #Inform of the new available ranges - range_keys = ("x_range", "y_range", "z_range") - - for ax, key in enumerate(range_keys): - self.modify_param(key, "inputField.params.max", self.grid.cell[ax, ax]) - self.get_param(key, as_dict=False).update_marks() - - def _infer_grid_axes(self, axes, cell, tol=1e-3): - """Returns which are the lattice vectors that correspond to each cartesian direction""" - grid_axes = [] - for ax in axes: - if ax in ("x", "y", "z"): - coord_index = "xyz".index(ax) - lattice_vecs = np.where(cell[:, coord_index] > tol)[0] - if lattice_vecs.shape[0] != 1: - raise ValueError(f"There are {lattice_vecs.shape[0]} lattice vectors that contribute to the {'xyz'[coord_index]} coordinate.") - grid_axes.append(lattice_vecs[0]) - else: - grid_axes.append("abc".index(ax)) - - return grid_axes - - def _is_cartesian_unordered(self, cell, tol=1e-3): - """Whether a cell has cartesian axes as lattice vectors, regardless of their order. - - Parameters - ----------- - cell: np.array of shape (3, 3) - The cell that you want to check. - tol: float, optional - Threshold value to consider a component of the cell nonzero. - """ - bigger_than_tol = abs(cell) > tol - return bigger_than_tol.sum() == 3 and bigger_than_tol.any(axis=0).all() and bigger_than_tol.any(axis=1).all() - - def _is_1D_cartesian(self, cell, coord_ax, tol=1e-3): - """Whether a cell contains only one vector that contributes only to a given coordinate. - - That is, one vector follows the direction of the cartesian axis and the other vectors don't - have any component in that direction. - - Parameters - ----------- - cell: np.array of shape (3, 3) - The cell that you want to check. - coord_ax: {"x", "y", "z"} - The cartesian axis that you are looking for in the cell. - tol: float, optional - Threshold value to consider a component of the cell nonzero. - """ - coord_index = "xyz".index(coord_ax) - lattice_vecs = np.where(cell[:, coord_index] > tol)[0] - - is_1D_cartesian = lattice_vecs.shape[0] == 1 - return is_1D_cartesian and (cell[lattice_vecs[0]] > tol).sum() == 1 - - def _set_data(self, axes, nsc, interp, trace_name, transforms, represent, grid_file, - x_range, y_range, z_range, plot_geom, geom_kwargs, transform_bc, reduce_method): - - if trace_name is None and grid_file: - trace_name = grid_file.name - - grid = self.grid.copy() - - self._ndim = len(axes) - self.offsets["origin"] = grid.origin - - # Choose the representation of the grid that we want to display - grid.grid = self._get_representation(grid, represent) - - # We will tile the grid now, as at the moment there's no good way to tile it afterwards - # Note that this means extra computation, as we are transforming (skewed_2d) or calculating - # the isosurfaces (3d) using more than one unit cell (FIND SMARTER WAYS!) - for ax, reps in enumerate(nsc): - grid = grid.tile(reps, ax) - - # Determine whether we should transform the grid to cartesian axes. This will be needed - # if the grid is skewed. However, it is never needed for the 3D representation, since we - # compute the coordinates of each point in the isosurface, and we don't need to reduce the - # grid. - should_orthogonalize = ~self._is_cartesian_unordered(grid.cell) and self._ndim < 3 - # We also don't need to orthogonalize if cartesian coordinates are not requested - # (this would mean that axes is a combination of "a", "b" and "c") - should_orthogonalize = should_orthogonalize and bool(set(axes).intersection(["x", "y", "z"])) - - if should_orthogonalize and self._ndim == 1: - # In 1D representations, even if the cell is skewed, we might not need to transform. - # An example of a cell that we don't need to transform is: - # a = [1, 1, 0], b = [1, -1, 0], c = [0, 0, 1] - # If the user wants to display the values on the z coordinate, we can safely reduce the - # first two axes, as they don't contribute in the Z direction. Also, it is required that - # "c" doesn't contribute to any of the other two directions. - should_orthogonalize &= not self._is_1D_cartesian(grid.cell, axes[0]) - - if should_orthogonalize: - grid, self.offsets["cell_transform"] = self._transform_grid_cell( - grid, mode=transform_bc, output_shape=(np.array(interp)*grid.shape).astype(int), cval=np.nan - ) - # The interpolation has already happened, so just set it to [1,1,1] for the rest of the method - interp = [1, 1, 1] - - # Now the grid axes correspond to the cartesian coordinates. - grid_axes = [{"x": 0, "y": 1, "z": 2}[ax] for ax in axes] - elif self._ndim < 3: - # If we are not transforming the grid, we need to get the axes of the grid that contribute to the - # directions we have to plot. - grid_axes = self._infer_grid_axes(axes, grid.cell) - elif self._ndim == 3: - grid_axes = [0, 1, 2] - - # Apply all transforms requested by the user - for transform in transforms: - grid = self._transform_grid(grid, transform) - - # Get only the part of the grid that we need - ax_ranges = [x_range, y_range, z_range] - for ax, ax_range in enumerate(ax_ranges): - if ax_range is not None: - # Build an array with the limits - lims = np.zeros((2, 3)) - # If the cell was transformed, then we need to modify - # the range to get what the user wants. - lims[:, ax] = ax_range + self.offsets["cell_transform"][ax] - self.offsets["origin"][ax] - - # Get the indices of those points - indices = np.array([grid.index(lim) for lim in lims], dtype=int) - - # And finally get the subpart of the grid - grid = grid.sub(np.arange(indices[0, ax], indices[1, ax] + 1), ax) - - # Reduce the dimensions that are not going to be displayed - for ax in [0, 1, 2]: - if ax not in grid_axes: - grid = getattr(grid, reduce_method)(ax) - - # Interpolate the grid to a different shape, if needed - interp_factors = np.array([factor if ax in grid_axes else 1 for ax, factor in enumerate(interp)], dtype=int) - interpolate = (interp_factors != 1).any() - if interpolate: - grid = grid.interp((np.array(interp_factors)*grid.shape).astype(int)) - - # Remove the leftover dimensions - values = np.squeeze(grid.grid) - - # Choose which function we need to use to prepare the data - prepare_func = getattr(self, f"_prepare{self._ndim}D") - - # Use it - backend_info = prepare_func(grid, values, axes, grid_axes, nsc, trace_name, showlegend=bool(trace_name) or values.ndim == 3) - - backend_info["ndim"] = self._ndim - - # Add also the geometry if the user requested it - # This should probably not work like this. It should make use - # of MultiplePlot somehow. The problem is that right now, the bonds - # are calculated each time this method is called, for example - geom_plot = None - if plot_geom: - geom = getattr(self.grid, 'geometry', None) - if geom is None: - warn('You asked to plot the geometry, but the grid does not contain any geometry') - else: - geom_plot = geom.plot(**{'axes': axes, "nsc": self.get_setting("nsc"), **geom_kwargs}) - - backend_info["geom_plot"] = geom_plot - - # Define the axes titles - backend_info["axes_titles"] = { - f"{ax_name}axis": GeometryPlot._get_ax_title(ax) for ax_name, ax in zip(("x", "y", "z"), axes) - } - if self._ndim == 1: - backend_info["axes_titles"]["yaxis"] = "Values" - - return backend_info - - def _get_ax_range(self, grid, ax, nsc): - if isinstance(ax, int) or ax in ("a", "b", "c"): - ax = {"a": 0, "b": 1, "c": 2}.get(ax, ax) - ax_vals = np.linspace(0, nsc[ax], grid.shape[ax]) - else: - offset = self._get_offset(grid, ax) - - ax = {"x": 0, "y": 1, "z": 2}[ax] - - ax_vals = np.arange(0, grid.cell[ax, ax], grid.dcell[ax, ax]) + offset - - if len(ax_vals) == grid.shape[ax] + 1: - ax_vals = ax_vals[:-1] - - return ax_vals - - def _get_offset(self, grid, ax, offset, x_range, y_range, z_range): - if isinstance(ax, int) or ax in ("a", "b", "c"): - return 0 - else: - coord_range = {"x": x_range, "y": y_range, "z": z_range}[ax] - grid_offset = _a.asarrayd(offset) + self.offsets["vacuum"] - - coord_index = "xyz".index(ax) - # Now let's get the offset due to the minimum value of the axis range - if coord_range is not None: - offset = coord_range[0] - else: - # If a range was specified, the cell_transform and origo offsets were applied - # when subbing the grid. Otherwise they have not been applied yet. - offset = self.offsets["cell_transform"][coord_index] + self.offsets["origin"][coord_index] - - return offset + grid_offset[coord_index] - - def _get_offsets(self, grid, display_axes=[0, 1, 2]): - return np.array([self._get_offset(grid, ax) for ax in display_axes]) - - @staticmethod - def _transform_grid(grid, transform): - - if isinstance(transform, str): - - # Since this may come from the GUI, there might be extra spaces - transform = transform.strip() - - # If is a string with no dots, we will assume it is a numpy function - if len(transform.split(".")) == 1: - transform = f"numpy.{transform}" - - return grid.apply(transform) - - @staticmethod - def _get_representation(grid, represent): - """Returns a representation of the grid - - Parameters - ------------ - grid: sisl.Grid - the grid for which we want return - represent: {"real", "imag", "mod", "phase", "deg_phase", "rad_phase"} - the type of representation. "phase" is equivalent to "rad_phase" - - Returns - ------------ - np.ndarray of shape = grid.shape - """ - if represent == 'real': - values = grid.grid.real - elif represent == 'imag': - values = grid.grid.imag - elif represent == 'mod': - values = np.absolute(grid.grid) - elif represent in ['phase', 'rad_phase', 'deg_phase']: - values = np.angle(grid.grid, deg=represent.startswith("deg")) - else: - raise ValueError(f"'{represent}' is not a valid value for the `represent` argument") - - return values - - def _prepare1D(self, grid, values, display_axes, grid_axes, nsc, name, **kwargs): - """Takes care of preparing the values to plot in 1D""" - display_ax = display_axes[0] - - return {"ax": display_ax, "values": values, "ax_range": self._get_ax_range(grid, display_ax, nsc), "name": name} - - def _prepare2D(self, grid, values, display_axes, grid_axes, nsc, name, crange, cmid, colorscale, zsmooth, isos, **kwargs): - """Takes care of preparing the values to plot in 2D""" - from skimage.measure import find_contours - xaxis = display_axes[0] - yaxis = display_axes[1] - - if grid_axes[0] < grid_axes[1]: - values = values.T - - if crange is None: - crange = [None, None] - cmin, cmax = crange - - if cmid is None and cmin is None and cmax is None: - real_vals = values[~np.isnan(values)] - if np.any(real_vals > 0) and np.any(real_vals < 0): - cmid = 0 - - xs = self._get_ax_range(grid, xaxis, nsc) - ys = self._get_ax_range(grid, yaxis, nsc) - - # Draw the contours (if any) - if len(isos) > 0: - offsets = self._get_offsets(grid, display_axes) - isos_param = self.get_param("isos") - minval = np.nanmin(values) - maxval = np.nanmax(values) - - if set(display_axes).intersection(["x", "y", "z"]): - coord_indices = ["xyz".index(ax) for ax in display_axes] - - def _indices_to_2Dspace(contour_coords): - return contour_coords.dot(grid.dcell[grid_axes, :])[:, coord_indices] - else: - def _indices_to_2Dspace(contour_coords): - return contour_coords / (np.array(grid.shape) / nsc)[grid_axes] - - isos_to_draw = [] - for iso in isos: - - iso = isos_param.complete_query(iso) - - # Infer the iso value either from val or from frac - isoval = iso.get("val") - if isoval is None: - frac = iso.get("frac") - if frac is None: - raise ValueError(f"You are providing an iso query without 'val' and 'frac'. There's no way to know the isovalue!\nquery: {iso}") - isoval = minval + (maxval-minval)*frac - - # Find contours at a constant value of 0.8 - contours = find_contours(values, isoval) - - contour_xs = [] - contour_ys = [] - for contour in contours: - # Swap the first and second columns so that we have [x,y] for each - # contour point (instead of [row, col], which means [y, x]) - contour_coords = contour[:, [1, 0]] - # Then convert from indices to coordinates in the 2D space - contour_coords = _indices_to_2Dspace(contour_coords) + offsets - contour_xs = [*contour_xs, None, *contour_coords[:, 0]] - contour_ys = [*contour_ys, None, *contour_coords[:, 1]] - - # Add the information about this isoline to the list of isolines - isos_to_draw.append({ - "x": contour_xs, "y": contour_ys, - "color": iso.get("color"), "opacity": iso.get("opacity"), - "name": iso.get("name", "").replace("$isoval$", f"{isoval:.4f}") - }) - - return { - "values": values, "x": xs, "y": ys, "zsmooth": zsmooth, - "xaxis": xaxis, "yaxis": yaxis, - "cmin": cmin, "cmax": cmax, "cmid": cmid, "colorscale": colorscale, - "name": name, "contours": isos_to_draw - } - - @staticmethod - def _transform_grid_cell(grid, cell=np.eye(3), output_shape=None, mode="constant", order=1, **kwargs): - """ - Applies a linear transformation to the grid to get it relative to arbitrary cell. - - This method can be used, for example to get the values of the grid with respect to - the standard basis, so that you can easily visualize it or overlap it with other grids - (e.g. to perform integrals). - - Parameters - ----------- - cell: array-like of shape (3,3) - these cell represent the directions that you want to use as references for - the new grid. - - The length of the axes does not have any effect! They will be rescaled to create - the minimum bounding box necessary to accomodate the unit cell. - output_shape: array-like of int of shape (3,), optional - the shape of the final output. If not provided, the current shape of the grid - will be used. - - Notice however that if the transformation applies a big shear to the image (grid) - you will probably need to have a bigger output_shape. - mode: str, optional - determines how to handle borders. See scipy docs for more info on the possible values. - order : int 0-5, optional - the order of the spline interpolation to calculate the values (since we are applying - a transformation, we don't actually have values for the new locations and we need to - interpolate them) - 1 means linear, 2 quadratic, etc... - **kwargs: - the rest of keyword arguments are passed directly to `scipy.ndimage.affine_transform` - - See also - ---------- - scipy.ndimage.affine_transform : method used to apply the linear transformation. - """ - # Take the current shape of the grid if no output shape was provided - if output_shape is None: - output_shape = grid.shape - - # Get the current cell in coordinates of the destination axes - inv_cell = cell_invert(cell).T - projected_cell = grid.cell.dot(inv_cell) - - # From that, infere how long will the bounding box of the cell be - lengths = abs(projected_cell).sum(axis=0) - - # Create the transformation matrix. Since we want to control the shape - # of the output, we can not use grid.dcell directly, we need to modify it. - scales = output_shape / lengths - forward_t = (grid.dcell.dot(inv_cell)*scales).T - - # Scipy's affine transform asks for the inverse transformation matrix, to - # map from output pixels to input pixels. By taking the inverse of our - # transformation matrix, we get exactly that. - tr = cell_invert(forward_t).T - - # Calculate the offset of the image so that all points of the grid "fall" inside - # the output array. - # For this we just calculate the centers of the input and output images - center_input = 0.5 * (_a.asarrayd(grid.shape) - 1) - center_output = 0.5 * (_a.asarrayd(output_shape) - 1) - - # And then make sure that the input center that is interpolated from the output - # falls in the actual input's center - offset = center_input - tr.dot(center_output) - - # We pass all the parameters to scipy's affine_transform - transformed_image = affine_transform(grid.grid, tr, order=1, offset=offset, - output_shape=output_shape, mode=mode, **kwargs) - - # Create a new grid with the new shape and the new cell (notice how the cell - # is rescaled from the input cell to fit the actual coordinates of the system) - new_grid = grid.__class__((1, 1, 1), lattice=cell*lengths.reshape(3, 1)) - new_grid.grid = transformed_image - - # Find the offset between the origin before and after the transformation - return new_grid, new_grid.dcell.dot(forward_t.dot(offset)) - - def _prepare3D(self, grid, values, display_axes, grid_axes, nsc, name, isos, **kwargs): - """Takes care of preparing the values to plot in 3D""" - # The minimum and maximum values might be needed at some places - minval, maxval = np.min(values), np.max(values) - - # Get the isos input field. It will be used to get the default fraction - # value and to complete queries - isos_param = self.get_param("isos") - - # If there are no iso queries, we are going to create 2 isosurfaces. - if len(isos) == 0 and maxval != minval: - - default_iso_frac = isos_param["frac"].default - - # If the default frac is 0.3, they will be displayed at 0.3 and 0.7 - isos = [ - {"frac": default_iso_frac}, - {"frac": 1-default_iso_frac} - ] - - isos_to_draw = [] - # Go through each iso query to prepare the isosurface - for iso in isos: - - iso = isos_param.complete_query(iso) - - if not iso.get("active", True): - continue - - # Infer the iso value either from val or from frac - isoval = iso.get("val") - if isoval is None: - frac = iso.get("frac") - if frac is None: - raise ValueError(f"You are providing an iso query without 'val' and 'frac'. There's no way to know the isovalue!\nquery: {iso}") - isoval = minval + (maxval-minval)*frac - - # Calculate the isosurface - vertices, faces, normals, intensities = grid.isosurface(isoval, iso.get("step_size", 1)) - - vertices = vertices + self._get_offsets(grid) + self.offsets["origin"] - - # Add all the isosurface info to the list that will be passed to the drawer - isos_to_draw.append({ - "vertices": vertices, "faces": faces, - "color": iso.get("color"), "opacity": iso.get("opacity"), - "name": iso.get("name", "").replace("$isoval$", f"{isoval:.4f}") - }) - - return {"isosurfaces": isos_to_draw} + # And maybe plot the strucuture + geom_plottings = _get_structure_plottings(plot_geom=plot_geom, geometry=geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=nsc) + + all_plottings = combined(grid_plottings, cell_plottings, geom_plottings, composite_method=None) + + return get_figure(backend=backend, plot_actions=all_plottings) + +def wavefunction_plot( + eigenstate: EigenstateData, + i: int = 0, + geometry: Optional[Geometry] = None, + grid_prec: float = 0.2, + # All grid inputs. + grid: Optional[Grid] = None, + axes: Axes = ["z"], + represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", + transforms: Sequence[Union[str, Callable]] = (), + reduce_method: Literal["average", "sum"] = "average", + boundary_mode: str = "grid-wrap", + nsc: Tuple[int, int, int] = (1, 1, 1), + interp: Tuple[int, int, int] = (1, 1, 1), + isos: Sequence[dict] = [], + smooth: bool = False, + colorscale: Optional[str] = None, + crange: Optional[Tuple[float, float]] = None, + cmid: Optional[float] = None, + show_cell: Literal["box", "axes", False] = "box", + cell_style: dict = {}, + x_range: Optional[Sequence[float]] = None, + y_range: Optional[Sequence[float]] = None, + z_range: Optional[Sequence[float]] = None, + plot_geom: bool = False, + geom_kwargs: dict = {}, + backend: str = "plotly" +) -> Figure: + """Plots a wavefunction in real space. + + Parameters + ---------- + eigenstate: + The eigenstate object containing information about eigenstates. + i: + The index of the eigenstate to plot. + geometry: + Geometry to use to project the eigenstate to real space. + If None, the geometry associated with the eigenstate is used. + grid_prec: + The precision of the grid where the wavefunction is projected. + grid: + The grid to plot. + axes: + The axes to project the grid to. + represent: + The representation of the grid to plot. + transforms: + List of transforms to apply to the grid before plotting. + reduce_method: + The method used to reduce the grid axes that are not displayed. + boundary_mode: + The method used to deal with the boundary conditions. + Only used if the grid is to be orthogonalized. + See scipy docs for more info on the possible values. + nsc: + The number of unit cells to display in each direction. + interp: + The interpolation factor to use for each axis to make the grid smoother. + isos: + List of isosurfaces or isocontours to plot. See the showcase notebooks for examples. + smooth: + Whether to ask the plotting backend to make an attempt at smoothing the grid display. + colorscale: + Colorscale to use for the grid display in the 2D representation. + If None, the default colorscale is used for each backend. + crange: + Min and max values for the colorscale. + cmid: + The value at which the colorscale is centered. + show_cell: + Method used to display the unit cell. If False, the cell is not displayed. + cell_style: + Style specification for the cell. See the showcase notebooks for examples. + x_range: + The range of the x axis to take into account. + Even if the X axis is not displayed! This is important because the reducing + operation will only be applied on this range. + y_range: + The range of the y axis to take into account. + Even if the Y axis is not displayed! This is important because the reducing + operation will only be applied on this range. + z_range: + The range of the z axis to take into account. + Even if the Z axis is not displayed! This is important because the reducing + operation will only be applied on this range. + plot_geom: + Whether to plot the associated geometry (if any). + geom_kwargs: + Keyword arguments to pass to the geometry plot of the associated geometry. + backend: + The backend to use to generate the figure. + + See also + ---------- + scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed. + """ + + # Create a grid with the wavefunction in it. + i_eigenstate = get_eigenstate(eigenstate, i) + geometry = eigenstate_geometry(eigenstate, geometry=geometry) - def _add_shortcuts(self): + tiled_geometry = tile_if_k(geometry=geometry, nsc=nsc, eigenstate=i_eigenstate) + grid_nsc = get_grid_nsc(nsc=nsc, eigenstate=i_eigenstate) + grid = project_wavefunction(eigenstate=i_eigenstate, grid_prec=grid_prec, grid=grid, geometry=tiled_geometry) - axes = ["x", "y", "z"] + # Grid processing + axes = sanitize_axes(axes) - for ax in axes: + grid_repr = get_grid_representation(grid, represent=represent) - self.add_shortcut(f'{ax.lower()}+enter', f"Show {ax} axis", self.update_settings, axes=[ax]) + tiled_grid = tile_grid(grid_repr, nsc=grid_nsc) - self.add_shortcut(f'{ax.lower()} {ax.lower()}', f"Duplicate {ax} axis", self.tile, 2, ax) + ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode) + + grid_axes = get_grid_axes(ort_grid, axes=axes) - self.add_shortcut(f'{ax.lower()}+-', f"Substract a unit cell along {ax}", self.tighten, 1, ax) - - self.add_shortcut(f'{ax.lower()}++', f"Add a unit cell along {ax}", self.tighten, -1, ax) - - for xaxis in axes: - for yaxis in [ax for ax in axes if ax != xaxis]: - self.add_shortcut( - f'{xaxis.lower()}+{yaxis.lower()}', f"Show {xaxis} and {yaxis} axes", - self.update_settings, axes=[xaxis, yaxis] - ) - - def tighten(self, steps, ax): - """ - Makes the supercell tighter by a number of unit cells - - Parameters - --------- - steps: int or array-like - Number of unit cells that you want to substract. - If there are not enough unit cells to substract, one unit cell will remain. - - If you provide multiple steps, it needs to match the number of axes provided. - ax: int or array-like - Axis along which to tighten the supercell. - - If you provide multiple axes, the number of different steps must match the number of axes or be a single int. - """ - if isinstance(ax, int): - ax = [ax] - if isinstance(steps, int): - steps = [steps]*len(ax) - - nsc = list(self.get_setting("nsc")) - - for a, step in zip(ax, steps): - nsc[a] = max(1, nsc[a]-step) - - return self.update_settings(nsc=nsc) - - def tile(self, tiles, ax): - """ - Tile a given axis to display more unit cells in the plot - - Parameters - ---------- - tiles: int or array-like - factor by which the supercell will be multiplied along axes `ax`. - - If you provide multiple tiles, it needs to match the number of axes provided. - ax: int or array-like - axis that you want to tile. - - If you provide multiple axes, the number of different tiles must match the number of axes or be a single int. - """ - if isinstance(ax, int): - ax = [ax] - if isinstance(tiles, int): - tiles = [tiles]*len(ax) - - nsc = [*self.get_setting("nsc")] - - for a, tile in zip(ax, tiles): - nsc[a] *= tile - - return self.update_settings(nsc=nsc) - - def scan(self, along, start=None, stop=None, step=None, num=None, breakpoints=None, mode="moving_slice", animation_kwargs=None, **kwargs): - """ - Returns an animation containing multiple frames scaning along an axis. - - Parameters - ----------- - along: {"x", "y", "z"} - the axis along which the scan is performed. If not provided, it will scan along the axes that are not displayed. - start: float, optional - the starting value for the scan (in Angstrom). - Make sure this value is inside the range of the unit cell, otherwise it will fail. - stop: float, optional - the last value of the scan (in Angstrom). - Make sure this value is inside the range of the unit cell, otherwise it will fail. - step: float, optional - the distance between steps in Angstrom. - - If not provided and `num` is also not provided, it will default to 1 Ang. - num: int , optional - the number of steps that you want the scan to consist of. - - If `step` is passed, this argument is ignored. - - Note that the grid is only stored once, so having a big number of steps is not that big of a deal. - breakpoints: array-like, optional - the discrete points of the scan. To be used if you don't want regular steps. - If the last step is exactly the length of the cell, it will be moved one dcell back to avoid errors. - - Note that if this parameter is passed, both `step` and `num` are ignored. - mode: {"moving_slice", "as_is"}, optional - the type of scan you want to see. - "moving_slice" renders a volumetric scan where a slice moves through the grid. - "as_is" renders each part of the scan as an animation frame. - (therefore, "as_is" SUPPORTS SCANNING 1D, 2D AND 3D REPRESENTATIONS OF THE GRID, e.g. display the volume data for different ranges of z) - animation_kwargs: dict, optional - dictionary whose keys and values are directly passed to the animated method as kwargs and therefore - end up being passed to animation initialization. - **kwargs: - the rest of settings that you want to apply to overwrite the existing ones. - - This settings apply to each plot and go directly to their initialization. - - Returns - ------- - sisl.viz.plotly.Animation - An animation representation of the scan - """ - # Do some checks on the args provided - if sum(1 for arg in (step, num, breakpoints) if arg is not None) > 1: - raise ValueError(f"Only one of ('step', 'num', 'breakpoints') should be passed.") - - axes = self.get_setting('axes') - if mode == "as_is" and set(axes) - set(["x", "y", "z"]): - raise ValueError("To perform a scan, the axes need to be cartesian. Please set the axes to a combination of 'x', 'y' and 'z'.") - - if self.grid.lattice.is_cartesian(): - grid = self.grid - else: - transform_bc = kwargs.pop("transform_bc", self.get_setting("transform_bc")) - grid, transform_offset = self._transform_grid_cell( - self.grid, mode=transform_bc, output_shape=self.grid.shape, cval=np.nan - ) - - kwargs["offset"] = transform_offset + kwargs.get("offset", self.get_setting("offset")) - - # We get the key that needs to be animated (we will divide the full range in frames) - range_key = f"{along}_range" - along_i = {"x": 0, "y": 1, "z": 2}[along] - - # Get the full range - if start is not None and stop is not None: - along_range = [start, stop] - else: - along_range = self.get_setting(range_key) - if along_range is None: - range_param = self.get_param(range_key) - along_range = [range_param[f"inputField.params.{lim}"] for lim in ["min", "max"]] - if start is not None: - along_range[0] = start - if stop is not None: - along_range[1] = stop - - if breakpoints is None: - if step is None and num is None: - step = 1.0 - if step is None: - step = (along_range[1] - along_range[0]) / num - else: - num = (along_range[1] - along_range[0]) // step - - # np.linspace will use the last point as a step (and we don't want it) - # therefore we will add an extra step - breakpoints = np.linspace(*along_range, int(num) + 1) - - if breakpoints[-1] == grid.cell[along_i, along_i]: - breakpoints[-1] = grid.cell[along_i, along_i] - grid.dcell[along_i, along_i] - - if mode == "moving_slice": - return self._moving_slice_scan(grid, along_i, breakpoints) - elif mode == "as_is": - return self._asis_scan(grid, range_key, breakpoints, animation_kwargs=animation_kwargs, **kwargs) - - def _asis_scan(self, grid, range_key, breakpoints, animation_kwargs=None, **kwargs): - """ - Returns an animation containing multiple frames scaning along an axis. - - Parameters - ----------- - range_key: {'x_range', 'y_range', 'z_range'} - the key of the setting that is to be animated through the scan. - breakpoints: array-like - the discrete points of the scan - animation_kwargs: dict, optional - dictionary whose keys and values are directly passed to the animated method as kwargs and therefore - end up being passed to animation initialization. - **kwargs: - the rest of settings that you want to apply to overwrite the existing ones. - - This settings apply to each plot and go directly to their initialization. - - Returns - ---------- - scan: sisl Animation - An animation representation of the scan - """ - # Generate the plot using self as a template so that plots don't need - # to read data, just process it and show it differently. - # (If each plot read the grid, the memory requirements would be HUGE) - scan = self.animated( - { - range_key: [[bp, breakpoints[i+1]] for i, bp in enumerate(breakpoints[:-1])] - }, - fixed={**{key: val for key, val in self.settings.items() if key != range_key}, **kwargs, "grid": grid}, - frame_names=[f'{bp:2f}' for bp in breakpoints], - **(animation_kwargs or {}) - ) - - # Set all frames to the same colorscale, if it's a 2d representation - if len(self.get_setting("axes")) == 2: - cmin = 10**6; cmax = -10**6 - for scan_im in scan: - c = getattr(scan_im.data[0], "value", scan_im.data[0].z) - cmin = min(cmin, np.min(c)) - cmax = max(cmax, np.max(c)) - for scan_im in scan: - scan_im.update_settings(crange=[cmin, cmax]) - - scan.get_figure() - - scan.layout = self.layout - - return scan - - def _moving_slice_scan(self, grid, along_i, breakpoints): - import plotly.graph_objs as go - ax = along_i - displayed_axes = [i for i in range(3) if i != ax] - shape = np.array(grid.shape)[displayed_axes] - cmin = np.min(grid.grid) - cmax = np.max(grid.grid) - x_ax, y_ax = displayed_axes - x = np.linspace(0, grid.cell[x_ax, x_ax], grid.shape[x_ax]) - y = np.linspace(0, grid.cell[y_ax, y_ax], grid.shape[y_ax]) - - fig = go.Figure(frames=[go.Frame(data=go.Surface( - x=x, y=y, - z=(bp * np.ones(shape)).T, - surfacecolor=np.squeeze(grid.cross_section(grid.index(bp, ax), ax).grid).T, - cmin=cmin, cmax=cmax, - ), - name=f'{bp:.2f}' - ) - for bp in breakpoints]) - - # Add data to be displayed before animation starts - fig.add_traces(fig.frames[0].data) - - def frame_args(duration): - return { - "frame": {"duration": duration}, - "mode": "immediate", - "fromcurrent": True, - "transition": {"duration": duration, "easing": "linear"}, - } - - sliders = [ - { - "pad": {"b": 10, "t": 60}, - "len": 0.9, - "x": 0.1, - "y": 0, - "steps": [ - { - "args": [[f.name], frame_args(0)], - "label": str(k), - "method": "animate", - } - for k, f in enumerate(fig.frames) - ], - } - ] - - def ax_title(ax): return f'{["X", "Y", "Z"][ax]} axis [Ang]' - - # Layout - fig.update_layout( - title=f'Grid scan along {["X", "Y", "Z"][ax]} axis', - width=600, - height=600, - scene=dict( - xaxis=dict(title=ax_title(x_ax)), - yaxis=dict(title=ax_title(y_ax)), - zaxis=dict(autorange=True, title=ax_title(ax)), - aspectmode="data", - ), - updatemenus = [ - { - "buttons": [ - { - "args": [None, frame_args(50)], - "label": "▶", # play symbol - "method": "animate", - }, - { - "args": [[None], frame_args(0)], - "label": "◼", # pause symbol - "method": "animate", - }, - ], - "direction": "left", - "pad": {"r": 10, "t": 70}, - "type": "buttons", - "x": 0.1, - "y": 0, - } - ], - sliders=sliders - ) - - # We need to add an invisible trace so that the z axis stays with the correct range - fig.add_trace({"type": "scatter3d", "mode": "markers", "marker_size": 0.001, "x": [0, 0], "y": [0, 0], "z": [0, grid.cell[ax, ax]]}) - - return fig + transformed_grid = apply_transforms(ort_grid, transforms) + subbed_grid = sub_grid(transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range) -class WavefunctionPlot(GridPlot): - """ - An extension of GridPlot specifically tailored for plotting wavefunctions + reduced_grid = reduce_grid(subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes) - Parameters - ----------- - eigenstate: EigenstateElectron, optional - The eigenstate that contains the coefficients of the wavefunction. - Note that an eigenstate can contain coefficients for multiple states. - wfsx_file: wfsxSileSiesta, optional - Siesta WFSX file to directly read the coefficients from. - If the root_fdf file is provided but the wfsx one isn't, we will try - to find it as SystemLabel.WFSX. - geometry: Geometry, optional - Necessary to generate the grid and to plot the wavefunctions, since - the basis orbitals are needed. If you provide a - hamiltonian, the geometry is probably inside the hamiltonian, so you - don't need to provide it. However, this field is - compulsory if you are providing the eigenstate directly. - k: array-like, optional - If the eigenstates need to be calculated from a hamiltonian, the k - point for which you want them to be calculated - spin: optional - The spin component where the eigenstate should be calculated. - Only meaningful if the state needs to be calculated from the - hamiltonian. - grid_prec: float, optional - The spacing between points of the grid where the wavefunction will be - projected (in Ang). If you are plotting a 3D - representation, take into account that a very fine and big grid could - result in your computer crashing on render. If it's the - first time you are using this function, assess the - capabilities of your computer by first using a low-precision grid and - increase it gradually. - i: int, optional - The index of the wavefunction - grid: Grid, optional - A sisl.Grid object. If provided, grid_file is ignored. - grid_file: cubeSile or rhoSileSiesta or ldosSileSiesta or rhoinitSileSiesta or rhoxcSileSiesta or drhoSileSiesta or baderSileSiesta or iorhoSileSiesta or totalrhoSileSiesta or stsSileSiesta or stmldosSileSiesta or hartreeSileSiesta or neutralatomhartreeSileSiesta or totalhartreeSileSiesta or gridncSileSiesta or ncSileSiesta or fdfSileSiesta or tsvncSileSiesta or chgSileVASP or locpotSileVASP, optional - A filename that can be return a Grid through `read_grid`. - represent: optional - The representation of the grid that should be displayed - transforms: optional - Transformations to apply to the whole grid. It can be a - function, or a string that represents the path to a - function (e.g. "scipy.exp"). If a string that is a single - word is provided, numpy will be assumed to be the module (e.g. - "square" will be converted into "np.square"). Note that - transformations will be applied in the order provided. Some - transforms might not be necessarily commutable (e.g. "abs" and - "cos"). - axes: optional - The axis along you want to see the grid, it will be reduced along the - other ones, according to the the `reduce_method` setting. - zsmooth: optional - Parameter that smoothens how data looks in a heatmap. - 'best' interpolates data, 'fast' interpolates pixels, 'False' - displays the data as is. - interp: array-like, optional - Interpolation factors to make the grid finer on each axis.See the - zsmooth setting for faster smoothing of 2D heatmap. - transform_bc: optional - The boundary conditions when a cell transform is applied to the grid. - Cell transforms are only applied when the grid's cell - doesn't follow the cartesian coordinates and the requested display is - 2D or 1D. - nsc: array-like, optional - Number of times the grid should be repeated - offset: array-like, optional - The offset of the grid along each axis. This is important if you are - planning to match this grid with other geometry related plots. - trace_name: str, optional - The name that the trace will show in the legend. Good when merging - with other plots to be able to toggle the trace in the legend - x_range: array-like of shape (2,), optional - Range where the X is displayed. Should be inside the unit cell, - otherwise it will fail. - y_range: array-like of shape (2,), optional - Range where the Y is displayed. Should be inside the unit cell, - otherwise it will fail. - z_range: array-like of shape (2,), optional - Range where the Z is displayed. Should be inside the unit cell, - otherwise it will fail. - crange: array-like of shape (2,), optional - The range of values that the colorbar must enclose. This controls - saturation and hides below threshold values. - cmid: int, optional - The value to set at the center of the colorbar. If not provided, the - color range is used - colorscale: str, optional - A valid plotly colorscale. See https://plotly.com/python/colorscales/ - reduce_method: optional - The method used to reduce the dimensions that will not be displayed - in the plot. - isos: array-like of dict, optional - The isovalues that you want to represent. The way they - will be represented is of course dependant on the type of - representation: - 2D representations: A contour (i.e. - a line) - 3D representations: A surface - Each item is a dict. Structure of the dict: { 'name': The - name of the iso query. Note that you can use $isoval$ as a template - to indicate where the isoval should go. 'val': The iso value. - If not provided, it will be infered from `frac` 'frac': If - val is not provided, this is used to calculate where the isosurface - should be drawn. It calculates them from the - minimum and maximum values of the grid like so: - If iso_frac = 0.3: (min_value----- - ISOVALUE(30%)-----------max_value) Therefore, it - should be a number between 0 and 1. - 'step_size': The step size to use to calculate the isosurface in case - it's a 3D representation A bigger step-size can - speed up the process dramatically, specially the rendering part - and the resolution may still be more than satisfactory (try to use - step_size=2). For very big grids your computer - may not even be able to render very fine surfaces, so it's worth - keeping this setting in mind. 'color': - The color of the surface/contour. 'opacity': Opacity of the - surface/contour. Between 0 (transparent) and 1 (opaque). } - plot_geom: bool, optional - If True the geometry associated to the grid will also be plotted - geom_kwargs: dict, optional - Extra arguments that are passed to geom.plot() if plot_geom is set to - True - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - """ + interp_grid = interpolate_grid(reduced_grid, interp=interp) - _plot_type = 'Wavefunction' - - _parameters = ( - - PlotableInput(key="eigenstate", name="Electron eigenstate", - default=None, - dtype=sisl.EigenstateElectron, - help="""The eigenstate that contains the coefficients of the wavefunction. - Note that an eigenstate can contain coefficients for multiple states. - """ - ), - - SileInput(key='wfsx_file', name='Path to WFSX file', - dtype=sisl.io.siesta.wfsxSileSiesta, - default=None, - help="""Siesta WFSX file to directly read the coefficients from. - If the root_fdf file is provided but the wfsx one isn't, we will try to find it - as SystemLabel.WFSX. - """ - ), - - SislObjectInput(key='geometry', name='Geometry', - default=None, - dtype=sisl.Geometry, - help="""Necessary to generate the grid and to plot the wavefunctions, since the basis orbitals are needed. - If you provide a hamiltonian, the geometry is probably inside the hamiltonian, so you don't need to provide it. - However, this field is compulsory if you are providing the eigenstate directly.""" - ), - - Array1DInput(key='k', name='K point', - default=(0, 0, 0), - help="""If the eigenstates need to be calculated from a hamiltonian, the k point for which you want them to be calculated""" - ), - - SpinSelect(key='spin', name="Spin", - default=0, - help="""The spin component where the eigenstate should be calculated. - Only meaningful if the state needs to be calculated from the hamiltonian.""", - only_if_polarized=True, - ), - - FloatInput(key='grid_prec', name='Grid precision', - default=0.2, - help="""The spacing between points of the grid where the wavefunction will be projected (in Ang). - If you are plotting a 3D representation, take into account that a very fine and big grid could result in - your computer crashing on render. If it's the first time you are using this function, - assess the capabilities of your computer by first using a low-precision grid and increase - it gradually. - """ - ), - - IntegerInput(key='i', name='Wavefunction index', - default=0, - help="The index of the wavefunction" - ), + # Finally, here comes the plotting! + grid_ds = grid_to_dataarray(interp_grid, axes=axes, grid_axes=grid_axes, nsc=grid_nsc) + grid_plottings = draw_grid(data=grid_ds, isos=isos, colorscale=colorscale, crange=crange, cmid=cmid, smooth=smooth) + # Process the cell as well + cell_plottings = cell_plot_actions( + cell=grid, show_cell=show_cell, cell_style=cell_style, + axes=axes, ) - _overwrite_defaults = { - 'axes': "xyz", - 'plot_geom': True - } - - @entry_point('eigenstate', 0) - def _read_nosource(self, eigenstate): - """ - Uses an already calculated Eigenstate object to generate the wavefunctions. - """ - if eigenstate is None: - raise ValueError('No eigenstate was provided') - - self.eigenstate = eigenstate - - @entry_point('wfsx file', 1) - def _read_from_WFSX_file(self, wfsx_file, k, spin, root_fdf): - """Reads the wavefunction coefficients from a SIESTA WFSX file""" - # Try to read the geometry - fdf = self.get_sile(root_fdf or "root_fdf") - if fdf is None: - raise ValueError("The setting 'root_fdf' needs to point to an fdf file with a geometry") - geometry = fdf.read_geometry(output=True) - - # Get the WFSX file. If not provided, it is inferred from the fdf. - wfsx = self.get_sile(wfsx_file or "wfsx_file") - if not wfsx.file.is_file(): - raise ValueError(f"File '{wfsx.file}' does not exist.") - - sizes = wfsx.read_sizes() - H = sisl.Hamiltonian(geometry, dim=sizes.nspin) - - wfsx = sisl.get_sile(wfsx.file, parent=H) - - # Try to find the eigenstate that we need - self.eigenstate = wfsx.read_eigenstate(k=k, spin=spin[0]) - if self.eigenstate is None: - # We have not found it. - raise ValueError(f"A state with k={k} was not found in file {wfsx.file}.") - - @entry_point('hamiltonian', 2) - def _read_from_H(self, k, spin): - """ - Calculates the eigenstates from a Hamiltonian and then generates the wavefunctions. - """ - self.setup_hamiltonian() - - self.eigenstate = self.H.eigenstate(k, spin=spin[0]) - - def _after_read(self): - # Just avoid here GridPlot's _after_grid. Note that we are - # calling it later in _set_data - pass - - def _get_eigenstate(self, i): - - if "index" in self.eigenstate.info: - wf_i = np.nonzero(self.eigenstate.info["index"] == i)[0] - if len(wf_i) == 0: - raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {self.eigenstate.info['index']}." - f"Entry point used: {self.source._name}") - wf_i = wf_i[0] - else: - max_index = len(self.eigenstate) - if i > max_index: - raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}]." - f"Entry point used: {self.source._name}") - wf_i = i - - return self.eigenstate[wf_i] - - def _set_data(self, i, geometry, grid, k, grid_prec, nsc): - - if geometry is not None: - self.geometry = geometry - elif isinstance(self.eigenstate.parent, sisl.Geometry): - self.geometry = self.eigenstate.parent - else: - self.geometry = getattr(self.eigenstate.parent, "geometry", None) - if self.geometry is None: - raise ValueError('No geometry was provided and we need it the basis orbitals to build the wavefunctions from the coefficients!') - - # Get the spin class for which the eigenstate was calculated. - spin = sisl.Spin() - if self.eigenstate.parent is not None: - spin = getattr(self.eigenstate.parent, "spin", None) - - # Check that number of orbitals match - no = self.eigenstate.shape[1] * (1 if spin.is_diagonal else 2) - if self.geometry.no != no: - raise ValueError(f"Number of orbitals in the state ({no}) and the geometry ({self.geometry.no}) don't match." - " This is most likely because the geometry doesn't contain the appropiate basis.") - - # Move all atoms inside the unit cell, otherwise the wavefunction is not - # properly displayed. - self.geometry = self.geometry.copy() - self.geometry.xyz = (self.geometry.fxyz % 1).dot(self.geometry.cell) - - # If we are calculating the wavefunction for any point other than gamma, - # the periodicity of the WF will be bigger than the cell. Therefore, if - # the user wants to see more than the unit cell, we need to generate the - # wavefunction for all the supercell. Here we intercept the `nsc` setting - # with this objective. - tiled_geometry = self.geometry - nsc = list(nsc) - for ax, sc_i in enumerate(nsc): - if k[ax] != 0: - tiled_geometry = tiled_geometry.tile(sc_i, ax) - nsc[ax] = 1 - - is_gamma = (np.array(k) == 0).all() - if grid is None: - dtype = np.float64 if is_gamma else np.complex128 - self.grid = sisl.Grid(grid_prec, geometry=tiled_geometry, dtype=dtype) - grid = self.grid - - # GridPlot's after_read basically sets the x_range, y_range and z_range options - # which need to know what the grid is, that's why we are calling it here - super()._after_read() - - # Get the particular WF that we want from the eigenstate object - wf_state = self._get_eigenstate(i) - - # Ensure we are dealing with the R gauge - wf_state.change_gauge('R') - - # Finally, insert the wavefunction values into the grid. - sisl.physics.electron.wavefunction( - wf_state.state, grid, geometry=tiled_geometry, - k=k, spinor=0, spin=spin - ) - - return super()._set_data(nsc=nsc, trace_name=f"WF {i} ({wf_state.eig[0]:.2f} eV)") - -GridPlot.backends.register_child(WavefunctionPlot.backends) + # And maybe plot the strucuture + geom_plottings = _get_structure_plottings(plot_geom=plot_geom, geometry=tiled_geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=grid_nsc) + + all_plottings = combined(grid_plottings, cell_plottings, geom_plottings, composite_method=None) + + return get_figure(backend=backend, plot_actions=all_plottings) + +class GridPlot(Plot): + + function = staticmethod(grid_plot) + +class WavefunctionPlot(GridPlot): + + function = staticmethod(wavefunction_plot) + +# The following commented code is from the old viz module, where the GridPlot had a scan method. +# It looks very nice, but probably should be reimplemented as a standalone function that plots a grid slice, +# and then merge those grid slices to create a scan. + +# def scan(self, along, start=None, stop=None, step=None, num=None, breakpoints=None, mode="moving_slice", animation_kwargs=None, **kwargs): +# """ +# Returns an animation containing multiple frames scaning along an axis. + +# Parameters +# ----------- +# along: {"x", "y", "z"} +# the axis along which the scan is performed. If not provided, it will scan along the axes that are not displayed. +# start: float, optional +# the starting value for the scan (in Angstrom). +# Make sure this value is inside the range of the unit cell, otherwise it will fail. +# stop: float, optional +# the last value of the scan (in Angstrom). +# Make sure this value is inside the range of the unit cell, otherwise it will fail. +# step: float, optional +# the distance between steps in Angstrom. + +# If not provided and `num` is also not provided, it will default to 1 Ang. +# num: int , optional +# the number of steps that you want the scan to consist of. + +# If `step` is passed, this argument is ignored. + +# Note that the grid is only stored once, so having a big number of steps is not that big of a deal. +# breakpoints: array-like, optional +# the discrete points of the scan. To be used if you don't want regular steps. +# If the last step is exactly the length of the cell, it will be moved one dcell back to avoid errors. + +# Note that if this parameter is passed, both `step` and `num` are ignored. +# mode: {"moving_slice", "as_is"}, optional +# the type of scan you want to see. +# "moving_slice" renders a volumetric scan where a slice moves through the grid. +# "as_is" renders each part of the scan as an animation frame. +# (therefore, "as_is" SUPPORTS SCANNING 1D, 2D AND 3D REPRESENTATIONS OF THE GRID, e.g. display the volume data for different ranges of z) +# animation_kwargs: dict, optional +# dictionary whose keys and values are directly passed to the animated method as kwargs and therefore +# end up being passed to animation initialization. +# **kwargs: +# the rest of settings that you want to apply to overwrite the existing ones. + +# This settings apply to each plot and go directly to their initialization. + +# Returns +# ------- +# sisl.viz.plotly.Animation +# An animation representation of the scan +# """ +# # Do some checks on the args provided +# if sum(1 for arg in (step, num, breakpoints) if arg is not None) > 1: +# raise ValueError(f"Only one of ('step', 'num', 'breakpoints') should be passed.") + +# axes = self.inputs['axes'] +# if mode == "as_is" and set(axes) - set(["x", "y", "z"]): +# raise ValueError("To perform a scan, the axes need to be cartesian. Please set the axes to a combination of 'x', 'y' and 'z'.") + +# if self.grid.lattice.is_cartesian(): +# grid = self.grid +# else: +# transform_bc = kwargs.pop("transform_bc", self.get_setting("transform_bc")) +# grid, transform_offset = self._transform_grid_cell( +# self.grid, mode=transform_bc, output_shape=self.grid.shape, cval=np.nan +# ) + +# kwargs["offset"] = transform_offset + kwargs.get("offset", self.get_setting("offset")) + +# # We get the key that needs to be animated (we will divide the full range in frames) +# range_key = f"{along}_range" +# along_i = {"x": 0, "y": 1, "z": 2}[along] + +# # Get the full range +# if start is not None and stop is not None: +# along_range = [start, stop] +# else: +# along_range = self.get_setting(range_key) +# if along_range is None: +# range_param = self.get_param(range_key) +# along_range = [range_param[f"inputField.params.{lim}"] for lim in ["min", "max"]] +# if start is not None: +# along_range[0] = start +# if stop is not None: +# along_range[1] = stop + +# if breakpoints is None: +# if step is None and num is None: +# step = 1.0 +# if step is None: +# step = (along_range[1] - along_range[0]) / num +# else: +# num = (along_range[1] - along_range[0]) // step + +# # np.linspace will use the last point as a step (and we don't want it) +# # therefore we will add an extra step +# breakpoints = np.linspace(*along_range, int(num) + 1) + +# if breakpoints[-1] == grid.cell[along_i, along_i]: +# breakpoints[-1] = grid.cell[along_i, along_i] - grid.dcell[along_i, along_i] + +# if mode == "moving_slice": +# return self._moving_slice_scan(grid, along_i, breakpoints) +# elif mode == "as_is": +# return self._asis_scan(grid, range_key, breakpoints, animation_kwargs=animation_kwargs, **kwargs) + +# def _asis_scan(self, grid, range_key, breakpoints, animation_kwargs=None, **kwargs): +# """ +# Returns an animation containing multiple frames scaning along an axis. + +# Parameters +# ----------- +# range_key: {'x_range', 'y_range', 'z_range'} +# the key of the setting that is to be animated through the scan. +# breakpoints: array-like +# the discrete points of the scan +# animation_kwargs: dict, optional +# dictionary whose keys and values are directly passed to the animated method as kwargs and therefore +# end up being passed to animation initialization. +# **kwargs: +# the rest of settings that you want to apply to overwrite the existing ones. + +# This settings apply to each plot and go directly to their initialization. + +# Returns +# ---------- +# scan: sisl Animation +# An animation representation of the scan +# """ +# # Generate the plot using self as a template so that plots don't need +# # to read data, just process it and show it differently. +# # (If each plot read the grid, the memory requirements would be HUGE) +# scan = self.animated( +# { +# range_key: [[bp, breakpoints[i+1]] for i, bp in enumerate(breakpoints[:-1])] +# }, +# fixed={**{key: val for key, val in self.settings.items() if key != range_key}, **kwargs, "grid": grid}, +# frame_names=[f'{bp:2f}' for bp in breakpoints], +# **(animation_kwargs or {}) +# ) + +# # Set all frames to the same colorscale, if it's a 2d representation +# if len(self.get_setting("axes")) == 2: +# cmin = 10**6; cmax = -10**6 +# for scan_im in scan: +# c = getattr(scan_im.data[0], "value", scan_im.data[0].z) +# cmin = min(cmin, np.min(c)) +# cmax = max(cmax, np.max(c)) +# for scan_im in scan: +# scan_im.update_settings(crange=[cmin, cmax]) + +# scan.get_figure() + +# scan.layout = self.layout + +# return scan + +# def _moving_slice_scan(self, grid, along_i, breakpoints): +# import plotly.graph_objs as go +# ax = along_i +# displayed_axes = [i for i in range(3) if i != ax] +# shape = np.array(grid.shape)[displayed_axes] +# cmin = np.min(grid.grid) +# cmax = np.max(grid.grid) +# x_ax, y_ax = displayed_axes +# x = np.linspace(0, grid.cell[x_ax, x_ax], grid.shape[x_ax]) +# y = np.linspace(0, grid.cell[y_ax, y_ax], grid.shape[y_ax]) + +# fig = go.Figure(frames=[go.Frame(data=go.Surface( +# x=x, y=y, +# z=(bp * np.ones(shape)).T, +# surfacecolor=np.squeeze(grid.cross_section(grid.index(bp, ax), ax).grid).T, +# cmin=cmin, cmax=cmax, +# ), +# name=f'{bp:.2f}' +# ) +# for bp in breakpoints]) + +# # Add data to be displayed before animation starts +# fig.add_traces(fig.frames[0].data) + +# def frame_args(duration): +# return { +# "frame": {"duration": duration}, +# "mode": "immediate", +# "fromcurrent": True, +# "transition": {"duration": duration, "easing": "linear"}, +# } + +# sliders = [ +# { +# "pad": {"b": 10, "t": 60}, +# "len": 0.9, +# "x": 0.1, +# "y": 0, +# "steps": [ +# { +# "args": [[f.name], frame_args(0)], +# "label": str(k), +# "method": "animate", +# } +# for k, f in enumerate(fig.frames) +# ], +# } +# ] + +# def ax_title(ax): return f'{["X", "Y", "Z"][ax]} axis [Ang]' + +# # Layout +# fig.update_layout( +# title=f'Grid scan along {["X", "Y", "Z"][ax]} axis', +# width=600, +# height=600, +# scene=dict( +# xaxis=dict(title=ax_title(x_ax)), +# yaxis=dict(title=ax_title(y_ax)), +# zaxis=dict(autorange=True, title=ax_title(ax)), +# aspectmode="data", +# ), +# updatemenus = [ +# { +# "buttons": [ +# { +# "args": [None, frame_args(50)], +# "label": "▶", # play symbol +# "method": "animate", +# }, +# { +# "args": [[None], frame_args(0)], +# "label": "◼", # pause symbol +# "method": "animate", +# }, +# ], +# "direction": "left", +# "pad": {"r": 10, "t": 70}, +# "type": "buttons", +# "x": 0.1, +# "y": 0, +# } +# ], +# sliders=sliders +# ) + +# # We need to add an invisible trace so that the z axis stays with the correct range +# fig.add_trace({"type": "scatter3d", "mode": "markers", "marker_size": 0.001, "x": [0, 0], "y": [0, 0], "z": [0, grid.cell[ax, ax]]}) + +# return fig diff --git a/src/sisl/viz/plots/merged.py b/src/sisl/viz/plots/merged.py new file mode 100644 index 0000000000..a9640ffb82 --- /dev/null +++ b/src/sisl/viz/plots/merged.py @@ -0,0 +1,36 @@ +from typing import Literal, Optional + +from ..figure import Figure, get_figure +from ..plot import Plot +from ..plotters.plot_actions import combined + + +def merge_plots(*figures: Figure, + composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = None, + backend: Literal["plotly", "matplotlib", "py3dmol", "blender"] = "plotly", + **kwargs +) -> Figure: + """Combines multiple plots into a single figure. + + Parameters + ---------- + *figures : Figure + The figures (or plots) to combine. + composite_method : {"multiple", "subplots", "multiple_x", "multiple_y", "animation", None}, optional + The method to use to combine the plots. None is the same as multiple. + backend : {"plotly", "matplotlib", "py3dmol", "blender"}, optional + The backend to use for the merged figure. + **kwargs + Additional arguments that will be passed to the `_init_figure_*` method of the Figure class. + The arguments accepted here are basically backend specific, but for subplots all backends should + support `rows` and `cols` to specify the number of rows and columns of the subplots, and `arrange` + which controls the arrangement ("rows", "cols" or "square"). + """ + + plot_actions = combined( + *[fig.plot_actions for fig in figures], + composite_method=composite_method, + **kwargs + ) + + return get_figure(plot_actions=plot_actions, backend=backend) diff --git a/src/sisl/viz/plots/orbital_groups_plot.py b/src/sisl/viz/plots/orbital_groups_plot.py new file mode 100644 index 0000000000..4f26884409 --- /dev/null +++ b/src/sisl/viz/plots/orbital_groups_plot.py @@ -0,0 +1,216 @@ +from ..plot import Plot + + +class OrbitalGroupsPlot(Plot): + """Contains methods to manipulate an input accepting groups of orbitals. + + Plots that need this functionality should inherit from this class. + """ + + _orbital_manager_key: str = "orbital_manager" + _orbital_groups_input_key: str = "groups" + + def _matches_group(self, group, query, iReq=None): + """Checks if a query matches a group.""" + if isinstance(query, (int, str)): + query = [query] + + if len(query) == 0: + return True + + return ("name" in group and group.get("name") in query) or iReq in query + + def groups(self, *i_or_names): + """Gets the groups that match your query + + Parameters + ---------- + *i_or_names: str, int + a string (to match the name) or an integer (to match the index), + You can pass as many as you want. + + Note that if you have a list of them you can go like `remove_group(*mylist)` + to spread it and use all items in your list as args. + + If no query is provided, all the groups will be matched + """ + return [req for i, req in enumerate(self.get_input(self._orbital_groups_input_key)) if self._matches_group(req, i_or_names, i)] + + def add_group(self, group = {}, clean=False, **kwargs): + """Adds a new orbitals group. + + The new group can be passed as a dict or as keyword arguments. + The keyword arguments will overwrite what has been passed as a dict if there is conflict. + + Parameters + --------- + group: dict, optional + the new group as a dictionary + clean: boolean, optional + whether the plot should be cleaned before drawing the group. + If `False`, the group will be drawn on top of what is already there. + **kwargs: + parameters of the group can be passed as keyword arguments too. + They will overwrite the values in req + """ + group = {**group, **kwargs} + + groups = [group] if clean else [*self.get_input(self._orbital_groups_input_key), group] + return self.update_inputs(**{self._orbital_groups_input_key: groups}) + + def remove_groups(self, *i_or_names, all=False): + """Removes orbital groups. + + Parameters + ------ + *i_or_names: str, int + a string (to match the name) or an integer (to match the index), + You can pass as many as you want. + + Note that if you have a list of them you can go like `remove_groups(*mylist)` + to spread it and use all items in your list as args + + If no query is provided, all the groups will be matched + """ + if all: + groups = [] + else: + groups = [req for i, req in enumerate(self.get_input(self._orbital_groups_input_key)) if not self._matches_group(req, i_or_names, i)] + + return self.update_inputs(**{self._orbital_groups_input_key: groups}) + + def update_groups(self, *i_or_names, **kwargs): + """Updates existing groups. + + Parameters + ------- + i_or_names: str or int + a string (to match the name) or an integer (to match the index) + this will be used to find the group that you need to update. + + Note that if you have a list of them you can go like `update_groups(*mylist)` + to spread it and use all items in your list as args + + If no query is provided, all the groups will be matched + **kwargs: + keyword arguments containing the values that you want to update + + """ + # We create a new list, otherwise we would be modifying the current one (not good) + groups = list(self.get_input(self._orbital_groups_input_key)) + for i, group in enumerate(groups): + if self._matches_group(group, i_or_names, i): + groups[i] = {**group, **kwargs} + + return self.update_inputs(**{self._orbital_groups_input_key: groups}) + + def split_groups(self, *i_or_names, on="species", only=None, exclude=None, remove=True, clean=False, ignore_constraints=False, **kwargs): + """Splits the orbital groups into multiple groups. + + Parameters + -------- + *i_or_names: str, int + a string (to match the name) or an integer (to match the index), + You can pass as many as you want. + + Note that if you have a list of them you can go like `split_groups(*mylist)` + to spread it and use all items in your list as args + + If no query is provided, all the groups will be matched + on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"}, or list of str + the parameter to split along. + + Note that you can combine parameters with a "+" to split along multiple parameters + at the same time. You can get the same effect also by passing a list. See examples. + only: array-like, optional + if desired, the only values that should be plotted out of + all of the values that come from the splitting. + exclude: array-like, optional + values of the splitting that should not be plotted + remove: + whether the splitted groups should be removed. + clean: boolean, optional + whether the plot should be cleaned before drawing. + If False, all the groups that come from the method will + be drawn on top of what is already there. + ignore_constraints: boolean or array-like, optional + determines whether constraints (imposed by the group to be splitted) + on the parameters that we want to split along should be taken into consideration. + + If `False`: all constraints considered. + If `True`: no constraints considered. + If array-like: parameters contained in the list ignore their constraints. + **kwargs: + keyword arguments that go directly to each group. + + This is useful to add extra filters. For example: + If you had a group called "C": + `plot.split_group("C", on="orbitals", spin=[0])` + will split the PDOS on the different orbitals but will take + only the contributions from spin up. + + Examples + ----------- + + >>> # Split groups 0 and 1 along n and l + >>> plot.split_groups(0, 1, on="n+l") + >>> # The same, but this time even if groups 0 or 1 had defined values for "l" + >>> # just ignore them and use all possible values for l. + >>> plot.split_groups(0, 1, on="n+l", ignore_constraints=["l"]) + """ + queries_manager = getattr(self.nodes, self._orbital_manager_key).get() + + old_groups = self.get_input(self._orbital_groups_input_key) + + if len(i_or_names) == 0: + groups = queries_manager.generate_queries( + split=on, only=only, exclude=exclude, **kwargs + ) + else: + reqs = self.groups(*i_or_names) + + groups = [] + for req in reqs: + new_groups = queries_manager._split_query( + req, on=on, only=only, exclude=exclude, + ignore_constraints=ignore_constraints, **kwargs + ) + + groups.extend(new_groups) + + if remove: + old_groups = [req for i, req in enumerate(old_groups) if not self._matches_group(req, i_or_names, i)] + + if not clean: + groups = [*old_groups, *groups] + + return self.update_inputs(**{self._orbital_groups_input_key: groups}) + + def split_orbs(self, on="species", only=None, exclude=None, clean=True, **kwargs): + """ + Splits the orbitals into different groups. + + Parameters + -------- + on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"}, or list of str + the parameter to split along. + Note that you can combine parameters with a "+" to split along multiple parameters + at the same time. You can get the same effect also by passing a list. + only: array-like, optional + if desired, the only values that should be plotted out of + all of the values that come from the splitting. + exclude: array-like, optional + values that should not be plotted + clean: boolean, optional + whether the plot should be cleaned before drawing. + If False, all the requests that come from the method will + be drawn on top of what is already there. + **kwargs: + keyword arguments that go directly to each request. + + This is useful to add extra filters. For example: + `plot.split_orbs(on="orbitals", species=["C"])` + will split on the different orbitals but will take + only those that belong to carbon atoms. + """ + return self.split_groups(on=on, only=only, exclude=exclude, clean=clean, **kwargs) diff --git a/src/sisl/viz/plots/pdos.py b/src/sisl/viz/plots/pdos.py index 169c95742b..2d6a8b90d4 100644 --- a/src/sisl/viz/plots/pdos.py +++ b/src/sisl/viz/plots/pdos.py @@ -1,889 +1,80 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np - -import sisl -from sisl.messages import warn +from __future__ import annotations -from ..input_fields import ( - Array1DInput, - BoolInput, - ColorInput, - DistributionInput, - ErangeInput, - FloatInput, - GeometryInput, - IntegerInput, - OptionsInput, - OrbitalQueries, - SileInput, - TextInput, -) -from ..plot import Plot, entry_point -from ..plotutils import find_files, random_color +from typing import Any, Literal, Optional, Sequence, Tuple -try: - import pathos - _do_parallel_calc = True -except Exception: - _do_parallel_calc = False +import numpy as np - -class PdosPlot(Plot): - """ - Plot representation of the projected density of states. +from sisl.viz.types import OrbitalStyleQuery + +from ..data import PDOSData +from ..figure import Figure, get_figure +from ..plot import Plot +from ..plotters.xarray import draw_xarray_xy +from ..processors.data import accept_data +from ..processors.logic import matches, swap +from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data +from ..processors.xarray import filter_energy_range, scale_variable +from .orbital_groups_plot import OrbitalGroupsPlot + + +def pdos_plot( + pdos_data: PDOSData, + groups: Sequence[OrbitalStyleQuery]=[{"name": "DOS"}], + Erange: Tuple[float, float] = (-2, 2), + E_axis: Literal["x", "y"] = "x", + line_mode: Literal["line", "scatter", "area_line"] = "line", + line_scale: float = 1., + backend: str = "plotly", +) -> Figure: + """Plot the projected density of states. Parameters - ------------- - pdos_file: pdosSileSiesta, optional - This parameter explicitly sets a .PDOS file. Otherwise, the PDOS file - is attempted to read from the fdf file - tbt_nc: tbtncSileTBtrans, optional - This parameter explicitly sets a .TBT.nc file. Otherwise, the PDOS - file is attempted to read from the fdf file - wfsx_file: wfsxSileSiesta, optional - The WFSX file to get the eigenstates. In standard SIESTA - nomenclature, this should probably be the *.fullBZ.WFSX file, as it - is the one that contains the eigenstates from the full - brillouin zone. - geometry: Geometry or sile (or path to file) that contains a geometry, optional - If this is passed, the geometry that has been read is ignored and - this one is used instead. - Erange: array-like of shape (2,), optional - Energy range where PDOS is displayed. - distribution: dict, optional - The distribution used for the smearing of the PDOS if calculated by - sisl. It accepts the same types of values as the - `distribution` argument of `EigenstateElectron.PDOS`. - Additionally, it accepts a dictionary containing arguments that are - passed directly to - `sisl.physics.distribution.get_distribution`. E.g.: {"method": - "gaussian", "smearing": 0.01, "x0": 0.0} - Structure of the dict: { 'method': 'smearing': - 'x0': } - nE: int, optional - If calculating the PDOS from a hamiltonian, the number of energy - points used - kgrid: array-like, optional - The number of kpoints in each reciprocal direction. A - Monkhorst-Pack grid will be generated to calculate the PDOS. - If not provided, it will be set to 3 for the periodic directions - and 1 for the non-periodic ones. - kgrid_displ: array-like, optional - Displacement of the Monkhorst-Pack grid - E0: float, optional - The energy to which all energies will be referenced (including - Erange). - requests: array-like of dict, optional - Here you can ask for the specific PDOS that you need. - TIP: Queries can be activated and deactivated. Each item is a - dict. Structure of the dict: { 'name': 'species': - 'atoms': Structure of the dict: { 'index': Structure of - the dict: { 'in': } 'fx': 'fy': - 'fz': 'x': 'y': 'z': 'Z': - 'neighbours': Structure of the dict: { 'range': - 'R': 'neigh_tag': } 'tag': 'seq': } - 'orbitals': 'spin': 'normalize': 'color': - 'linewidth': 'dash': 'split_on': 'scale': - The final DOS will be multiplied by this number. } - root_fdf: fdfSileSiesta, optional - Path to the fdf file that is the 'parent' of the results. - results_path: str, optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. - entry_points_order: array-like, optional - Order with which entry points will be attempted. - backend: optional - Directory where the files with the simulations results are - located. This path has to be relative to the root fdf. + ---------- + pdos_data: + The object containing the raw PDOS data (individual PDOS for each orbital/spin). + groups: + List of orbital specifications to filter and accumulate the PDOS. + The contribution of each group will be displayed in a different line. + See showcase notebook for examples. + Erange: + The energy range to plot. + E_axis: + Axis to project the energies. + line_mode: + Mode used to draw the PDOS lines. + line_scale: + Scaling factor for the width of all lines. + backend: + The backend to generate the figure. """ + pdos_data = accept_data(pdos_data, cls=PDOSData, check=True) - #Define all the class attributes - _plot_type = "PDOS" - - _param_groups = ( - - { - "key": "Hparams", - "name": "Hamiltonian related", - "icon": "apps", - "description": "This parameters are meaningful only if you are calculating the PDOS from a Hamiltonian" - }, - - ) - - _parameters = ( - - SileInput( - key = "pdos_file", name = "Path to PDOS file", - dtype=sisl.io.siesta.pdosSileSiesta, - group="dataread", - params = { - "placeholder": "Write the path to your PDOS file here...", - }, - help = """This parameter explicitly sets a .PDOS file. Otherwise, the PDOS file is attempted to read from the fdf file """ - ), - - SileInput( - key = "tbt_nc", name = "Path to the TBT.nc file", - dtype=sisl.io.tbtrans.tbtncSileTBtrans, - group="dataread", - params = { - "placeholder": "Write the path to your TBT.nc file here...", - }, - help = """This parameter explicitly sets a .TBT.nc file. Otherwise, the PDOS file is attempted to read from the fdf file """ - ), - - SileInput(key='wfsx_file', name='Path to WFSX file', - dtype=sisl.io.siesta.wfsxSileSiesta, - default=None, - help="""The WFSX file to get the eigenstates. - In standard SIESTA nomenclature, this should probably be the *.fullBZ.WFSX file, as it is the one - that contains the eigenstates from the full brillouin zone. - """ - ), - - GeometryInput( - key = "geometry", name = "Geometry to force on the plot", - group="dataread", - help = """If this is passed, the geometry that has been read is ignored and this one is used instead.""" - ), - - ErangeInput( - key="Erange", - default=[-2, 2], - help = "Energy range where PDOS is displayed." - ), - - DistributionInput( - key="distribution", name="distribution", - default={"method": "gaussian", "smearing": 0.01, "x0": 0.0}, - group="Hparams", - help="""The distribution used for the smearing of the PDOS if calculated by sisl. - It accepts the same types of values as the `distribution` argument of `EigenstateElectron.PDOS`. - Additionally, it accepts a dictionary containing arguments that are passed directly - to `sisl.physics.distribution.get_distribution`. E.g.: {"method": "gaussian", - "smearing": 0.01, "x0": 0.0} - """ - ), - - IntegerInput( - key="nE", name="Number of energy points", - group="Hparams", - default=100, - help="""If calculating the PDOS from a hamiltonian, the number of energy points used""" - ), - - Array1DInput(key="kgrid", name="Monkhorst-Pack grid", - default=None, - group="Hparams", - params={ - "shape": (3,) - }, - help="""The number of kpoints in each reciprocal direction. - A Monkhorst-Pack grid will be generated to calculate the PDOS. - If not provided, it will be set to 3 for the periodic directions - and 1 for the non-periodic ones.""" - ), - - Array1DInput(key="kgrid_displ", name="Monkhorst-Pack grid displacement", - default=[0, 0, 0], - group="Hparams", - help="""Displacement of the Monkhorst-Pack grid""" - ), - - FloatInput(key="E0", name="Reference energy", - default=0, - help="""The energy to which all energies will be referenced (including Erange).""" - ), - - OrbitalQueries( - key = "requests", name = "PDOS queries", - default = [{"active": True, "name": "DOS", "species": None, "atoms": None, "orbitals": None, "spin": None, "normalize": False, "color": "black", "linewidth": 1}], - help = """Here you can ask for the specific PDOS that you need. -
TIP: Queries can be activated and deactivated.""", - queryForm = [ - - TextInput( - key="name", name="Name", - default="DOS", - params={ - "placeholder": "Name of the line..." - }, - ), - - 'species', 'atoms', 'orbitals', 'spin', - - BoolInput( - key="normalize", name="Normalize", - default=False, - params={ - "offLabel": "No", - "onLabel": "Yes" - } - ), - - ColorInput( - key="color", name="Line color", - default=None, - ), - - FloatInput( - key="linewidth", name="Line width", - default=1, - ), - - OptionsInput( - key="dash", name="Line style", - default="solid", - params={ - "isMulti": False, - "isClearable": False, - "isSearchable": True, - "options": [{"value": option, "label": option} for option in ("solid", "dot", "dash", "longdash", "dashdot", "longdashdot")] - } - ), - - OptionsInput( - key="split_on", name="Split", - default=None, - params={ - "isMulti": True, - "isSearchable": True, - "options": [{"value": option, "label": option} for option in ("species", "atoms", "Z", "orbitals", "spin", "n", "l", "m", "zeta")] - } - ), - - FloatInput( - key="scale", name="Scale", - default=1, - params={"min": None}, - help="The final DOS will be multiplied by this number." - ) - ] - ), + E_PDOS = filter_energy_range(pdos_data, Erange=Erange, E0=0) + orbital_manager = get_orbital_queries_manager(pdos_data) + groups_data = reduce_orbital_data( + E_PDOS, groups=groups, orb_dim="orb", spin_dim="spin", sanitize_group=orbital_manager, + group_vars=('color', 'size', 'dash'), groups_dim="group", drop_empty=True, + spin_reduce=np.sum, ) - _shortcuts = { - - } - - @classmethod - def _default_animation(self, wdir = None, frame_names = None, **kwargs): - - pdos_files = find_files(wdir, "*.PDOS.xml", sort = True) - - if not pdos_files: - pdos_files = find_files(wdir, "*.PDOS", sort = True) - - def _get_frame_names(self): - - return [child_plot.get_setting("pdos_file").name for child_plot in self.child_plots] - - return PdosPlot.animated("pdos_file", pdos_files, frame_names = _get_frame_names, wdir = wdir, **kwargs) - - def _after_init(self): - - self._add_shortcuts() - - def _add_shortcuts(self): - - self.add_shortcut( - "o", "Split on orbitals", - self.split_DOS, on="orbitals", - _description="Split the total DOS along the different orbitals" - ) - - self.add_shortcut( - "s", "Split on species", - self.split_DOS, on="species", - _description="Split the total DOS along the different species" - ) - - self.add_shortcut( - "a", "Split on atoms", - self.split_DOS, on="atoms", - _description="Split the total DOS along the different atoms" - ) - - self.add_shortcut( - "p", "Split on spin", - self.split_DOS, on="spin", - _description="Split the total DOS along the different spin" - ) - - @entry_point('siesta output', 0) - def _read_siesta_output(self, pdos_file): - """ - Reads the pdos from a SIESTA .PDOS file. - """ - #Get the info from the .PDOS file - self.geometry, self.E, self.PDOS = self.get_sile(pdos_file or "pdos_file").read_data() - - @entry_point("TB trans", 2) - def _read_TBtrans(self, root_fdf, tbt_nc): - """ - Reads the PDOS from a *.TBT.nc file coming from a TBtrans run. - """ - #Get the info from the .PDOS file - tbt_sile = self.get_sile("tbt_nc") - self.PDOS = tbt_sile.DOS(sum=False).data.T - self.E = tbt_sile.E - - read_geometry_kwargs = {} - # Try to get the basis information from the root_fdf, if possible - try: - read_geometry_kwargs["atom"] = self.get_sile("root_fdf").read_geometry(output=True).atoms - except (FileNotFoundError, TypeError): - pass - - # Read the geometry from the TBT.nc file and get only the device part - self.geometry = tbt_sile.read_geometry(**read_geometry_kwargs).sub(tbt_sile.a_dev) - - @entry_point('wfsx file', 3) - def _read_from_wfsx(self, root_fdf, wfsx_file, Erange, nE, E0, distribution): - """Generates the PDOS values from a file containing eigenstates.""" - # Read the hamiltonian. We need it because we need the overlap matrix. - if not hasattr(self, "H"): - self.setup_hamiltonian() - - if self.H is None: - raise ValueError("No hamiltonian found, and we need the overlap matrix to calculate the PDOS.") - - # Get the wfsx file - wfsx_sile = self.get_sile(wfsx_file or "wfsx_file", parent=self.H) - - # Read the sizes of the file, which contain the number of spin channels - # and the number of orbitals and the number of k points. - sizes = wfsx_sile.read_sizes() - # Check that spin sizes of hamiltonian and wfsx file match - assert self.H.spin.size == sizes.nspin, \ - f"Hamiltonian has spin size {self.H.spin.size} while file has spin size {sizes.nspin}" - # Get the size of the spin channel. The size returned might be 8 if it is a spin-orbit - # calculation, but we need only 4 spin channels (total, x, y and z), same as with non-colinear - nspin = min(4, sizes.nspin) - - # Get the energies for which we need to calculate the PDOS. - self.E = np.linspace(Erange[0], Erange[-1], nE) + E0 - - # Initialize the PDOS array - self.PDOS = np.zeros((nspin, sizes.no_u, self.E.shape[0]), dtype=np.float64) - - # Loop through eigenstates in the WFSX file and add their contribution to the PDOS. - # Note that we pass the hamiltonian as the parent here so that the overlap matrix - # for each point can be calculated by eigenstate.PDOS() - for eigenstate in wfsx_sile.yield_eigenstate(): - if nspin == 2: - spin = eigenstate.info.get("spin", 0) - else: - spin = slice(None) - - self.PDOS[spin] += eigenstate.PDOS(self.E, distribution=distribution) * eigenstate.info.get("weight", 1) - - if nspin == 2: - # Convert from spin components to total and z contributions. - total = self.PDOS[0] + self.PDOS[1] - z = self.PDOS[0] - self.PDOS[1] - - self.PDOS[0] = total - self.PDOS[1] = z - - @entry_point('hamiltonian', 4) - def _read_from_H(self, kgrid, kgrid_displ, Erange, nE, E0, distribution): - """ - Calculates the PDOS from a sisl Hamiltonian. - """ - if not hasattr(self, "H"): - self.setup_hamiltonian() - - if self.H is None: - raise ValueError("No hamiltonian found.") - - # Get the kgrid or generate a default grid by checking the interaction between cells - # This should probably take into account how big the cell is. - if kgrid is None: - kgrid = [3 if nsc > 1 else 1 for nsc in self.H.geometry.nsc] - - if Erange is None: - raise ValueError('You need to provide an energy range to calculate the PDOS from the Hamiltonian') - - self.E = np.linspace(Erange[0], Erange[-1], nE) + E0 - - self.bz = sisl.MonkhorstPack(self.H, kgrid, kgrid_displ) - - # Define the available spins - spin_indices = [0] - if self.H.spin.is_polarized: - spin_indices = [0, 1] - - # Calculate the PDOS for all available spins - PDOS = [] - for spin in spin_indices: - with self.bz.apply(pool=_do_parallel_calc) as parallel: - spin_PDOS = parallel.average.eigenstate( - spin=spin, - wrap=lambda eig: eig.PDOS(self.E, distribution=distribution) - ) - - PDOS.append(spin_PDOS) - - if len(spin_indices) == 1: - PDOS = PDOS[0] - else: - # Convert from spin components to total and z contributions. - total = PDOS[0] + PDOS[1] - z = PDOS[0] - PDOS[1] - - PDOS = np.concatenate([total, z]) - - self.PDOS = PDOS - - def _after_read(self, geometry): - """ - Creates the PDOS dataarray and updates the "requests" input field. - """ - from xarray import DataArray - - if self.PDOS.ndim == 2: - # Add an extra axis for spin at the beggining if the array only has dimensions for orbitals and energy. - self.PDOS = self.PDOS[None, ...] - - # Check if the PDOS contains spin resolution (there should be three dimensions, - # and the first one should be the spin components) - self.spin = sisl.Spin({ - 1: sisl.Spin.UNPOLARIZED, - 2: sisl.Spin.POLARIZED, - 4: sisl.Spin.NONCOLINEAR - }[self.PDOS.shape[0]]) - - # Set the geometry. - if geometry is not None: - if geometry.no != self.PDOS.shape[1]: - raise ValueError(f"The geometry provided contains {geometry.no} orbitals, while we have PDOS information of {self.PDOS.shape[1]}.") - self.geometry = geometry - - self.get_param('requests').update_options(self.geometry, self.spin) - - dims = ('spin', 'orb', 'E') - coords = {'E': self.E} - if self.spin.is_polarized: - coords['spin'] = ['total', 'z'] - elif not self.spin.is_diagonal: - coords['spin'] = self.get_param('requests').get_options("spin") - - self.PDOS = DataArray(self.PDOS, coords=coords, dims=dims) - - def _set_data(self, requests, E0, Erange): - - # Get only the energies we are interested in - Erange = np.array(Erange) - if Erange is None: - Emin, Emax = [min(self.PDOS.E.values), max(self.PDOS.E.values)] - else: - Emin, Emax = Erange + E0 - - # Get only the part of the arra - E_PDOS = self.PDOS.where( - (self.PDOS.E > Emin) & (self.PDOS.E < Emax), drop=True) - - # Build the dictionary that will be passed to the backend - for_backend = {"Es": E_PDOS.E.values - E0, "PDOS_values": {}, "request_metadata": {}} - - # Go request by request and extract the corresponding PDOS contribution - for request in requests: - self._get_request_PDOS(request, E_PDOS, values_storage=for_backend["PDOS_values"], metadata_storage=for_backend["request_metadata"]) - - return for_backend - - @staticmethod - def _select_spin(dataarray, spin): - nspin = len(dataarray.spin) - - if nspin == 1: - if spin is not None: - warn(f"There is no spin information but the spin request is {spin}, ignoring spin request.") - return dataarray - - if spin is None: - spin = "total" - - if nspin == 4: - # Non colinear spin calculation, just select from the dataarray - dataarray = dataarray.sel(spin=spin) - elif nspin == 2: - # Spin polarized calculation. The information is stored as cartesian, but the user - # might request spin components. - if not isinstance(spin, (int, str)): - assert len(spin) == 1 - spin = spin[0] - - if spin in ('total', 'z'): - dataarray = dataarray.sel(spin=spin) - else: - total = dataarray.sel(spin="total") - z = dataarray.sel(spin="z") - - if spin == 0: - dataarray = (total + z) / 2 - elif spin == 1: - dataarray = (total - z) / 2 - else: - raise ValueError(f"Incorrect spin request for spin polarized data: {spin}") - - return dataarray - - def _get_request_PDOS(self, request, E_PDOS=None, values_storage=None, metadata_storage=None): - """Extracts the PDOS values that correspond to a specific request. - - This has been made a function so that it can call itself recursively - to support splitting individual requests. - Parameters - -------------- - request: dict - the request to process - E_PDOS: DataArray - the part of the PDOS dataarray that falls in the energy range that we want to draw. - If not, provided the full PDOS data stored in `self.PDOS` is used. - values_storage: dict, optional - a dictionary where the PDOS values will be stored using the request's name as the key. - metadata_storage: dict, optional - a dictionary where metadata for the request will be stored using the request's name as the key. - Returns - ---------- - np.ndarray - PDOS values obtained from the request - """ - - # Get the full PDOS data if a filtered PDOS has not been provided - if E_PDOS is None: - E_PDOS = self.PDOS - - # Get the requests parameter, which will be needed to retrieve available options - # and get the list of orbitals that correspond to a given request. - requests_param = self.get_param("requests") - - request = self._new_request(**request) - - # If the request has an split_on parameter that is not None, - # we are going to split the request in place. Note that you can also - # split requests or the full DOS using the `split_requests` and `split_DOS` - # methods, but this may be more convenient for the GUI. - if request["split_on"]: - - # We are going to give a different dash style to each obtained request - dash_options = requests_param["dash"].options - n_dash_options = len(dash_options) - def query_gen(i=[-1], **kwargs): - i[0] += 1 - return self._new_request(**{**kwargs, "dash": dash_options[i[0] % n_dash_options]}) - - # And ensure they all have the same color (if the color is None, - # each request will show up with a different color) - request["color"] = request["color"] or random_color() - - # Now, get all the requests that emerge from splitting the "parent" request - # Note that we need to set split_on to None for the new requests, otherwise the - # cycle would be infinite - splitted_request = requests_param._split_query(request, on=request["split_on"], split_on=None, query_gen=query_gen, vary="dash") - # Now that we have them, process them - for req in splitted_request: - self._get_request_PDOS(req, E_PDOS, values_storage=values_storage, metadata_storage=metadata_storage) - # Since we have already drawn all the requests, we don't need to do anything else - # This would not be true if we wanted to represent the "total request" as well, but we - # don't give that option yet. Just removing the return would draw the total - return + # Determine what goes on each axis + x = matches(E_axis, "x", ret_true="E", ret_false="PDOS") + y = matches(E_axis, "y", ret_true="E", ret_false="PDOS") - # From now on, the code focuses on actually extracting the PDOS values for the request + dependent_axis = swap(E_axis, ("x", "y")) - # Use only the active requests - if not request["active"]: - return + # A PlotterNode gets the processed data and creates abstract actions (backend agnostic) + # that should be performed on the figure. The output of this node + # must be fed to a figure (backend specific). + final_groups_data = scale_variable(groups_data, var="size", scale=line_scale, default_value=1) + plot_actions = draw_xarray_xy(data=final_groups_data, x=x, y=y, width="size", what=line_mode, dependent_axis=dependent_axis) - orb = requests_param.get_orbitals(request) + return get_figure(backend=backend, plot_actions=plot_actions) - if len(orb) == 0: - # This request does not match any possible orbital - return +class PdosPlot(OrbitalGroupsPlot): - req_PDOS = E_PDOS.sel(orb=orb) - req_PDOS = self._select_spin(req_PDOS, request['spin']) - - reduce_coords = set(["orb", "spin"]).intersection(req_PDOS.dims) - - if request["normalize"]: - req_PDOS = req_PDOS.mean(reduce_coords) - else: - req_PDOS = req_PDOS.sum(reduce_coords) - - # Finally, multiply the values by the scale factor - values = req_PDOS.values * request["scale"] - req_name = request["name"] - - if values_storage is not None: - if req_name in values_storage: - raise ValueError(f"There are multiple requests that are named '{req_name}'") - values_storage[req_name] = values - - if metadata_storage is not None: - # Build the dictionary that contains metadata for this request. - metadata = { - "style": { - "line": {'width': request["linewidth"], "color": request["color"], "dash": request["dash"]} - } - } - - metadata_storage[req_name] = metadata - - return values - - # ---------------------------------- - # CONVENIENCE METHODS - # ---------------------------------- - - def _matches_request(self, request, query, iReq=None): - """ - Checks if a query matches a PDOS request - """ - if isinstance(query, (int, str)): - query = [query] - - if len(query) == 0: - return True - - return ("name" in request and request.get("name") in query) or iReq in query - - def _new_request(self, **kwargs): - - complete_req = self.get_param("requests").complete_query - - if "spin" not in kwargs and not self.spin.is_diagonal: - if "spin" not in kwargs.get("split_on", ""): - kwargs["spin"] = ["total"] - - return complete_req({"name": str(len(self.settings["requests"])), **kwargs}) - - def requests(self, *i_or_names): - """ - Gets the requests that match your query - - Parameters - ---------- - *i_or_names: str, int - a string (to match the name) or an integer (to match the index), - You can pass as many as you want. - - Note that if you have a list of them you can go like `remove_request(*mylist)` - to spread it and use all items in your list as args. - - If no query is provided, all the requests will be matched - """ - return [req for i, req in enumerate(self.get_setting("requests")) if self._matches_request(req, i_or_names, i)] - - def add_request(self, req = {}, clean=False, **kwargs): - """ - Adds a new PDOS request. The new request can be passed as a dict or as a list of keyword arguments. - The keyword arguments will overwrite what has been passed as a dict if there is conflict. - - Parameters - --------- - req: dict, optional - the new request as a dictionary - clean: boolean, optional - whether the plot should be cleaned before drawing the request. - If `False`, the request will be drawn on top of what is already there. - **kwargs: - parameters of the request can be passed as keyword arguments too. - They will overwrite the values in req - """ - request = self._new_request(**{**req, **kwargs}) - - try: - requests = [request] if clean else [*self.settings["requests"], request] - self.update_settings(requests=requests) - except Exception as e: - warn("There was a problem with your new request ({}): \n\n {}".format(request, e)) - self.undo_settings() - - return self - - def remove_requests(self, *i_or_names, all=False, update_fig=True): - """ - Removes requests from the PDOS plot - - Parameters - ------ - *i_or_names: str, int - a string (to match the name) or an integer (to match the index), - You can pass as many as you want. - - Note that if you have a list of them you can go like `remove_requests(*mylist)` - to spread it and use all items in your list as args - - If no query is provided, all the requests will be matched - """ - if all: - requests = [] - else: - requests = [req for i, req in enumerate(self.get_setting("requests", copy=False)) if not self._matches_request(req, i_or_names, i)] - - return self.update_settings(run_updates=update_fig, requests=requests) - - def update_requests(self, *i_or_names, **kwargs): - """ - Updates an existing request - - Parameters - ------- - i_or_names: str or int - a string (to match the name) or an integer (to match the index) - this will be used to find the request that you need to update. - - Note that if you have a list of them you can go like `update_requests(*mylist)` - to spread it and use all items in your list as args - - If no query is provided, all the requests will be matched - **kwargs: - keyword arguments containing the values that you want to update - - """ - # We create a new list, otherwise we would be modifying the current one (not good) - requests = list(self.get_setting("requests", copy=False)) - for i, request in enumerate(requests): - if self._matches_request(request, i_or_names, i): - requests[i] = {**request, **kwargs} - - return self.update_settings(requests=requests) - - def _NOT_WORKING_merge_requests(self, *i_or_names, remove=True, clean=False, **kwargs): - """ - Merge multiple requests into one. - - Parameters - ------ - *i_or_names: str, int - a string (to match the name) or an integer (to match the index), - You can pass as many as you want. - - Note that if you have a list of them you can go like `merge_requests(*mylist)` - to spread it and use all items in your list as args - - If no query is provided, all the requests will be matched - remove: boolean, optional - whether the merged requests should be removed. - If False, they will be kept in the plot - clean: boolean, optional - whether all requests should be removed before drawing the merged request - **kwargs: - keyword arguments that go directly to the new request. - - You can use them to set other attributes to the request. For example: - `plot.merge_requests(on="orbitals", species=["C"])` - will split the PDOS on the different orbitals but will take - only those that belong to carbon atoms. - """ - keys = ["atoms", "Z", "orbitals", "species", "spin", "n", "l", "m", "zeta"] - - # Merge all the requests (nice tree I built here, isn't it? :) ) - new_request = {key: [] for key in keys} - for i, request in enumerate(self.get_setting("requests", copy=False)): - if self._matches_request(request, i_or_names, i): - for key in keys: - if request.get(key, None) is not None: - val = request[key] - if key == "atoms": - val = self.geometry._sanitize_atoms(val) - val = np.atleast_1d(val) - new_request[key] = [*new_request[key], *val] - - # Remove duplicate values for each key - # and if it's an empty list set it to None (empty list returns no PDOS) - for key in keys: - new_request[key] = list(set(new_request[key])) or None - - # Remove the merged requests if desired - if remove: - self.remove_requests(*i_or_names, update_fig=False) - - return self.add_request(**new_request, **kwargs, clean=clean) - - def split_requests(self, *i_or_names, on="species", only=None, exclude=None, remove=True, clean=False, ignore_constraints=False, **kwargs): - """ - Splits the desired requests into multiple requests - - Parameters - -------- - *i_or_names: str, int - a string (to match the name) or an integer (to match the index), - You can pass as many as you want. - - Note that if you have a list of them you can go like `split_requests(*mylist)` - to spread it and use all items in your list as args - - If no query is provided, all the requests will be matched - on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"}, or list of str - the parameter to split along. - - Note that you can combine parameters with a "+" to split along multiple parameters - at the same time. You can get the same effect also by passing a list. See examples. - only: array-like, optional - if desired, the only values that should be plotted out of - all of the values that come from the splitting. - exclude: array-like, optional - values of the splitting that should not be plotted - remove: - whether the splitted requests should be removed. - clean: boolean, optional - whether the plot should be cleaned before drawing. - If False, all the requests that come from the method will - be drawn on top of what is already there. - ignore_constraints: boolean or array-like, optional - determines whether constraints (imposed by the request to be splitted) - on the parameters that we want to split along should be taken into consideration. - - If `False`: all constraints considered. - If `True`: no constraints considered. - If array-like: parameters contained in the list ignore their constraints. - **kwargs: - keyword arguments that go directly to each request. - - This is useful to add extra filters. For example: - If you had a request called "C": - `plot.split_request("C", on="orbitals", spin=[0])` - will split the PDOS on the different orbitals but will take - only the contributions from spin up. - - Examples - ----------- - - >>> plot = H.plot.pdos(requests=[...]) - >>> - >>> # Split requests 0 and 1 along n and l - >>> plot.split_requests(0, 1, on="n+l") - >>> # The same, but this time even if requests 0 or 1 had defined values for "l" - >>> # just ignore them and use all possible values for l. - >>> plot.split_requests(0, 1, on="n+l", ignore_constraints=["l"]) - """ - reqs = self.requests(*i_or_names) - - requests = [] - for req in reqs: - - new_requests = self.get_param("requests")._split_query( - req, on=on, only=only, exclude=exclude, req_gen=self._new_request, - ignore_constraints=ignore_constraints, **kwargs - ) - - requests.extend(new_requests) - - if remove: - self.remove_requests(*i_or_names, update_fig=False) - - if not clean: - requests = [*self.get_setting("requests", copy=False), *requests] - - return self.update_settings(requests=requests) + function = staticmethod(pdos_plot) def split_DOS(self, on="species", only=None, exclude=None, clean=True, **kwargs): """ @@ -922,11 +113,5 @@ def split_DOS(self, on="species", only=None, exclude=None, clean=True, **kwargs) >>> # be replaced by the value of n. >>> plot.split_DOS(on="n+l", species=["Au"], name="Au $ns") """ - requests = self.get_param('requests')._generate_queries( - on=on, only=only, exclude=exclude, query_gen=self._new_request, **kwargs) - - # If the user doesn't want to clean the plot, we will just add the requests to the existing ones - if not clean: - requests = [*self.get_setting("requests", copy=False), *requests] - - return self.update_settings(requests=requests) + return self.split_groups(on=on, only=only, exclude=exclude, clean=clean, **kwargs) + diff --git a/src/sisl/viz/plots/tests/.coverage b/src/sisl/viz/plots/tests/.coverage new file mode 100644 index 0000000000..1859d98b1b Binary files /dev/null and b/src/sisl/viz/plots/tests/.coverage differ diff --git a/src/sisl/viz/plots/tests/__init__.py b/src/sisl/viz/plots/tests/__init__.py index 448bb8652d..e69de29bb2 100644 --- a/src/sisl/viz/plots/tests/__init__.py +++ b/src/sisl/viz/plots/tests/__init__.py @@ -1,3 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. diff --git a/src/sisl/viz/plots/tests/conftest.py b/src/sisl/viz/plots/tests/conftest.py deleted file mode 100644 index 1c44f5f86c..0000000000 --- a/src/sisl/viz/plots/tests/conftest.py +++ /dev/null @@ -1,87 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import importlib -import os.path as osp - -import pytest - - -@pytest.fixture(scope="session") -def importables(): - # Find out which packages are impo - importables_info = {True: [], False: []} - for modname in ("pandas", "xarray", "bpy", "pathos", "dill", "tqdm", - "skimage", "plotly", "matplotlib"): - try: - importlib.import_module(modname) - importables_info[True].append(modname) - except ImportError: - importables_info[False].append(modname) - - return importables_info - - -@pytest.fixture(scope="session") -def siesta_test_files(sisl_files): - - def _siesta_test_files(path): - return sisl_files(osp.join('sisl', 'io', 'siesta', path)) - - return _siesta_test_files - - -@pytest.fixture(scope="session") -def vasp_test_files(sisl_files): - - def _siesta_test_files(path): - return sisl_files(osp.join('sisl', 'io', 'vasp', path)) - - return _siesta_test_files - - -class _TestPlot: - - @pytest.fixture(scope="class", params=[None]) - def backend(self, request): - return request.param - - @pytest.fixture(scope="class") - def plot(self, backend, init_func_and_attrs, importables): - """Initializes the plot using the initializin function. - - If the plot can't be initialized it skips all tests for that plot. - """ - init_func = init_func_and_attrs[0] - - msg = "" - if importables[False]: - msg = ", ".join(importables[False]) + " is/are not importable" - if importables[True]: - msg = f'{", ".join(importables[True]) + " is/are importable"}; {msg}' - - try: - yield init_func(backend=backend, _debug=True) - except Exception as e: - pytest.xfail(f"Plot was not initialized. Error: {e}. \n\n{msg}") - - # If we are testing with the matplotlib backend, close all the figures that - # might have been created. - if backend == "matplotlib": - import matplotlib.pyplot as plt - - plt.close("all") - - @pytest.fixture(scope="class") - def test_attrs(self, init_func_and_attrs): - """Checks that all the attributes required for testing have been passed. - - Otherwise, tests are skipped. - """ - attrs = init_func_and_attrs[1] - - missing_attrs = set(getattr(self, "_required_attrs", [])) - set(attrs) - if len(missing_attrs) > 0: - pytest.skip(f"Tests could not be ran because some testing attributes are missing: {missing_attrs}") - - return attrs diff --git a/src/sisl/viz/plots/tests/test_bands.py b/src/sisl/viz/plots/tests/test_bands.py index 70f7f4bc74..ae52cc45f3 100644 --- a/src/sisl/viz/plots/tests/test_bands.py +++ b/src/sisl/viz/plots/tests/test_bands.py @@ -1,242 +1,21 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" - -Tests specific functionality of the bands plot. - -Different inputs are tested (siesta .bands and sisl Hamiltonian). - -""" -import itertools -from functools import partial - -import numpy as np import pytest -import sisl -from sisl.viz import BandsPlot -from sisl.viz.plots.tests.conftest import _TestPlot - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - - -class TestBandsPlot(_TestPlot): - - _required_attrs = [ - "bands_shape", # Tuple specifying the shape of the bands dataarray - "gap", # Float. The value of the gap in eV - "ticklabels", # Array-like with the tick labels - "tickvals", # Array-like with the expected positions of the ticks - "spin_texture", # Whether spin texture should be possible to draw or not. - "spin", # The spin class of the calculation - ] - - @pytest.fixture(scope="class", params=[None, *BandsPlot.get_class_param("backend").options]) - def backend(self, request): - return request.param - - @pytest.fixture(scope="class", params=[ - # From .bands file - "siesta_output", - # From a hamiltonian - "sisl_H_unpolarized", "sisl_H_polarized", "sisl_H_noncolinear", "sisl_H_spinorbit", - "sisl_H_path_unpolarized", - # From a .bands.WFSX file - "wfsx_file" - ]) - def init_func_and_attrs(self, request, siesta_test_files): - name = request.param - - if name == "siesta_output": - # From a siesta .bands file - init_func = sisl.get_sile(siesta_test_files("SrTiO3.bands")).plot - attrs = { - "bands_shape": (150, 72), - "ticklabels": ('Gamma', 'X', 'M', 'Gamma', 'R', 'X'), - "tickvals": [0.0, 0.429132, 0.858265, 1.465149, 2.208428, 2.815313], - "gap": 1.677, - "spin_texture": False, - "spin": sisl.Spin("") - } - elif name == "wfsx_file": - # From the SIESTA .bands.WFSX file - fdf = sisl.get_sile(siesta_test_files("bi2se3_3ql.fdf")) - wfsx = siesta_test_files("bi2se3_3ql.bands.WFSX") - init_func = partial(fdf.plot.bands, wfsx_file=wfsx, E0=-51.68, entry_points_order=["wfsx file"]) - attrs = { - "bands_shape": (16, 8), - "ticklabels": None, - "tickvals": None, - "gap": 0.0575, - "spin_texture": False, - "spin": sisl.Spin("nc") - } - - elif name.startswith("sisl_H"): - gr = sisl.geom.graphene() - H = sisl.Hamiltonian(gr) - H.construct([(0.1, 1.44), (0, -2.7)]) - - spin_type = name.split("_")[-1] - n_spin, H = { - "unpolarized": (0, H), - "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), - "noncolinear": (0, H.transform(spin=sisl.Spin.NONCOLINEAR)), - "spinorbit": (0, H.transform(spin=sisl.Spin.SPINORBIT)) - }.get(spin_type) - - n_states = 2 - if not H.spin.is_diagonal: - n_states *= 2 - - # Let's create the same graphene bands plot using the hamiltonian - # from two different prespectives - if name.startswith("sisl_H_path"): - # Passing a list of points (as if we were interacting from a GUI) - # We want 6 points in total. This is the path that we want to get: - # [0,0,0] --2-- [2/3, 1/3, 0] --1-- [1/2, 0, 0] - path = [{"active": True, "x": x, "y": y, "z": z, "divisions": 3, - "name": tick} for tick, (x, y, z) in zip(["Gamma", "M", "K"], [[0, 0, 0], [2/3, 1/3, 0], [1/2, 0, 0]])] - path[-1]['divisions'] = 2 - - init_func = partial(H.plot.bands, band_structure=path) - else: - # Directly creating a BandStructure object - bz = sisl.BandStructure(H, [[0, 0, 0], [2/3, 1/3, 0], [1/2, 0, 0]], 6, ["Gamma", "M", "K"]) - init_func = bz.plot - - attrs = { - "bands_shape": (6, n_spin, n_states) if n_spin != 0 else (6, n_states), - "ticklabels": ["Gamma", "M", "K"], - "tickvals": [0., 1.70309799, 2.55464699], - "gap": 0, - "spin_texture": not H.spin.is_diagonal, - "spin": H.spin - } - - return init_func, attrs - - def _check_bands_array(self, bands, spin, expected_shape): - pytest.importorskip("xarray") - from xarray import DataArray - - assert isinstance(bands, DataArray) - - if spin.is_polarized: - expected_coords = ('k', 'spin', 'band') - else: - expected_coords = ('k', 'band') - - assert set(bands.dims) == set(expected_coords) - assert bands.transpose(*expected_coords).shape == expected_shape - - def test_bands_dataarray(self, plot, test_attrs): - """ - Check that the data array was created and contains the correct information. - """ - # Check that there is a bands attribute - assert hasattr(plot, 'bands') - - self._check_bands_array(plot.bands, test_attrs["spin"], test_attrs['bands_shape']) - - def test_bands_filtered(self, plot, test_attrs): - # Check that we can correctly filter the bands to draw. - plot.update_settings(bands_range=[0, 1], Erange=None) - - # Check that the filtered bands are correctly passed to the backend - assert "draw_bands" in plot._for_backend - assert "filtered_bands" in plot._for_backend["draw_bands"] - - # Check that everything is fine with the dimensions of the filtered bands. Since we filtered, - # it should contain only one band - filtered_bands = plot._for_backend["draw_bands"]["filtered_bands"] - self._check_bands_array(filtered_bands, test_attrs["spin"], (*test_attrs["bands_shape"][:-1], 1)) - - def test_gap(self, plot, test_attrs): - # Check that we can calculate the gap correctly - # Allow for a small variability just in case there - # are precision differences - assert abs(plot.gap - test_attrs['gap']) < 0.01 - - def test_gap_to_backend(self, plot, test_attrs): - # Check that the gap is correctly transmitted to the backend - plot.update_settings(gap=False, custom_gaps=[]) - assert len(plot._for_backend["gaps"]) == 0 - - plot.update_settings(gap=True) - assert len(plot._for_backend["gaps"]) > 0 - for gap in plot._for_backend["gaps"]: - assert len(set(["ks", "Es", "color", "name"]) - set(gap)) == 0 - assert abs(np.diff(gap["Es"]) - test_attrs['gap']) < 0.01 - - def test_custom_gaps_to_backend(self, plot, test_attrs): - if test_attrs['ticklabels'] is None: - return - - plot.update_settings(gap=False, custom_gaps=[]) - assert len(plot._for_backend["gaps"]) == 0 - - gaps = list(itertools.combinations(test_attrs['ticklabels'], 2)) - - plot.update_settings(custom_gaps=[{"from": gap[0], "to": gap[1], "spin": [0]} for gap in gaps]) - - assert len(plot._for_backend["gaps"]) + len(gaps) - for gap in plot._for_backend["gaps"]: - assert len(set(["ks", "Es", "color", "name"]) - set(gap)) == 0 - - def test_custom_gaps_correct(self, plot, test_attrs): - if test_attrs['ticklabels'] is None: - return - - # Generate custom gaps from labels - gaps = list(itertools.combinations(test_attrs['ticklabels'], 2)) - plot.update_settings(custom_gaps=[{"from": gap[0], "to": gap[1]} for gap in gaps]) - - gaps_from_labels = np.unique([np.diff(gap["Es"]) for gap in plot._for_backend["gaps"]]) - - # Generate custom gaps from k values - gaps = list(itertools.combinations(test_attrs['tickvals'], 2)) - plot.update_settings(custom_gaps=[{"from": gap[0], "to": gap[1]} for gap in gaps]) - - # Check that we get the same values for the gaps - assert abs(gaps_from_labels.sum() - np.unique([np.diff(gap["Es"]) for gap in plot._for_backend["gaps"]]).sum()) < 0.03 - - # We have finished with all the gaps tests here, so just clean up before continuing - plot.update_settings(custom_gaps=[], gap=False) - - def test_spin_moments(self, plot, test_attrs): - if not test_attrs["spin_texture"]: - return - pytest.importorskip("xarray") - from xarray import DataArray - - # Check that spin moments have been calculated - assert hasattr(plot, "spin_moments") - - # Check that it is a dataarray containing the right information - spin_moments = plot.spin_moments - assert isinstance(spin_moments, DataArray) - assert set(spin_moments.dims) == set(('k', 'band', 'axis')) - assert spin_moments.shape == (test_attrs['bands_shape'][0], 3, test_attrs['bands_shape'][-1]) - - def test_spin_texture(self, plot, test_attrs): - assert plot._for_backend["draw_bands"]["spin_texture"]["show"] is False - - if not test_attrs["spin_texture"]: - return +from sisl import Spin +from sisl.viz.data import BandsData +from sisl.viz.plots import bands_plot - plot.update_settings(spin="x", bands_range=[0, 1], Erange=None) - spin_texture = plot._for_backend["draw_bands"]["spin_texture"] - assert spin_texture["show"] is True - assert "colorscale" in spin_texture - assert "values" in spin_texture +@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"]) +def spin(request): + return Spin(request.param) - spin_texture_arr = spin_texture["values"] +@pytest.fixture(scope="module") +def gap(): + return 2.5 - self._check_bands_array(spin_texture_arr, test_attrs["spin"], (*test_attrs["bands_shape"][:-1], 1)) - assert "axis" in spin_texture_arr.coords - assert str(spin_texture_arr.axis.values) == "x" +@pytest.fixture(scope="module") +def bands_data(spin, gap): + return BandsData.toy_example(spin=spin, gap=gap) - plot.update_settings(spin=None) +def test_bands_plot(bands_data): + bands_plot(bands_data) diff --git a/src/sisl/viz/plots/tests/test_bondlengthmap.py b/src/sisl/viz/plots/tests/test_bondlengthmap.py deleted file mode 100644 index ad57f0e522..0000000000 --- a/src/sisl/viz/plots/tests/test_bondlengthmap.py +++ /dev/null @@ -1,52 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from functools import partial - -import numpy as np -import pytest - -import sisl -from sisl.viz.plots.tests.test_geometry import TestGeometry as _TestGeometry - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - -# ------------------------------------------------------------ -# Build a generic tester for bond length plot -# ------------------------------------------------------------ - - -class TestBondLengthMap(_TestGeometry): - - _required_attrs = ["has_strain_ref"] - - @pytest.fixture(scope="class", params=[None, *sisl.viz.BondLengthMap.get_class_param("backend").options]) - def backend(self, request): - return request.param - - @pytest.fixture(scope="class", params=["sisl_geom", "sisl_geom_strain"]) - def init_func_and_attrs(self, request): - name = request.param - - if name.startswith("sisl_geom"): - geometry = sisl.geom.graphene(orthogonal=True, bond=1.35) - - if name.endswith("strain"): - kwargs = {"strain_ref": sisl.geom.graphene(orthogonal=True)} - attrs = {"has_strain_ref": True} - else: - kwargs = {} - attrs = {"has_strain_ref": False} - - init_func = partial(geometry.plot.bondlengthmap, **kwargs) - - return init_func, attrs - - def test_strain_ref(self, plot, test_attrs): - if test_attrs["has_strain_ref"]: - plot.update_settings(axes=[0, 1, 2], strain=True, show_bonds=True) - - strains = [bond["color"] for bond in plot._for_backend["bonds_props"]] - - plot.update_settings(strain=False) - assert not np.allclose([bond["color"] for bond in plot._for_backend["bonds_props"]], strains) diff --git a/src/sisl/viz/plots/tests/test_fatbands.py b/src/sisl/viz/plots/tests/test_fatbands.py deleted file mode 100644 index bd59c616ef..0000000000 --- a/src/sisl/viz/plots/tests/test_fatbands.py +++ /dev/null @@ -1,197 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" - -Tests specific functionality of a fatbands plot - -""" -from functools import partial - -import numpy as np -import pytest - -import sisl -from sisl.viz.plots.tests.test_bands import TestBandsPlot as _TestBandsPlot - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - -# ------------------------------------------------------------ -# Build a generic tester for the bands plot -# ------------------------------------------------------------ - - -class TestFatbandsPlot(_TestBandsPlot): - - _required_attrs = [ - *_TestBandsPlot._required_attrs, - "weights_shape", # Tuple. The shape that self.weights dataarray is expected to have - ] - - @pytest.fixture(scope="class", params=[None, *sisl.viz.FatbandsPlot.get_class_param("backend").options]) - def backend(self, request): - return request.param - - @pytest.fixture(scope="class", params=[ - "sisl_H_unpolarized", "sisl_H_polarized", "sisl_H_noncolinear", "sisl_H_spinorbit", - "sisl_H_unpolarized_jump", - "wfsx file", - ]) - def init_func_and_attrs(self, request, siesta_test_files): - name = request.param - - if name.startswith("sisl_H"): - gr = sisl.geom.graphene() - H = sisl.Hamiltonian(gr) - H.construct([(0.1, 1.44), (0, -2.7)]) - - spin_type = name.split("_")[2] - n_spin, H = { - "unpolarized": (1, H), - "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), - "noncolinear": (1, H.transform(spin=sisl.Spin.NONCOLINEAR)), - "spinorbit": (1, H.transform(spin=sisl.Spin.SPINORBIT)) - }.get(spin_type) - - n_states = 2 - if H.spin.is_spinorbit or H.spin.is_noncolinear: - n_states *= 2 - - # Directly creating a BandStructure object - if name.endswith("jump"): - names = ["Gamma", "M", "M", "K"] - bz = sisl.BandStructure(H, [[0, 0, 0], [2/3, 1/3, 0], None, [2/3, 1/3, 0], [1/2, 0, 0]], 6, names) - nk = 7 - tickvals = [0., 1.70309799, 1.83083034, 2.68237934] - else: - names = ["Gamma", "M", "K"] - bz = sisl.BandStructure(H, [[0, 0, 0], [2/3, 1/3, 0], [1/2, 0, 0]], 6, names) - nk = 6 - tickvals = [0., 1.70309799, 2.55464699] - init_func = bz.plot.fatbands - - attrs = { - "bands_shape": (nk, n_spin, n_states) if H.spin.is_polarized else (nk, n_states), - "weights_shape": (n_spin, nk, n_states, 2) if H.spin.is_polarized else (nk, n_states, 2), - "ticklabels": names, - "tickvals": tickvals, - "gap": 0, - "spin_texture": not H.spin.is_diagonal, - "spin": H.spin - } - elif name == "wfsx file": - # From a siesta bands.WFSX file - # Since there is no hamiltonian for bi2se3_3ql.fdf, we create a dummy one - wfsx = sisl.get_sile(siesta_test_files("bi2se3_3ql.bands.WFSX")) - - geometry = sisl.get_sile(siesta_test_files("bi2se3_3ql.fdf")).read_geometry() - geometry = sisl.Geometry(geometry.xyz, atoms=wfsx.read_basis()) - - H = sisl.Hamiltonian(geometry, dim=4) - - init_func = partial(H.plot.fatbands, wfsx_file=wfsx, E0=-51.68, entry_points_order=["wfsx file"]) - attrs = { - "bands_shape": (16, 8), - "weights_shape": (16, 8, 195), - "ticklabels": None, - "tickvals": None, - "gap": 0.0575, - "spin_texture": False, - "spin": sisl.Spin("nc") - } - - return init_func, attrs - - def test_weights_dataarray_avail(self, plot, test_attrs): - """ - Check that the data array was created and contains the correct information. - """ - pytest.importorskip("xarray") - from xarray import DataArray - - # Check that there is a weights attribute - assert hasattr(plot, "weights") - - # Check that it is a dataarray containing the right information - weights = plot.weights - assert isinstance(weights, DataArray) - - if test_attrs["spin"].is_polarized: - expected_dims = ("spin", "k", "band", "orb") - else: - expected_dims = ("k", "band", "orb") - assert weights.dims == expected_dims - assert weights.shape == test_attrs["weights_shape"] - - def test_group_weights(self, plot): - pytest.importorskip("xarray") - from xarray import DataArray - - total_weights = plot._get_group_weights({}) - - assert isinstance(total_weights, DataArray) - assert set(total_weights.dims) == set(("spin", "band", "k")) - - def test_weights_values(self, plot, test_attrs): - # Check that all states are normalized. - assert np.allclose(plot.weights.dropna("k", "all").sum("orb"), 1, atol=0.05), "Weight values do not sum 1 for all states." - - # If we have all the bands of the system, assert that orbitals are also "normalized". - factor = 2 if not test_attrs["spin"].is_diagonal else 1 - if len(plot.weights.band) * factor == len(plot.weights.orb): - assert np.allclose(plot.weights.dropna("k", "all").sum("band"), factor) - - def test_groups(self, plot, test_attrs): - """ - Check that we can request groups - """ - pytest.importorskip("xarray") - from xarray import DataArray - - color = "green" - name = "Nice group" - - plot.update_settings( - groups=[{"atoms": [1], "color": color, "name": name}], - bands_range=None, Erange=None - ) - - assert "groups_weights" in plot._for_backend - assert len(plot._for_backend["groups_weights"]) == 1 - assert name in plot._for_backend["groups_weights"] - - group_weights = plot._for_backend["groups_weights"][name] - assert isinstance(group_weights, DataArray) - assert set(group_weights.dims) == set(("spin", "k", "band")) - group_weights_shape = test_attrs["weights_shape"][:-1] - if not test_attrs["spin"].is_polarized: - group_weights_shape = (1, *group_weights_shape) - assert group_weights.transpose("spin", "k", "band").shape == group_weights_shape - - assert "groups_metadata" in plot._for_backend - assert len(plot._for_backend["groups_metadata"]) == 1 - assert name in plot._for_backend["groups_metadata"] - assert plot._for_backend["groups_metadata"][name]["style"]["line"]["color"] == color - - @pytest.mark.parametrize("request_atoms", [None, {"index": 0}]) - def _test_split_groups(self, plot, constraint_atoms): - - # Number of groups that each splitting should give - expected_splits = [ - ('species', len(plot.geometry.atoms.atom)), - ('atoms', plot.geometry.na), - ('orbitals', plot.geometry.no) - ] - - plot.update_settings(groups=[]) - # Check that there are no groups - assert len(plot._for_backend["groups_weights"]) == 0 - assert len(plot._for_backend["groups_metadata"]) == 0 - - # Check that each splitting works as expected - for group_by, n_groups in expected_splits: - plot.split_groups(group_by, atoms=constraint_atoms) - if constraint_atoms is None: - err_message = f'Not correctly grouping by {group_by}' - assert len(plot._for_backend["groups_weights"]) == n_groups, err_message - assert len(plot._for_backend["groups_metadata"]) == n_groups, err_message diff --git a/src/sisl/viz/plots/tests/test_geometry.py b/src/sisl/viz/plots/tests/test_geometry.py deleted file mode 100644 index 2e5dacd86b..0000000000 --- a/src/sisl/viz/plots/tests/test_geometry.py +++ /dev/null @@ -1,228 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" - -Tests specific functionality of the bands plot. - -Different inputs are tested (siesta .bands and sisl Hamiltonian). - -""" -import numpy as np -import pytest - -import sisl -from sisl.messages import SislWarning -from sisl.viz.plots.geometry import GeometryPlot -from sisl.viz.plots.tests.conftest import _TestPlot - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - - -def test_cross_product(): - cell = np.eye(3) * 2 - z_dir = np.array([0, 0, 1]) - - products = [ - ["x", "y", z_dir], ["-x", "y", -z_dir], ["-x", "-y", z_dir], - ["b", "c", cell[0]], ["c", "b", -cell[0]], - np.eye(3) - ] - - for v1, v2, result in products: - assert np.all(GeometryPlot._cross_product(v1, v2, cell) == result) - - -class TestGeometry(_TestPlot): - - @pytest.fixture(scope="class", params=["sisl_geom", "ghost_atoms"]) - def init_func_and_attrs(self, request): - name = request.param - - if name == "sisl_geom": - init_func = sisl.geom.graphene(orthogonal=True).plot - elif name == "ghost_atoms": - init_func = sisl.Geometry([[0, 0, 1], [1, 0, 0]], atoms=[sisl.Atom(6), sisl.Atom(-6)]).plot - - attrs = {} - - return init_func, attrs - - @pytest.fixture(scope="class", params=[None, *sisl.viz.GeometryPlot.get_class_param("backend").options]) - def backend(self, request): - return request.param - - @pytest.fixture(params=[1, 2, 3]) - def ndim(self, request, backend): - if backend == "matplotlib" and request.param == 3: - pytest.skip("Matplotlib 3D representations are not available yet") - return request.param - - @pytest.fixture(params=["cartesian", "lattice", "explicit"]) - def axes(self, request, ndim): - if request.param == "cartesian": - return {1: "x", 2: "x-y", 3: "xyz"}[ndim] - elif request.param == "lattice": - # We don't test the 3D case because it doesn't work - if ndim == 3: - pytest.skip("3D view doesn't support fractional coordinates") - return {1: "a", 2: "a-b"}[ndim] - elif request.param == "explicit": - if ndim == 3: - pytest.skip("3D view doesn't support explicit directions") - return { - 1: [[1, 1, 0]], - 2: [[1, 1, 0], [0, 1, 1]], - }[ndim] - - @pytest.fixture(params=["Unit cell", "supercell"]) - def nsc(self, request): - return {"Unit cell": [1, 1, 1], "supercell": [2, 1, 1]}[request.param] - - def _check_all_atomic_props_shape(self, backend_info, na, nsc_val): - na_sc = na*nsc_val[0]*nsc_val[1]*nsc_val[2] - - for key, value in backend_info["atoms_props"].items(): - if not isinstance(value, np.ndarray): - continue - - assert value.shape[0] == na_sc, f"'{key}' doesn't have the appropiate shape" - - if key == "xy": - assert value.shape[1] == 2 - elif key == "xyz": - assert value.shape[1] == 3 - - @pytest.mark.parametrize("atoms, na", [([], 0), (0, 1), (None, "na")]) - def test_atoms(self, plot, axes, nsc, atoms, na): - plot.update_settings(axes=axes, nsc=nsc, show_bonds=False, show_cell=False, atoms=atoms) - - if na == "na": - na = plot.geometry.na - - backend_info = plot._for_backend - self._check_all_atomic_props_shape(backend_info, na, nsc) - - @pytest.mark.parametrize("show_bonds", [False, True]) - def test_toggle_bonds(self, plot, axes, ndim, nsc, show_bonds, test_attrs): - plot.update_settings(axes=axes, nsc=nsc, show_bonds=show_bonds, bind_bonds_to_ats=True, show_cell=False, atoms=[]) - assert len(plot._for_backend["bonds_props"]) == 0 - - plot.update_settings(bind_bonds_to_ats=False) - - backend_info = plot._for_backend - bonds_props = backend_info["bonds_props"] - if not test_attrs.get("no_bonds", False): - n_bonds = len(bonds_props) - if show_bonds and ndim > 1: - assert n_bonds > 0 - if ndim == 2: - assert bonds_props[0]["xys"].shape == (2, 2) - elif ndim == 3: - assert bonds_props[0]["xyz1"].shape == (3,) - assert bonds_props[0]["xyz2"].shape == (3,) - else: - assert n_bonds == 0 - - @pytest.mark.parametrize("show_cell", [False, "box", "axes"]) - def test_cell(self, plot, axes, show_cell): - plot.update_settings(axes=axes, show_cell=show_cell) - - assert plot._for_backend["show_cell"] == show_cell - - @pytest.mark.parametrize("show_cell", [False, "box", "axes"]) - def test_cell_styles(self, plot, axes, show_cell): - cell_style = {"color": "red", "width": 2, "opacity": 0.6} - plot.update_settings(axes=axes, show_cell=show_cell, cell_style=cell_style) - - assert plot._for_backend["cell_style"] == cell_style - - def test_atoms_sorted_2d(self, plot): - plot.update_settings(atoms=None, axes="yz", nsc=[1, 1, 1]) - - # Check that atoms are sorted along x - assert np.allclose(plot.geometry.xyz[:, 1:][plot.geometry.xyz[:, 0].argsort()], plot._for_backend["atoms_props"]["xy"]) - - def test_atoms_style(self, plot, axes, ndim, nsc): - plot.update_settings(atoms=None, axes=axes, nsc=nsc) - - rand_values = np.random.random(plot.geometry.na) - atoms_style = {"color": rand_values, "size": rand_values, "opacity": rand_values} - - new_atoms_style = {"atoms": 0, "color": 2, "size": 2, "opacity": 0.3} - - if ndim == 2: - depth_vector = plot._cross_product(*plot.get_setting("axes"), plot.geometry.cell) - sorted_atoms = np.concatenate(plot.geometry.sort(vector=depth_vector, ret_atoms=True)[1]) - else: - sorted_atoms = plot.geometry._sanitize_atoms(None) - - # Try both passing a dictionary and a list with one dictionary - for i, atoms_style_val in enumerate((atoms_style, [atoms_style], [atoms_style, new_atoms_style])): - plot.update_settings(atoms_style=atoms_style_val) - - backend_info = plot._for_backend - self._check_all_atomic_props_shape(backend_info, plot.geometry.na, nsc) - - if i != 2: - for key in atoms_style: - if not (ndim == 3 and key == "color"): - assert np.allclose( - backend_info["atoms_props"][key].astype(float), - np.tile(atoms_style[key][sorted_atoms], nsc[0]*nsc[1]*nsc[2]) - ) - else: - for key in atoms_style: - if not (ndim == 3 and key == "color"): - assert np.isclose( - backend_info["atoms_props"][key].astype(float), - np.tile(atoms_style[key][sorted_atoms], nsc[0]*nsc[1]*nsc[2]) - ).sum() == (plot.geometry.na - 1) * nsc[0]*nsc[1]*nsc[2] - - def test_bonds_style(self, plot, axes, ndim, nsc): - if ndim == 1: - return - - bonds_style = {"width": 2, "opacity": 0.6} - - plot.update_settings(atoms=None, axes=axes, nsc=nsc, bonds_style=bonds_style) - - bonds_props = plot._for_backend["bonds_props"] - - assert bonds_props[0]["width"] == 2 - assert bonds_props[0]["opacity"] == 0.6 - - plot.update_settings(bonds_style={}) - - def test_arrows(self, plot, axes, ndim, nsc): - # Check that arrows accepts both a dictionary and a list and the data is properly transferred - for arrows in ({"data": [0, 0, 2]}, [{"data": [0, 0, 2]}]): - plot.update_settings(axes=axes, arrows=arrows, atoms=None, nsc=nsc, atoms_style=[]) - arrow_data = plot._for_backend["arrows"][0]["data"] - assert arrow_data.shape == (plot.geometry.na * nsc[0]*nsc[1]*nsc[2], ndim) - assert not np.isnan(arrow_data).any() - - # Now check that atom selection works - plot.update_settings(arrows=[{"atoms": 0, "data": [0, 0, 2]}]) - arrow_data = plot._for_backend["arrows"][0]["data"] - assert arrow_data.shape == (plot.geometry.na * nsc[0]*nsc[1]*nsc[2], ndim) - assert np.isnan(arrow_data).any() - assert not np.isnan(arrow_data[0]).any() - - # Check that if atoms is provided, data is only stored for those atoms that are going to be - # displayed - plot.update_settings(atoms=0, arrows=[{"atoms": 0, "data": [0, 0, 2]}]) - arrow_data = plot._for_backend["arrows"][0]["data"] - assert arrow_data.shape == (nsc[0]*nsc[1]*nsc[2], ndim) - assert not np.isnan(arrow_data).any() - - # Check that if no data is provided for the atoms that are displayed, arrow data is not stored - # We also check that a warning is being raised because we are providing arrow data for atoms that - # are not being displayed. - with pytest.warns(SislWarning): - plot.update_settings(atoms=1, arrows=[{"atoms": 0, "data": [0, 0, 2]}]) - assert len(plot._for_backend["arrows"]) == 0 - - # Finally, check that multiple arrows are passed to the backend - plot.update_settings(atoms=None, arrows=[{"data": [0, 0, 2]}, {"data": [1, 0, 0]}]) - assert len(plot._for_backend["arrows"]) == 2 diff --git a/src/sisl/viz/plots/tests/test_grid.py b/src/sisl/viz/plots/tests/test_grid.py deleted file mode 100644 index b1817a3002..0000000000 --- a/src/sisl/viz/plots/tests/test_grid.py +++ /dev/null @@ -1,266 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" - -Tests specific functionality of the grid plot. - -Different inputs are tested (siesta .RHO and sisl Hamiltonian). - -""" -import os.path as osp -from typing import ChainMap - -import numpy as np -import pytest - -import sisl -from sisl.viz import Animation, GridPlot -from sisl.viz.plots.tests.conftest import _TestPlot - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - -try: - import skimage - skip_skimage = pytest.mark.skipif(False, reason="scikit-image (skimage) not available") -except ImportError: - skip_skimage = pytest.mark.skipif(True, reason="scikit-image (skimage) not available") - -try: - import plotly - skip_plotly = pytest.mark.skipif(False, reason="plotly not available") -except ImportError: - skip_plotly = pytest.mark.skipif(True, reason="plotly not available") - - -class TestGridPlot(_TestPlot): - - _required_attrs = [ - "grid_shape" # Tuple indicating the grid shape - ] - - @pytest.fixture(scope="class", params=["siesta_RHO", "VASP CHGCAR", "complex_grid"]) - def init_func_and_attrs(self, request, siesta_test_files, vasp_test_files): - name = request.param - - if name == "siesta_RHO": - init_func = sisl.get_sile(siesta_test_files("SrTiO3.RHO")).plot - attrs = {"grid_shape": (48, 48, 48)} - if name == "VASP CHGCAR": - init_func = sisl.get_sile(vasp_test_files(osp.join("graphene", "CHGCAR"))).plot.grid - attrs = {"grid_shape": (24, 24, 100)} - elif name == "complex_grid": - complex_grid_shape = (8, 10, 10) - np.random.seed(1) - values = np.random.random(complex_grid_shape).astype(np.complex128) + np.random.random(complex_grid_shape) * 1j - complex_grid = sisl.Grid(complex_grid_shape, lattice=1) - complex_grid.grid = values - - init_func = complex_grid.plot - attrs = {"grid_shape": complex_grid_shape} - - return init_func, attrs - - @pytest.fixture(scope="class", params=[None, *sisl.viz.GridPlot.get_class_param("backend").options]) - def backend(self, request): - return request.param - - @pytest.fixture(scope="function", params=["imag", "mod", "rad_phase", "deg_phase", "real"]) - def grid_representation(self, request, plot): - """ - Fixture that returns all possible grid representations of the grid that we are testing. - - To be used only in methods of test classes. - - Returns - ----------- - sisl.Grid: - a new grid that contains the specific representation of the plot's grid - str: - the name of the representation that we are returning - """ - - representation = request.param - - # Copy the plot's grid - new_grid = plot.grid.copy() - - # Substitute the values by the appropiate representations - new_grid.grid = GridPlot._get_representation(new_grid, representation) - - return (new_grid, representation) - - @pytest.fixture(scope="class", params=[1, 2, 3]) - def ndim(self, request, backend): - if backend == "matplotlib" and request.param == 3: - pytest.skip("Matplotlib 3D representations are not available yet") - if request.param > 1: - pytest.importorskip("skimage") - return request.param - - @pytest.fixture(scope="class", params=["cartesian", "lattice"]) - def axes(self, request, ndim): - if request.param == "cartesian": - return {1: "x", 2: "xy", 3: "xyz"}[ndim] - elif request.param == "lattice": - # We don't test the 3D case because it doesn't work - if ndim == 3: - pytest.skip("3D view doesn't support fractional coordinates") - return {1: "a", 2: "ab"}[ndim] - - @pytest.fixture(scope="class") - def lattice_axes(self, ndim): - return {1: [0], 2: [0, 1], 3: [0, 1, 2]}[ndim] - - @pytest.fixture(params=["Unit cell", "supercell"]) - def nsc(self, request): - return {"Unit cell": [1, 1, 1], "supercell": [2, 2, 2]}[request.param] - - def _get_plotted_values(self, plot): - - ndim = len(plot.get_setting("axes")) - if ndim < 3: - values = plot._for_backend["values"] - if ndim == 2: - values = values.T - return values - elif ndim == 3: - return plot._for_backend["isosurfaces"][0]["vertices"] - - def test_values(self, plot, ndim, axes, nsc): - plot.update_settings(axes=axes, nsc=nsc) - - if ndim < 3: - assert "values" in plot._for_backend - assert plot._for_backend["values"].ndim == ndim - - elif ndim == 3: - plot.update_settings(isos=[]) - assert "isosurfaces" in plot._for_backend - - assert len(plot._for_backend["isosurfaces"]) == 2 - - for iso in plot._for_backend["isosurfaces"]: - assert set(("vertices", "faces", "color", "opacity", "name")) == set(iso) - assert iso["vertices"].shape[1] == 3 - assert iso["faces"].shape[1] == 3 - - def test_ax_ranges(self, plot, axes, ndim, nsc): - if ndim == 3: - return - - plot.update_settings(axes=axes, nsc=nsc) - values = plot._for_backend["values"] - - if ndim == 1: - assert values.shape == plot._for_backend["ax_range"].shape - if ndim == 2: - assert (values.shape[1], ) == plot._for_backend["x"].shape - assert (values.shape[0], ) == plot._for_backend["y"].shape - - plot.update_settings(nsc=[1, 1, 1]) - - def test_representation(self, plot, lattice_axes, grid_representation): - - kwargs = {"isos": [], "reduce_method": "average"} - - ndim = len(lattice_axes) - if ndim == 3: - kwargs["isos"] = [{"frac": 0.5}] - - new_grid, representation = grid_representation - - if new_grid.grid.min() == new_grid.grid.max() and ndim == 3: - return - - plot.update_settings(axes=lattice_axes, represent=representation, nsc=[1, 1, 1], **kwargs) - new_plot = new_grid.plot(**ChainMap(plot.settings, dict(axes=lattice_axes, represent="real", grid_file=None))) - - assert np.allclose( - self._get_plotted_values(plot), self._get_plotted_values(plot=new_plot) - ), f"'{representation}' representation of the {ndim}D plot is not correct" - - def test_grid(self, plot, test_attrs): - grid = plot.grid - - assert isinstance(grid, sisl.Grid) - assert grid.shape == test_attrs["grid_shape"] - - @skip_skimage - @skip_plotly - def test_scan(self, plot, backend): - import plotly.graph_objs as go - plot.update_settings(axes="xy") - # AS_IS SCAN - # Provide number of steps - if backend == "plotly": - scanned = plot.scan("z", num=2, mode="as_is") - assert isinstance(scanned, Animation) - assert len(scanned.frames) == 2 - - # Provide step in Ang - step = plot.grid.cell[2, 2]/2 - scanned = plot.scan(along="z", step=step, mode="as_is") - assert len(scanned.frames) == 2 - - # Provide breakpoints - breakpoints = [plot.grid.cell[2, 2]*frac for frac in [1/3, 2/3, 3/3]] - scanned = plot.scan(along="z", breakpoints=breakpoints, mode="as_is") - assert len(scanned.frames) == 2 - - # Check that it doesn't accept step and breakpoints at the same time - with pytest.raises(ValueError): - plot.scan(along="z", step=4.5, breakpoints=breakpoints, mode="as_is") - - # 3D SCAN - breakpoints = [plot.grid.cell[0, 0]*frac for frac in [1/3, 2/3, 3/3]] - scanned = plot.scan(along="z", mode="moving_slice", breakpoints=breakpoints) - - assert isinstance(scanned, go.Figure) - assert len(scanned.frames) == 3 # One cross section for each breakpoint - - @skip_skimage - def test_supercell(self, plot): - plot.update_settings(axes=[0, 1], interp=[1, 1, 1], nsc=[1, 1, 1]) - - # Check that the shapes for the unit cell are right - uc_shape = plot._for_backend["values"].shape - assert uc_shape == (plot.grid.shape[1], plot.grid.shape[0]) - - # Check that the supercell is displayed - plot.update_settings(nsc=[2, 1, 1]) - sc_shape = plot._for_backend["values"].shape - assert sc_shape[1] == 2*uc_shape[1] - assert sc_shape[0] == uc_shape[0] - - plot.update_settings(nsc=[1, 1, 1]) - - @pytest.mark.parametrize("reduce_method", ["sum", "average"]) - def test_reduce_method(self, plot, reduce_method, lattice_axes, grid_representation): - new_grid, representation = grid_representation - - # If this is a 3D plot, no dimension is reduced, therefore it makes no sense - if len(lattice_axes) == 3: - return - - numpy_func = getattr(np, reduce_method) - - plot.update_settings(axes=lattice_axes, reduce_method=reduce_method, represent=representation, transforms=[]) - - assert np.allclose( - self._get_plotted_values(plot), numpy_func(new_grid.grid, axis=tuple(ax for ax in [0, 1, 2] if ax not in lattice_axes)) - ) - - def test_transforms(self, plot, lattice_axes, grid_representation): - - if len(lattice_axes) == 3: - return - - new_grid, representation = grid_representation - - plot.update_settings(axes=lattice_axes, reduce_method="average", transforms=["cos"], represent=representation, nsc=[1, 1, 1]) - - # Check that transforms = ["cos"] applies np.cos - assert np.allclose( - self._get_plotted_values(plot), np.cos(new_grid.grid).mean(axis=tuple(ax for ax in [0, 1, 2] if ax not in lattice_axes)) - ) diff --git a/src/sisl/viz/plots/tests/test_pdos.py b/src/sisl/viz/plots/tests/test_pdos.py deleted file mode 100644 index 65ccaf24eb..0000000000 --- a/src/sisl/viz/plots/tests/test_pdos.py +++ /dev/null @@ -1,211 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" - -Tests specific functionality of the PDOS plot. - -Different inputs are tested (siesta .PDOS and sisl Hamiltonian). - -""" -from functools import partial - -import pytest - -import sisl -from sisl.viz.plots.tests.conftest import _TestPlot - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - - -@pytest.fixture(params=[True, False], ids=["inplace_split", "method_splitting"]) -def inplace_split(request): - return request.param - - -class TestPdosPlot(_TestPlot): - - _required_attrs = [ - "na", # int, number of atoms in the geometry - "no", # int, number of orbitals in the geometry - "n_spin", # int, number of spin components for the PDOS - "species", # array-like of str. The names of the species. - ] - - @pytest.fixture(scope="class", params=[ - # From a siesta PDOS file - "siesta_PDOS_file_unpolarized", "siesta_PDOS_file_polarized", "siesta_PDOS_file_noncollinear", - # From a sisl hamiltonian - "sisl_H_unpolarized", "sisl_H_polarized", "sisl_H_noncolinear", "sisl_H_spinorbit", - # From a WFSX file and the overlap matrix - "wfsx_file" - - ]) - def init_func_and_attrs(self, request, siesta_test_files): - name = request.param - - if name.startswith("siesta_PDOS_file"): - - spin_type = name.split("_")[-1] - - n_spin, filename = { - "unpolarized": (1, "SrTiO3.PDOS"), - "polarized": (2, "SrTiO3_polarized.PDOS"), - "noncollinear": (4, "SrTiO3_noncollinear.PDOS") - }[spin_type] - - init_func = sisl.get_sile(siesta_test_files(filename)).plot - attrs = { - "na": 5, - "no": 72, - "n_spin": n_spin, - "species": ('Sr', 'Ti', 'O') - } - elif name.startswith("sisl_H"): - gr = sisl.geom.graphene() - H = sisl.Hamiltonian(gr) - H.construct([(0.1, 1.44), (0, -2.7)]) - - spin_type = name.split("_")[-1] - - n_spin, H = { - "unpolarized": (1, H), - "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), - "noncolinear": (4, H.transform(spin=sisl.Spin.NONCOLINEAR)), - "spinorbit": (4, H.transform(spin=sisl.Spin.SPINORBIT)) - }[spin_type] - - init_func = partial(H.plot.pdos, Erange=[-5, 5]) - attrs = { - "na": 2, - "no": 2, - "n_spin": n_spin, - "species": ('C',) - } - elif name == "wfsx_file": - # From a siesta .WFSX file - # Since there is no hamiltonian for bi2se3_3ql.fdf, we create a dummy one - wfsx = sisl.get_sile(siesta_test_files("bi2se3_3ql.bands.WFSX")) - - geometry = sisl.get_sile(siesta_test_files("bi2se3_3ql.fdf")).read_geometry() - geometry = sisl.Geometry(geometry.xyz, atoms=wfsx.read_basis()) - - H = sisl.Hamiltonian(geometry, dim=4) - - init_func = partial( - H.plot.pdos, wfsx_file=wfsx, - entry_points_order=["wfsx file"] - ) - - attrs = { - "na": 15, - "no": 195, - "n_spin": 4, - "species": ('Bi', 'Se') - } - - return init_func, attrs - - @pytest.fixture(scope="class", params=[None, *sisl.viz.PdosPlot.get_class_param("backend").options]) - def backend(self, request): - return request.param - - def test_dataarray(self, plot, test_attrs): - pytest.importorskip("xarray") - from xarray import DataArray - - PDOS = plot.PDOS - geom = plot.geometry - - assert isinstance(PDOS, DataArray) - assert isinstance(geom, sisl.Geometry) - - # Check if we have the correct number of orbitals - assert len(PDOS.orb) == test_attrs["no"] == geom.no - - def test_request_PDOS(self, plot): - total_DOS = plot._get_request_PDOS({}) - - assert total_DOS.ndim == 1 - assert total_DOS.shape == (plot.PDOS.E.shape) - - def test_splitDOS(self, plot, test_attrs, inplace_split): - if inplace_split: - def split_DOS(on, **kwargs): - return plot.update_settings(requests=[{"split_on": on, **kwargs}]) - else: - split_DOS = plot.split_DOS - - unique_orbs = plot.get_param('requests')['orbitals'].options - - expected_splits = { - "species": (len(test_attrs["species"]), test_attrs["species"][0]), - "atoms": (test_attrs["na"], 1), - "orbitals": (len(unique_orbs), unique_orbs[0]), - "spin": (test_attrs["n_spin"], None), - } - - # Test all splittings - for on, (n, toggle_val) in expected_splits.items(): - err_message = f'Error splitting DOS based on {on}' - assert len(split_DOS(on=on)._for_backend["PDOS_values"]) == n, err_message - if toggle_val is not None and not inplace_split: - assert len(split_DOS(on=on, only=[toggle_val])._for_backend["PDOS_values"]) == 1, err_message - assert len(split_DOS(on=on, exclude=[toggle_val])._for_backend["PDOS_values"]) == n - 1, err_message - - def test_composite_splitting(self, plot, inplace_split): - - if inplace_split: - def split_DOS(on, **kwargs): - return plot.update_settings(requests=[{"split_on": on, **kwargs}]) - else: - split_DOS = plot.split_DOS - - split_DOS(on="species+orbitals", name="This is $species") - - first_trace_name = list(plot._for_backend["PDOS_values"].keys())[0] - assert "This is " in first_trace_name, "Composite splitting not working" - assert "species" not in first_trace_name, "Name templating not working in composite splitting" - assert "orbitals=" in first_trace_name, "Name templating not working in composite splitting" - - @pytest.mark.parametrize("request_atoms", [0, {"index": 0}]) - def test_request_splitting(self, plot, inplace_split, request_atoms): - - # Here we are just checking that, when splitting a request - # the plot understands that it has constrains - plot.update_settings(requests=[{"atoms": request_atoms}]) - prev_len = len(plot._for_backend["PDOS_values"]) - - # Even if there are more atoms, the plot should understand - # that it is constrained to the values of the current request - - if inplace_split: - plot.update_settings(requests=[{"atoms": request_atoms, "split_on": "atoms"}]) - else: - plot.split_requests(0, on="atoms") - - assert len(plot._for_backend["PDOS_values"]) == prev_len - - def test_request_management(self, plot, test_attrs): - - plot.update_settings(requests=[]) - assert len(plot._for_backend["PDOS_values"]) == 0 - - sel_species = test_attrs["species"][0] - plot.add_request({"species": [sel_species]}) - assert len(plot._for_backend["PDOS_values"]) == 1 - - # Try to split this request in multiple ones - plot.split_requests(0, on="orbitals") - # Get the number of orbitals with unique names - species_no = len(set(orb.name() for orb in plot.geometry.atoms[sel_species].orbitals)) - assert len(plot._for_backend["PDOS_values"]) == species_no - - # Then try to merge - # if species_no >= 2: - # plot.merge_requests(species_no - 1, species_no - 2) - # assert len(plot.data) == species_no - 1 - - # And try to remove one request - prev = len(plot._for_backend["PDOS_values"]) - assert len(plot.remove_requests(0)._for_backend["PDOS_values"]) == prev - 1 diff --git a/src/sisl/viz/plots/tests/test_plots.py b/src/sisl/viz/plots/tests/test_plots.py deleted file mode 100644 index 4a6b77ce92..0000000000 --- a/src/sisl/viz/plots/tests/test_plots.py +++ /dev/null @@ -1,47 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" - -These tests check that all plot subclasses fulfill at least the most basic stuff - -More tests should be run on each plot, but these are the most basic ones to -ensure that at least they do not break basic plot functionality. -""" -import pytest - -from sisl.viz.plots import * -from sisl.viz.plotutils import get_plot_classes -from sisl.viz.tests.test_plot import _TestPlotClass - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - -# Test all plot subclasses with the subclass tester - -# The following function basically tells pytest to run TestPlotSubClass -# once for each plot class. It takes care of setting the _cls attribute -# to the corresponding plot class. - - -@pytest.fixture(autouse=True, scope="class", params=get_plot_classes()) -def plot_class(request): - request.cls._cls = request.param - - -class TestPlotSubClass(_TestPlotClass): - - def test_compulsory_methods(self): - - assert hasattr(self._cls, "_set_data") - assert callable(self._cls._set_data) - - assert hasattr(self._cls, "_plot_type") - assert isinstance(self._cls._plot_type, str) - - def test_param_groups(self): - - plot = self._init_plot_without_warnings() - - for group in plot.param_groups: - for key in ("key", "name", "icon", "description"): - assert key in group, f'{self._cls.__name__} is missing {key} in parameters group {group}' diff --git a/src/sisl/viz/plots/tests/test_wavefunction.py b/src/sisl/viz/plots/tests/test_wavefunction.py deleted file mode 100644 index 819ad50131..0000000000 --- a/src/sisl/viz/plots/tests/test_wavefunction.py +++ /dev/null @@ -1,27 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from functools import partial - -import pytest - -import sisl -from sisl.viz.plots.tests.test_grid import TestGridPlot as _TestGridPlot - - -class TestWavefunctionPlot(_TestGridPlot): - - @pytest.fixture(scope="class", params=["wfsx file"]) - def init_func_and_attrs(self, request, siesta_test_files): - name = request.param - - if name == "wfsx file": - pytest.skip("Basis for bi2se3_3ql.fdf is not available in the test files.") - fdf = sisl.get_sile(siesta_test_files("bi2se3_3ql.fdf")) - wfsx = siesta_test_files("bi2se3_3ql.bands.WFSX") - init_func = partial( - fdf.plot.wavefunction, wfsx_file=wfsx, k=(0.003, 0.003, 0), - entry_points_order=["wfsx file"]) - - attrs = {"grid_shape": (48, 48, 48)} - return init_func, attrs diff --git a/src/sisl/viz/plotters/__init__.py b/src/sisl/viz/plotters/__init__.py new file mode 100644 index 0000000000..9a5ac8cc0b --- /dev/null +++ b/src/sisl/viz/plotters/__init__.py @@ -0,0 +1,3 @@ +"""Functions that generate plot actions to be passed to figures.""" + +from . import plot_actions \ No newline at end of file diff --git a/src/sisl/viz/plotters/cell.py b/src/sisl/viz/plotters/cell.py new file mode 100644 index 0000000000..0ed597f7e0 --- /dev/null +++ b/src/sisl/viz/plotters/cell.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Literal, Sequence + +from ..processors.cell import cell_to_lines, gen_cell_dataset +from ..processors.coords import project_to_axes +from ..types import Axes, CellLike +from .xarray import draw_xarray_xy + + +def get_ndim(axes: Axes) -> int: + return len(axes) + +def get_z(ndim: int) -> Literal["z", False]: + if ndim == 3: + z = "z" + else: + z = False + return z + +def cell_plot_actions(cell: CellLike = None, show_cell: Literal[False, "box", "axes"] = "box", axes=["x", "y", "z"], + name: str = "Unit cell", cell_style={}, dataaxis_1d=None): + if show_cell == False: + cell_plottings = [] + else: + cell_ds = gen_cell_dataset(cell) + cell_lines = cell_to_lines(cell_ds, show_cell, cell_style) + projected_cell_lines = project_to_axes(cell_lines, axes=axes, dataaxis_1d=dataaxis_1d) + + ndim = get_ndim(axes) + z = get_z(ndim) + cell_plottings = draw_xarray_xy(data=projected_cell_lines, x="x", y="y", z=z, set_axequal=ndim > 1, name=name) + + return cell_plottings \ No newline at end of file diff --git a/src/sisl/viz/plotters/grid.py b/src/sisl/viz/plotters/grid.py new file mode 100644 index 0000000000..40fd413175 --- /dev/null +++ b/src/sisl/viz/plotters/grid.py @@ -0,0 +1,56 @@ +import sisl.viz.plotters.plot_actions as plot_actions +from sisl.viz.processors.grid import get_isos + + +def draw_grid(data, isos=[], colorscale=None, crange=None, cmid=None, smooth=False): + + to_plot = [] + + ndim = data.ndim + + if ndim == 1: + to_plot.append( + plot_actions.draw_line(x=data.x, y=data.values) + ) + elif ndim == 2: + transposed = data.transpose("y", "x") + + cmin, cmax = crange if crange is not None else (None, None) + + to_plot.append( + plot_actions.init_coloraxis(name="grid_color", cmin=cmin, cmax=cmax, cmid=cmid, colorscale=colorscale) + ) + + + to_plot.append( + plot_actions.draw_heatmap(values=transposed.values, x=data.x, y=data.y, name="HEAT", zsmooth="best" if smooth else False, coloraxis="grid_color") + ) + + dx = data.x[1] - data.x[0] + dy = data.y[1] - data.y[0] + + iso_lines = get_isos(transposed, isos) + for iso_line in iso_lines: + iso_line['line'] = { + "color": iso_line.pop("color", None), + "opacity": iso_line.pop("opacity", None), + "width": iso_line.pop("width", None), + **iso_line.get("line", {}) + } + to_plot.append( + plot_actions.draw_line(**iso_line) + ) + elif ndim == 3: + isosurfaces = get_isos(data, isos) + + for isosurface in isosurfaces: + to_plot.append( + plot_actions.draw_mesh_3D(**isosurface) + ) + + if ndim > 1: + to_plot.append( + plot_actions.set_axes_equal() + ) + + return to_plot \ No newline at end of file diff --git a/src/sisl/viz/plotters/plot_actions.py b/src/sisl/viz/plotters/plot_actions.py new file mode 100644 index 0000000000..69f224f60a --- /dev/null +++ b/src/sisl/viz/plotters/plot_actions.py @@ -0,0 +1,45 @@ +"""Contains all the individual actions that can be performed on a figure.""" + +import functools +import inspect +import sys +from typing import Literal, Optional, Type + +from ..figure import Figure + + +def _register_actions(figure_cls: Type[Figure]): + + # Take all actions possible from the Figure class + module = sys.modules[__name__] + + actions = inspect.getmembers(figure_cls, predicate=lambda x: inspect.isfunction(x) and not x.__name__.startswith("_")) + + for name, function in actions: + + sig = inspect.signature(function) + + @functools.wraps(function) + def a(*args, __method_name__=function.__name__, **kwargs): + return dict(method=__method_name__, args=args, kwargs=kwargs) + + a.__signature__ = sig.replace(parameters=list(sig.parameters.values())[1:]) + a.__module__ = module + + setattr(module, name, a) + +_register_actions(Figure) + +def combined(*plotters, + composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = None, + provided_list: bool = False, + **kwargs +): + if provided_list: + plotters = plotters[0] + + return { + "composite_method": composite_method, + "plot_actions": plotters, + "init_kwargs": kwargs + } \ No newline at end of file diff --git a/src/sisl/viz/plotters/tests/test_xarray.py b/src/sisl/viz/plotters/tests/test_xarray.py new file mode 100644 index 0000000000..42efe34816 --- /dev/null +++ b/src/sisl/viz/plotters/tests/test_xarray.py @@ -0,0 +1,22 @@ +import xarray as xr + +from sisl.viz.plotters.xarray import draw_xarray_xy + + +def test_empty_dataset(): + + ds = xr.Dataset({"x": ("dim", []), "y": ("dim", [])}) + + drawings = draw_xarray_xy(ds, x="x", y="y") + + assert isinstance(drawings, list) + assert len(drawings) == 0 + +def test_empty_dataarray(): + + arr = xr.DataArray([], name="values", dims=['x']) + + drawings = draw_xarray_xy(arr, x="x") + + assert isinstance(drawings, list) + assert len(drawings) == 0 \ No newline at end of file diff --git a/src/sisl/viz/plotters/xarray.py b/src/sisl/viz/plotters/xarray.py new file mode 100644 index 0000000000..c5922eba4f --- /dev/null +++ b/src/sisl/viz/plotters/xarray.py @@ -0,0 +1,299 @@ +import itertools +import typing + +import numpy as np +from xarray import DataArray + +import sisl.viz.plotters.plot_actions as plot_actions +from sisl.messages import info + +#from sisl.viz.nodes.processors.grid import get_isos + +def _process_xarray_data(data, x=None, y=None, z=False, style={}): + axes = {"x": x, "y": y} + if z is not False: + axes["z"] = z + + ndim = len(axes) + + # Normalize data to a Dataset + if isinstance(data, DataArray): + if np.all([ax is None for ax in axes.values()]): + raise ValueError("You have to provide either x or y (or z if it is not False) (one needs to be the fixed variable).") + axes = {k: v or data.name for k, v in axes.items()} + data = data.to_dataset(name=data.name) + else: + if np.any([ax is None for ax in axes.values()]): + raise ValueError("Since you provided a Dataset, you have to provide both x and y (and z if it is not False).") + + data_axis = None + fixed_axes = {} + # Check, for each axis, if it is uni dimensional (in which case we add it to the fixed axes dictionary) + # or it contains more than one dimension, in which case we set it as the data axis + for k in axes: + if axes[k] in data.coords or (axes[k] in data and data[axes[k]].ndim == 1): + if len(fixed_axes) < ndim - 1: + fixed_axes[k] = axes[k] + else: + data_axis = k + else: + data_axis = k + + # Transpose the data so that the fixed axes are first. + last_dims = [] + for ax_key, fixed_axis in fixed_axes.items(): + if fixed_axis not in data.dims: + # This means that the fixed axis is a variable, which should contain only one dimension + last_dim = data[fixed_axis].dims[-1] + else: + last_dim = fixed_axis + last_dims.append(last_dim) + last_dims = np.unique(last_dims) + data = data.transpose(..., *last_dims) + + data_var = axes[data_axis] + + style_dims = set() + for key, value in style.items(): + if value in data: + style_dims = style_dims.union(set(data[value].dims)) + + extra_style_dims = style_dims - set(data[data_var].dims) + if extra_style_dims: + data = data.stack(extra_style_dim=extra_style_dims).transpose('extra_style_dim', ...) + + if data[data_var].shape[0] == 0: + return None, None, None, None, None + + if len(data[data_var].shape) == 1: + data = data.expand_dims(dim={"fake_dim": [0]}, axis=0) + # We have to flatten all the dimensions that will not be represented as an axis, + # since we will just iterate over them. + dims_to_stack = data[data_var].dims[:-len(last_dims)] + data = data.stack(iterate_dim=dims_to_stack).transpose("iterate_dim", ...) + + styles = {} + for key, value in style.items(): + if value in data: + styles[key] = data[value] + else: + styles[key] = None + + plot_data = data[axes[data_axis]] + + fixed_coords = {} + for ax_key, fixed_axis in fixed_axes.items(): + fixed_coord = data[fixed_axis] + if "iterate_dim" in fixed_coord.dims: + # This is if fixed_coord was a variable of the dataset, which possibly has + # gotten the extra iterate_dim added. + fixed_coord = fixed_coord.isel(iterate_dim=0) + fixed_coords[ax_key] = fixed_coord + + #info(f"{self} variables: \n\t- Fixed: {fixed_axes}\n\t- Data axis: {data_axis}\n\t") + + return plot_data, fixed_coords, styles, data_axis, axes + +def draw_xarray_xy(data, x=None, y=None, z=False, color="color", width="width", dash="dash", opacity="opacity", name="", colorscale=None, + what: typing.Literal["line", "scatter", "balls", "area_line", "arrows", "none"] = "line", + dependent_axis: typing.Optional[typing.Literal["x", "y"]] = None, + set_axrange=False, set_axequal=False +): + if what == "none": + return [] + + plot_data, fixed_coords, styles, data_axis, axes = _process_xarray_data( + data, x=x, y=y, z=z, style={"color": color, "width": width, "opacity": opacity, "dash": dash} + ) + + if plot_data is None: + return [] + + to_plot = _draw_xarray_lines( + data=plot_data, style=styles, fixed_coords=fixed_coords, data_axis=data_axis, colorscale=colorscale, what=what, name=name, + dependent_axis=dependent_axis + ) + + if set_axequal: + to_plot.append(plot_actions.set_axes_equal()) + + # Set axis range + for key, coord_key in axes.items(): + if coord_key == getattr(data, "name", None): + ax = data + else: + ax = data[coord_key] + title = ax.name + units = ax.attrs.get("units") + if units: + title += f" [{units}]" + + axis = {"title": title} + + if set_axrange: + axis["range"] = (float(ax.min()), float(ax.max())) + + axis.update(ax.attrs.get("axis", {})) + + to_plot.append(plot_actions.set_axis(axis=key, **axis)) + + return to_plot + +def _draw_xarray_lines(data, style, fixed_coords, data_axis, colorscale, what, name="", dependent_axis=None): + # Initialize actions list + to_plot = [] + + # Get the lines styles + lines_style = {} + extra_style_dims = False + for key in ("color", "width", "opacity", "dash"): + lines_style[key] = style.get(key) + + if lines_style[key] is not None: + extra_style_dims = extra_style_dims or "extra_style_dim" in lines_style[key].dims + # If some style is constant, just repeat it. + if lines_style[key] is None or "iterate_dim" not in lines_style[key].dims: + lines_style[key] = itertools.repeat(lines_style[key]) + + # If we have to draw multicolored lines, we have to initialize a color axis and + # use a special drawing function. If we have to draw lines with multiple widths + # we also need to use a special function. + line_kwargs = {} + if isinstance(lines_style['color'], itertools.repeat): + color_value = next(lines_style['color']) + else: + color_value = lines_style['color'] + + if isinstance(lines_style['width'], itertools.repeat): + width_value = next(lines_style['width']) + else: + width_value = lines_style['width'] + + if isinstance(color_value, DataArray) and (data.dims[-1] in color_value.dims): + color = color_value + if color.dtype in (int, float): + coloraxis_name = f"{color.name}_{name}" if name else color.name + to_plot.append( + plot_actions.init_coloraxis(name=coloraxis_name, cmin=color.values.min(), cmax=color.values.max(), colorscale=colorscale) + ) + line_kwargs = {'coloraxis': coloraxis_name} + drawing_function_name = f"draw_multicolor_{what}" + elif isinstance(width_value, DataArray) and (data.dims[-1] in width_value.dims): + drawing_function_name = f"draw_multisize_{what}" + else: + drawing_function_name = f"draw_{what}" + + # Check if we have to use a 3D function + if len(fixed_coords) == 2: + to_plot.append(plot_actions.init_3D()) + drawing_function_name += "_3D" + + _drawing_function = getattr(plot_actions, drawing_function_name) + if what in ("scatter", "balls"): + def drawing_function(*args, **kwargs): + marker = kwargs.pop("line") + marker['size'] = marker.pop("width") + + to_plot.append( + _drawing_function(*args, marker=marker, **kwargs) + ) + elif what == "area_line": + def drawing_function(*args, **kwargs): + to_plot.append( + _drawing_function(*args, dependent_axis=dependent_axis, **kwargs) + ) + else: + def drawing_function(*args, **kwargs): + to_plot.append( + _drawing_function(*args, **kwargs) + ) + + # Define the iterator over lines, containing both values and styles + iterator = zip(data, + lines_style['color'], lines_style['width'], lines_style['opacity'], lines_style['dash'] + ) + + fixed_coords_values = {k: arr.values for k, arr in fixed_coords.items()} + + single_line = len(data.iterate_dim) == 1 + if name in data.iterate_dim.coords: + name_prefix = "" + else: + name_prefix = f"{name}_" if name and not single_line else name + + # Now just iterate over each line and plot it. + for values, *styles in iterator: + + names = values.iterate_dim.values[()] + if name in values.iterate_dim.coords: + line_name = f"{name_prefix}{values.iterate_dim.coords[name].values[()]}" + elif single_line and not isinstance(names[0], str): + line_name = name_prefix + elif len(names) == 1: + line_name = f"{name_prefix}{names[0]}" + else: + line_name = f"{name_prefix}{names}" + + parsed_styles = [] + for style in styles: + if style is not None: + style = style.values + if style.ndim == 0: + style = style[()] + parsed_styles.append(style) + + line_color, line_width, line_opacity, line_dash = parsed_styles + line_style = {"color": line_color, "width": line_width, "opacity": line_opacity, "dash": line_dash} + line = {**line_style, **line_kwargs} + + coords = { + data_axis: values, + **fixed_coords_values, + } + + if not extra_style_dims: + drawing_function(**coords, line=line, name=line_name) + else: + for k, v in line_style.items(): + if v is None or v.ndim == 0: + line_style[k] = itertools.repeat(v) + + for l_color, l_width, l_opacity, l_dash in zip(line_style['color'], line_style['width'], line_style['opacity'], line_style['dash']): + line_style = {"color": l_color, "width": l_width, "opacity": l_opacity, "dash": l_dash} + drawing_function(**coords, line=line_style, name=line_name) + + return to_plot + + + +# class PlotterNodeGrid(PlotterXArray): + +# def draw(self, data, isos=[]): + +# ndim = data.ndim + +# if ndim == 2: +# transposed = data.transpose("y", "x") + +# self.draw_heatmap(transposed.values, x=data.x, y=data.y, name="HEAT", zsmooth="best") + +# dx = data.x[1] - data.x[0] +# dy = data.y[1] - data.y[0] + +# iso_lines = get_isos(transposed, isos) +# for iso_line in iso_lines: +# iso_line['line'] = { +# "color": iso_line.pop("color", None), +# "opacity": iso_line.pop("opacity", None), +# "width": iso_line.pop("width", None), +# **iso_line.get("line", {}) +# } +# self.draw_line(**iso_line) +# elif ndim == 3: +# isosurfaces = get_isos(data, isos) + +# for isosurface in isosurfaces: +# self.draw_mesh_3D(**isosurface) + + +# self.set_axes_equal() \ No newline at end of file diff --git a/src/sisl/viz/plotutils.py b/src/sisl/viz/plotutils.py index 7070e28168..e66bc3b0d6 100644 --- a/src/sisl/viz/plotutils.py +++ b/src/sisl/viz/plotutils.py @@ -109,16 +109,14 @@ def get_plot_classes(): list all the plot classes that the module is aware of. """ - from . import Animation, MultiplePlot, Plot, SubPlots + from . import Plot def get_all_subclasses(cls): all_subclasses = [] for Subclass in cls.__subclasses__(): - - if Subclass not in [MultiplePlot, Animation, SubPlots] and not getattr(Subclass, 'is_only_base', False): - all_subclasses.append(Subclass) + all_subclasses.append(Subclass) all_subclasses.extend(get_all_subclasses(Subclass)) @@ -179,122 +177,6 @@ def get_plotable_variables(variables): return plotables -def get_configurable_docstring(cls): - """ Builds the docstring for a class that inherits from Configurable - - Parameters - ----------- - cls: - the class you want the docstring for - - Returns - ----------- - str: - the docs with the settings added. - """ - import re - - if isinstance(cls, type): - params = cls._parameters - doc = cls.__doc__ - if doc is None: - doc = "" - else: - # It's really an instance, not the class - params = cls.params - doc = "" - - configurable_settings = "\n".join([param._get_docstring() for param in params]) - - html_cleaner = re.compile('<.*?>') - configurable_settings = re.sub(html_cleaner, '', configurable_settings) - - if "Parameters\n--" not in doc: - doc += f'\n\nParameters\n-----------\n{configurable_settings}' - else: - doc += f'\n{configurable_settings}' - - return doc - - -def get_configurable_kwargs(cls_or_inst, fake_default): - """ Builds a string to help you define all the kwargs coming from the settings. - - The main point is to avoid wasting time writing all the kwargs manually, and - at the same time makes it easy to keep it consistent with the defaults. - - This may be useful, for example, for the __init__ method of plots. - - Parameters - ------------ - cls_or_inst: - the class (or instance) you want the kwargs for. - fake_default: str - only floats, ints, bools and strings can be parsed safely into strings and then into values again. - For this reason, the rest of the settings will just be given a fake default that you need to handle. - - Returns - ----------- - str: - the string containing the described kwargs. - """ - # TODO why not just repr(val)? that seems to be the same in all cases? - def get_string(val): - if isinstance(val, (float, int, bool)) or val is None: - return val - elif isinstance(val, str): - return val.__repr__() - else: - return fake_default.__repr__() - - if isinstance(cls_or_inst, type): - params = cls_or_inst._parameters - return ", ".join([f'{param.key}={get_string(param.default)}' for param in params]) - - # It's really an instance, not the class - # In this case, the defaults for the method will be the current values. - params = cls_or_inst.params - return ", ".join([f'{param.key}={get_string(cls_or_inst.settings[param.key])}' for param in params]) - - -def get_configurable_kwargs_to_pass(cls): - """ Builds a string to help you pass kwargs that you got from the function `get_configurable_kwargs`. - - E.g.: If `get_configurable_kwargs` gives you 'param1=None, param2="nothing"' - `get_configurable_kwargs_to_pass` will give you param1=param1, param2=param2 - - Parameters - ------------ - cls: - the class you want the kwargs for - - Returns - ----------- - str: - the string containing the described kwargs. - """ - if isinstance(cls, type): - params = cls._parameters - else: - # It's really an instance, not the class - params = cls.params - - return ", ".join([f'{param.key}={param.key}' for param in params]) - - -def get_session_classes(): - """ Returns the available session classes - - Returns - -------- - dict - keys are the name of the class and values are the class itself. - """ - from .session import Session - - return {sbcls.__name__: sbcls for sbcls in Session.__subclasses__()} - - def get_avail_presets(): """ Gets the names of the currently available presets. @@ -412,96 +294,10 @@ def dictOfLists2listOfDicts(dictOfLists): return [dict(zip(dictOfLists, t)) for t in zip(*dictOfLists.values())] -def call_method_if_present(obj, method_name, *args, **kwargs): - """ Calls a method of the object if it is present. - - If the method is not there, it just does nothing. - - Parameters - ----------- - method_name: str - the name of the method that you want to call. - *args and **kwargs: - arguments passed to the method call. - """ - - method = getattr(obj, method_name, None) - if callable(method): - return method(*args, **kwargs) - - -def copy_params(params, only=(), exclude=()): - """ Function that returns a copy of the provided plot parameters. - - Arguments - ---------- - params: tuple - The parameters that have to be copied. This will come presumably from the "_parameters" variable of some plot class. - only: array-like - Use this if you only want a certain set of parameters. Pass the wanted keys as a list. - exclude: array-like - Use this if there are some parameters that you don't want. Pass the unwanted keys as a list. - This argument will not be used if "only" is present. - - Returns - ---------- - copiedParams: tuple - The params that the user asked for. They are not linked to the input params, so they can be modified independently. - """ - if only: - return tuple(param for param in deepcopy(params) if param.key in only) - return tuple(param for param in deepcopy(params) if param.key not in exclude) - - -def copy_dict(dictInst, only=(), exclude=()): - """ Function that returns a copy of a dict. This function is thought to be used for the settings dictionary, for example. - - Arguments - ---------- - dictInst: dict - The dictionary that needs to be copied. - only: array-like - Use this if you only want a certain set of values. Pass the wanted keys as a list. - exclude: array-like - Use this if there are some values that you don't want. Pass the unwanted keys as a list. - This argument will not be used if "only" is present. - - Returns - ---------- - copiedDict: dict - The dictionary that the user asked for. It is not linked to the input dict, so it can be modified independently. - """ - if only: - return {k: v for k, v in deepcopy(dictInst).iteritems() if k in only} - return {k: v for k, v in deepcopy(dictInst).iteritems() if k not in exclude} - #------------------------------------- # Filesystem #------------------------------------- - -def load(path): - """ - Loads a previously saved python object using pickle. To be used for plots, sessions, etc... - - Arguments - ---------- - path: str - The path to the saved object. - - Returns - ---------- - loadedObj: object - The object that was saved. - """ - import dill - - with open(path, 'rb') as handle: - loadedObj = dill.load(handle) - - return loadedObj - - def find_files(root_dir=Path("."), search_string = "*", depth = [0, 0], sort = True, sort_func = None, case_insensitive=False): """ Function that finds files (or directories) according to some conditions. @@ -597,338 +393,6 @@ def find_plotable_siles(dir_path=None, depth=0): return files -#------------------------------------- -# Multiprocessing -#------------------------------------- - -_MAX_NPROCS = get_environ_variable("SISL_VIZ_NUM_PROCS") - - -def _apply_method(args_tuple): - """ Apply a method to an object. This function is meant for multiprocessing """ - - method, obj, args, kwargs = args_tuple - - if args is None: - args = [] - - method(obj, *args, **kwargs) - - return obj - - -def _init_single_plot(args_tuple): - """ Initialize a single plot. This function is meant to be used in multiprocessing, when multiple plots need to be initialized """ - - PlotClass, args, kwargs = args_tuple - - return PlotClass(**kwargs) - - -def run_multiple(func, *args, argsList = None, kwargsList = None, messageFn = None, serial = False): - """ - Makes use of the pathos.multiprocessing module to run a function simultanously multiple times. - This is meant mainly to update multiple plots at the same time, which can accelerate significantly the process of visualizing data. - - All arguments passed to the function, except func, can be passed as specified in the arguments section of this documentation - or as a list containing multiple instances of them. - If a list is passed, each time the function needs to be run it will take the next item of the list. - If a single item is passed instead, this item will be repeated for each function run. - However, at least one argument must be a list, so that the number of times that the function has to be ran is defined. - - Arguments - ---------- - func: function - The function to be executed. It has to be prepared to recieve the arguments as they are provided to it (zipped). - - See the applyMethod() function as an example. - *args: - Contains all the arguments that are specific to the individual function that we want to run. - See each function separately to understand what you need to pass (you may not need this parameter). - argsList: array-like - An array of arguments that have to be passed to the executed function. - - Can also be a list of arrays (see this function's description). - - WARNING: Currently it only works properly for a list of arrays. Didn't fix this because the lack of interest - of argsList on Plot's methods (everything is passed as keyword arguments). - kwargsList: dict - A dictionary with the keyword arguments that have to be passed to the executed function. - - If the executed function is a Plot's method, these can be the settings, for example. - - Can also be a list of dicts (see this function's description). - - messageFn: function - Function that recieves the number of tasks and nodes and needs to return a string to display as a description of the progress bar. - serial: bool - If set to true, multiprocessing is not used. - - This seems to have little sense, but it is useful to switch easily between multiprocessing and serial with the same code. - - Returns - ---------- - results: list - A list with all the returned values or objects from each function execution. - This list is ordered, so results[0] is the result of executing the function with argsList[0] and kwargsList[0]. - """ - #Prepare the arguments to be passed to the initSinglePlot function - toZip = [*args, argsList, kwargsList] - for i, arg in enumerate(toZip): - if not isinstance(arg, (list, tuple, np.ndarray)): - toZip[i] = itertools.repeat(arg) - else: - nTasks = len(arg) - - # Run things in serial mode in case it is demanded or pathos is not available - serial = not pathos_avail or serial or _MAX_NPROCS == 1 or nTasks == 1 - if serial: - return [func(argsTuple) for argsTuple in zip(*toZip)] - - #Create a pool with the appropiate number of processes - pool = Pool(min(nTasks, _MAX_NPROCS)) - #Define the plots array to store all the plots that we initialize - results = [None]*nTasks - - #Initialize the pool iterator and the progress bar that controls it - imap = pool.imap(func, zip(*toZip)) - if tqdm_avail: - imap = tqdm.tqdm(imap, total = nTasks) - - #Set a description for the progress bar - if not callable(messageFn): - message = "Updating {} plots in {} processes".format(nTasks, pool.nodes) - else: - message = messageFn(nTasks, pool.nodes) - - imap.set_description(message) - - #Run the processes and store each result in the plots array - for i, res in enumerate(imap): - results[i] = res - - pool.close() - pool.join() - pool.clear() - - return results - - -def init_multiple_plots(PlotClass, argsList = None, kwargsList = None, **kwargs): - """ Initializes a set of plots in multiple processes simultanously making use of runMultiple() - - All arguments passed to the function, can be passed as specified in the arguments section of this documentation - or as a list containing multiple instances of them. - If a list is passed, each time the function needs to be run it will take the next item of the list. - If a single item is passed instead, this item will be repeated for each function run. - However, at least one argument must be a list, so that the number of times that the function has to be ran is defined. - - Arguments - ---------- - PlotClass: child class of sisl.viz.plotly.Plot - The plot class that must be initialized - - Can also be a list of classes (see this function's description). - argsList: array-like - An array of arguments that have to be passed to the executed function. - - Can also be a list of arrays (see this function's description). - - WARNING: Currently it only works properly for a list of arrays. Didn't fix this because the lack of interest - of argsList on Plot's methods (everything is passed as keyword arguments). - kwargsList: dict - A dictionary with the keyword arguments that have to be passed to the executed function. - - If the executed function is a Plot's method, these can be the settings, for example. - - Can also be a list of dicts (see this function's description). - - Returns - ---------- - plots: list - A list with all the initialized plots. - This list is ordered, so plots[0] is the plot initialized with argsList[0] and kwargsList[0]. - """ - - return run_multiple(_init_single_plot, PlotClass, argsList = argsList, kwargsList = kwargsList, **kwargs) - - -def apply_method_on_multiple_objs(method, objs, argsList = None, kwargsList = None, **kwargs): - """ Applies a given method to the objects provided on multiple processes simultanously making use of the runMultiple() function. - - This is useful in principle for any kind of object and any method, but has been tested only on plots. - - All arguments passed to the function, except method, can be passed as specified in the arguments section of this documentation - or as a list containing multiple instances of them. - If a list is passed, each time the function needs to be run it will take the next item of the list. - If a single item is passed instead, this item will be repeated for each function run. - However, at least one argument must be a list, so that the number of times that the function has to be ran is defined. - - Arguments - ---------- - method: func - The method to be executed. - objs: object - The object to which we need to apply the method (e.g. a plot) - - Can also be a list of objects (see this function's description). - argsList: array-like - An array of arguments that have to be passed to the executed function. - - Can also be a list of arrays (see this function's description). - - WARNING: Currently it only works properly for a list of arrays. Didn't fix this because the lack of interest - of argsList on Plot's methods (everything is passed as keyword arguments). - kwargsList: dict - A dictionary with the keyword arguments that have to be passed to the executed function. - - If the executed function is a Plot's method, these can be the settings, for example. - - Can also be a list of dicts (see this function's description). - - Returns - ---------- - plots: list - A list with all the initialized plots. - This list is ordered, so plots[0] is the plot initialized with argsList[0] and kwargsList[0]. - """ - - return run_multiple(_apply_method, method, objs, argsList = argsList, kwargsList = kwargsList, **kwargs) - - -def repeat_if_children(method): - """ Decorator that will force a method to be run on all the plot's children in case there are any """ - - def apply_to_all_plots(obj, *args, children_sel=None, **kwargs): - - if hasattr(obj, "children"): - - kwargs_list = kwargs.get("kwargs_list", kwargs) - - if isinstance(children_sel, int): - children_sel = [children_sel] - - # Get all the child plots that we are going to modify - children = obj.children - if children_sel is not None: - children = np.array(children)[children_sel].tolist() - else: - children_sel = range(len(children)) - - new_children = apply_method_on_multiple_objs(method, children, kwargsList=kwargs_list, serial=True) - - # Set the new plots. We need to do this because apply_method_on_multiple_objs - # can use multiprocessing, and therefore will not modify the plot in place. - for i, new_child in zip(children_sel, new_children): - obj.children[i] = new_child - - obj.get_figure() - - else: - - return method(obj, *args, **kwargs) - - return apply_to_all_plots - -#------------------------------------- -# Fun stuff -#------------------------------------- - -# TODO these would be ideal to put in the sisl configdir so users can -# alter the commands used ;) -# However, not really needed now. - - -def trigger_notification(title, message, sound="Submarine"): - """ Triggers a notification. - - Will not do anything in Windows (oops!) - - Parameters - ----------- - title: str - message: str - sound: str - """ - - if sys.platform == 'linux': - os.system(f"""notify-send "{title}" "{message}" """) - elif sys.platform == 'darwin': - sound_string = f'sound name "{sound}"' if sound else '' - os.system(f"""osascript -e 'display notification "{message}" with title "{title}" {sound_string}' """) - else: - info(f"sisl cannot issue notifications through the operating system ({sys.platform})") - - -def spoken_message(message): - """ Trigger a spoken message. - - In linux espeak must be installed (sudo apt-get install espeak) - - Will not do anything in Windows (oops!) - - Parameters - ----------- - title: str - message: str - sound: str - """ - - if sys.platform == 'linux': - os.system(f"""espeak -s 150 "{message}" 2>/dev/null""") - elif sys.platform == 'darwin': - os.system(f"""osascript -e 'say "{message}"' """) - else: - info(f"sisl cannot issue notifications through the operating system ({sys.platform})") - -#------------------------------------- -# Plot manipulation -#------------------------------------- - - -def shift_trace(trace, shift, axis="y"): - """ Shifts a trace by a given value in the given axis. - - Parameters - ----------- - shift: float or array-like - If it's a float, it will be a solid shift (i.e. all points moved equally). - If it's an array, an element-wise sum will be performed - axis: {"x","y","z"}, optional - The axis along which we want to shift the traces. - """ - trace[axis] = np.array(trace[axis]) + shift - - -def normalize_trace(trace, min_val=0, max_val=1, axis='y'): - """ Normalizes a trace to a given range along an axis. - - Parameters - ----------- - min_val: float, optional - The lower bound of the range. - max_val: float, optional - The upper part of the range - axis: {"x", "y", "z"}, optional - The axis along which we want to normalize. - """ - t = np.array(trace[axis]) - tmin = t.min() - trace[axis] = (t - tmin) / (t.max() - tmin) * (max_val - min_val) + min_val - - -def swap_trace_axes(trace, ax1='x', ax2='y'): - """ Swaps two axes of a trace. - - Parameters - ----------- - ax1, ax2: str, {'x', 'x*', 'y', 'y*', 'z', 'z*'} - The names of the axes that you want to swap. - """ - ax1_data = trace[ax1] - trace[ax1] = trace[ax2] - trace[ax2] = ax1_data - #------------------------------------- # Colors diff --git a/src/sisl/viz/processors/__init__.py b/src/sisl/viz/processors/__init__.py new file mode 100644 index 0000000000..94016c4463 --- /dev/null +++ b/src/sisl/viz/processors/__init__.py @@ -0,0 +1,5 @@ +# from .bands import * +# from .fatbands import * +# from .pdos import * +# from .geometry import * +# from .grid import * \ No newline at end of file diff --git a/src/sisl/viz/processors/atom.py b/src/sisl/viz/processors/atom.py new file mode 100644 index 0000000000..568a8446c8 --- /dev/null +++ b/src/sisl/viz/processors/atom.py @@ -0,0 +1,97 @@ +from collections import defaultdict +from typing import Any, Callable, Optional, Sequence, TypedDict, Union + +import numpy as np +import xarray as xr +from xarray import DataArray, Dataset + +from sisl import Geometry +from sisl.messages import SislError + +from .xarray import Group, group_reduce + + +class AtomsGroup(Group, total=False): + name: str + atoms: Any + reduce_func: Optional[Callable] + +def reduce_atom_data(atom_data: Union[DataArray, Dataset], groups: Sequence[AtomsGroup], geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, atom_dim: str = "atom", groups_dim: str = "group", + sanitize_group: Callable = lambda x: x, group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, fill_empty: Any = 0. +) -> Union[DataArray, Dataset]: + """Groups contributions of atoms into a new dimension. + + Given an xarray object containing atom information and the specification of groups of atoms, this function + computes the total contribution for each group of atoms. It therefore removes the atoms dimension and + creates a new one to account for the groups. + + Parameters + ---------- + atom_data : DataArray or Dataset + The xarray object to reduce. + groups : Sequence[AtomsGroup] + A sequence containing the specifications for each group of atoms. See ``AtomsGroup``. + geometry : Geometry, optional + The geometry object that will be used to parse atom specifications into actual atom indices. Knowing the + geometry therefore allows you to specify more complex selections. + If not provided, it will be searched in the ``geometry`` attribute of the ``atom_data`` object. + reduce_func : Callable, optional + The function that will compute the reduction along the atoms dimension once the selection is done. + This could be for example ``numpy.mean`` or ``numpy.sum``. + Notice that this will only be used in case the group specification doesn't specify a particular function + in its "reduce_func" field, which will take preference. + spin_reduce: Callable, optional + The function that will compute the reduction along the spin dimension once the selection is done. + orb_dim: str, optional + Name of the dimension that contains the atom indices in ``atom_data``. + groups_dim: str, optional + Name of the new dimension that will be created for the groups. + sanitize_group: Callable, optional + A function that will be used to sanitize the group specification before it is used. + group_vars: Sequence[str], optional + If set, this argument specifies extra variables that depend on the group and the user would like to + introduce in the new xarray object. These variables will be searched as fields for each group specification. + A data variable will be created for each group_var and they will be added to the final xarray object. + Note that this forces the returned object to be a Dataset, even if the input data is a DataArray. + drop_empty: bool, optional + If set to `True`, group specifications that do not correspond to any atom will not appear in the final + returned object. + fill_empty: Any, optional + If ``drop_empty`` is set to ``False``, this argument specifies the value to use for group specifications + that do not correspond to any atom. + """ + # If no geometry was provided, then get it from the attrs of the xarray object. + if geometry is None: + geometry = atom_data.attrs.get("geometry") + + if geometry is None: + def _sanitize_group(group): + group = group.copy() + group = sanitize_group(group) + atoms = group['atoms'] + try: + group['atoms'] = np.array(atoms, dtype=int) + assert atoms.ndim == 1 + except: + raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't" + f" convert the provided atom selection ({atoms}) to an array of integers.") + + group['selector'] = group['atoms'] + + return group + else: + def _sanitize_group(group): + group = group.copy() + group = sanitize_group(group) + group["atoms"] = geometry._sanitize_atoms(group["atoms"]) + group['selector'] = group['atoms'] + return group + + + return group_reduce( + data=atom_data, groups=groups, reduce_dim=atom_dim, reduce_func=reduce_func, + groups_dim=groups_dim, sanitize_group=_sanitize_group, group_vars=group_vars, + drop_empty=drop_empty, fill_empty=fill_empty + ) \ No newline at end of file diff --git a/src/sisl/viz/processors/axes.py b/src/sisl/viz/processors/axes.py new file mode 100644 index 0000000000..e37e6797d0 --- /dev/null +++ b/src/sisl/viz/processors/axes.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import re +from typing import Callable, List, Optional, Sequence, Union + +import numpy as np + +from sisl.utils import direction + + +def sanitize_axis(ax) -> Union[str, int, np.ndarray]: + if isinstance(ax, str): + if re.match("[+-]?[012]", ax): + ax = ax.replace("0", "a").replace("1", "b").replace("2", "c") + ax = ax.lower().replace("+", "") + elif isinstance(ax, int): + ax = 'abc'[ax] + elif isinstance(ax, (list, tuple)): + ax = np.array(ax) + + # Now perform some checks + invalid = True + if isinstance(ax, str): + invalid = not re.match("-?[xyzabc]", ax) + elif isinstance(ax, np.ndarray): + invalid = ax.shape != (3,) + + if invalid: + raise ValueError(f"Incorrect axis passed. Axes must be one of [+-]('x', 'y', 'z', 'a', 'b', 'c', '0', '1', '2', 0, 1, 2)" + + " or a numpy array/list/tuple of shape (3, )") + + return ax + +def sanitize_axes(val: Union[str, Sequence[Union[str, int, np.ndarray]]]) -> List[Union[str, int, np.ndarray]]: + if isinstance(val, str): + val = re.findall("[+-]?[xyzabc012]", val) + return [sanitize_axis(ax) for ax in val] + +def get_ax_title(ax: Union[Axis, Callable], cartesian_units: str = "Ang") -> str: + """Generates the title for a given axis""" + if hasattr(ax, "__name__"): + title = ax.__name__ + elif isinstance(ax, np.ndarray) and ax.shape == (3,): + title = str(ax) + elif not isinstance(ax, str): + title = "" + elif re.match("[+-]?[xXyYzZ]", ax): + title = f'{ax.upper()} axis [{cartesian_units}]' + elif re.match("[+-]?[aAbBcC]", ax): + title = f'{ax.upper()} lattice vector' + else: + title = ax + + return title + +def axis_direction(ax: Axis, cell: Optional[Union[npt.ArrayLike, Lattice]] = None) -> npt.NDArray[np.float64]: + """Returns the vector direction of a given axis. + + Parameters + ---------- + ax: Axis + Axis specification for which you want the direction. It supports + negative signs (e.g. "-x"), which will invert the direction. + cell: array-like of shape (3, 3) or Lattice, optional + The cell of the structure, only needed if lattice vectors {"a", "b", "c"} + are provided for `ax`. + + Returns + ---------- + np.ndarray of shape (3, ) + The direction of the axis. + """ + if isinstance(ax, (int, str)): + sign = 1 + # If the axis contains a -, we need to mirror the direction. + if isinstance(ax, str) and ax[0] == "-": + sign = -1 + ax = ax[1] + ax = sign * direction(ax, abc=cell, xyz=np.diag([1., 1., 1.])) + + return ax + +def axes_cross_product(v1: Axis, v2: Axis, cell: Optional[Union[npt.ArrayLike, Lattice]] = None): + """An enhanced version of the cross product. + + It is an enhanced version because both vectors accept strings that represent + the cartesian axes or the lattice vectors (see `v1`, `v2` below). It has been built + so that cross product between lattice vectors (-){"a", "b", "c"} follows the same rules + as (-){"x", "y", "z"} + + Parameters + ---------- + v1, v2: array-like of shape (3,) or (-){"x", "y", "z", "a", "b", "c"} + The vectors to take the cross product of. + cell: array-like of shape (3, 3) + The cell of the structure, only needed if lattice vectors {"a", "b", "c"} + are passed for `v1` and `v2`. + """ + # Make abc follow the same rules as xyz to find the orthogonal direction + # That is, a X b = c; -a X b = -c and so on. + if isinstance(v1, str) and isinstance(v2, str): + if re.match("([+-]?[abc]){2}", v1 + v2): + v1 = v1.replace("a", "x").replace("b", "y").replace("c", "z") + v2 = v2.replace("a", "x").replace("b", "y").replace("c", "z") + ort = axes_cross_product(v1, v2) + ort_ax = "abc"[np.where(ort != 0)[0][0]] + if ort.sum() == -1: + ort_ax = "-" + ort_ax + return axis_direction(ort_ax, cell) + + # If the vectors are not abc, we just need to take the cross product. + return np.cross(axis_direction(v1, cell), axis_direction(v2, cell)) + diff --git a/src/sisl/viz/processors/bands.py b/src/sisl/viz/processors/bands.py new file mode 100644 index 0000000000..f06a3cb935 --- /dev/null +++ b/src/sisl/viz/processors/bands.py @@ -0,0 +1,367 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +import itertools +from typing import List, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +import xarray as xr + +from ..plotters import plot_actions + + +def filter_bands( + bands_data: xr.Dataset, + Erange: Optional[Tuple[float, float]] = None, + E0: float = 0, + bands_range: Optional[Tuple[int, int]] = None, + spin: Optional[int] = None +) -> xr.Dataset: + filtered_bands = bands_data.copy() + # Shift the energies according to the reference energy, while keeping the + # attributes (which contain the units, amongst other things) + filtered_bands['E'] = bands_data.E - E0 + continous_bands = filtered_bands.dropna("k", how="all") + + # Get the bands that matter for the plot + if Erange is None: + if bands_range is None: + # If neither E range or bands_range was provided, we will just plot the 15 bands below and above the fermi level + CB = int(continous_bands.E.where(continous_bands.E <= 0).argmax('band').max()) + bands_range = [int(max(continous_bands["band"].min(), CB - 15)), int(min(continous_bands["band"].max() + 1, CB + 16))] + + filtered_bands = filtered_bands.sel(band=slice(*bands_range)) + continous_bands = filtered_bands.dropna("k", how="all") + + # This is the new Erange + # Erange = np.array([float(f'{val:.3f}') for val in [float(continous_bands.E.min() - 0.01), float(continous_bands.E.max() + 0.01)]]) + else: + filtered_bands = filtered_bands.where((filtered_bands <= Erange[1]) & (filtered_bands >= Erange[0])).dropna("band", "all") + continous_bands = filtered_bands.dropna("k", how="all") + + # This is the new bands range + #bands_range = [int(continous_bands['band'].min()), int(continous_bands['band'].max())] + + # Give the filtered bands the same attributes as the full bands + filtered_bands.attrs = bands_data.attrs + + filtered_bands.E.attrs = bands_data.E.attrs + filtered_bands.E.attrs['E0'] = filtered_bands.E.attrs.get('E0', 0) + E0 + + # Let's treat the spin if the user requested it + if not isinstance(spin, (int, type(None))): + if len(spin) > 0: + spin = spin[0] + else: + spin = None + + if spin is not None: + # Only use the spin setting if there is a spin index + if "spin" in filtered_bands.coords: + filtered_bands = filtered_bands.sel(spin=spin) + + return filtered_bands + +def style_bands( + bands_data: xr.Dataset, + bands_style: dict = {"color": "black", "width": 1}, + spindown_style: dict = {"color": "blue", "width": 1} +) -> xr.Dataset: + """Returns the bands dataset, with the style information added to it. + + Parameters + ------------ + bands_data: xr.Dataset + The dataset containing bands energy information. + bands_style: dict + Dictionary containing the style information for the bands. + spindown_style: dict + Dictionary containing the style information for the spindown bands. + Any style that is not present in this dictionary will be taken from + the "bands_style" dictionary. + """ + # If the user provided a styler function, apply it. + if bands_style.get("styler") is not None: + if callable(bands_style['styler']): + bands_data = bands_style['styler'](data=bands_data) + + # Include default styles in bands_style, only if they are not already + # present in the bands dataset (e.g. because the styler included them) + default_styles = {'color': 'black', 'width': 1, 'opacity': 1} + for key in default_styles: + if key not in bands_data.data_vars and key not in bands_style: + bands_style[key] = default_styles[key] + + # If some key in bands_style is a callable, apply it + for key in bands_style: + if callable(bands_style[key]): + bands_style[key] = bands_style[key](data=bands_data) + + # Build the style dataarrays + if 'spin' in bands_data.dims: + spindown_style = {**bands_style, **spindown_style} + style_arrays = {} + for key in ['color', 'width', 'opacity']: + if isinstance(bands_style[key], xr.DataArray): + if not isinstance(spindown_style[key], xr.DataArray): + down_style = bands_style[key].copy(deep=True) + down_style.values[:] = spindown_style[key] + spindown_style[key] = down_style + + style_arrays[key] = xr.concat([bands_style[key], spindown_style[key]], dim='spin') + else: + style_arrays[key] = xr.DataArray([bands_style[key], spindown_style[key]], dims=['spin']) + else: + style_arrays = {} + for key in ['color', 'width', 'opacity']: + style_arrays[key] = xr.DataArray(bands_style[key]) + + # Merge the style arrays with the bands dataset and return the styled dataset + return bands_data.assign(style_arrays) + +def calculate_gap(bands_data: xr.Dataset) -> dict: + bands_E = bands_data.E + # Calculate the band gap to store it + shifted_bands = bands_E + above_fermi = bands_E.where(shifted_bands > 0) + below_fermi = bands_E.where(shifted_bands < 0) + CBbot = above_fermi.min() + VBtop = below_fermi.max() + + CB = above_fermi.where(above_fermi == CBbot, drop=True).squeeze() + VB = below_fermi.where(below_fermi == VBtop, drop=True).squeeze() + + gap = float(CBbot - VBtop) + + return { + 'gap': gap, + 'k': (VB["k"].values, CB['k'].values), + 'bands': (VB["band"].values, CB["band"].values), + 'spin': (VB["spin"].values, CB["spin"].values) if bands_data.attrs['spin'].is_polarized else (0, 0), + 'Es': (float(VBtop), float(CBbot)) + } + +def sanitize_k(bands_data: xr.Dataset, k: Union[float, str]) -> Optional[float]: + """Returns the float value of a k point in the plot. + + Parameters + ------------ + bands_data: xr.Dataset + The dataset containing bands energy information. + k: float or str + The k point that you want to sanitize. + If it can be parsed into a float, the result of `float(k)` will be returned. + If it is a string and it is a label of a k point, the corresponding k value for that + label will be returned + + Returns + ------------ + float + The sanitized k value. + """ + san_k = None + + try: + san_k = float(k) + except ValueError: + if 'axis' in bands_data.k.attrs and bands_data.k.attrs['axis'].get('ticktext') is not None: + ticktext = bands_data.k.attrs['axis']['ticktext'] + tickvals = bands_data.k.attrs['axis']['tickvals'] + if k in ticktext: + i_tick = ticktext.index(k) + san_k = tickvals[i_tick] + else: + pass + # raise ValueError(f"We can not interpret {k} as a k-location in the current bands plot") + # This should be logged instead of raising the error + + return san_k + +def get_gap_coords( + bands_data: xr.Dataset, + bands: Tuple[int, int], + from_k: Union[float, str], + to_k: Optional[Union[float, str]] = None, + spin: int = 0 +) -> Tuple[Tuple[float, float], Tuple[float, float]]: + """Calculates the coordinates of a gap given some k values. + + Parameters + ----------- + bands_data: xr.Dataset + The dataset containing bands energy information. + bands: array-like of int + Length 2 array containing the band indices of the gap. + from_k: float or str + The k value where you want the gap to start (bottom limit). + If "to_k" is not provided, it will be interpreted also as the top limit. + If a k-value is a float, it will be directly interpreted + as the position in the graph's k axis. + If a k-value is a string, it will be attempted to be parsed + into a float. If not possible, it will be interpreted as a label + (e.g. "Gamma"). + to_k: float or str, optional + same as "from_k" but in this case represents the top limit. + If not provided, "from_k" will be used. + spin: int, optional + the spin component where you want to draw the gap. Has no effect + if the bands are not spin-polarized. + + Returns + ----------- + tuple + A tuple containing (k_values, E_values) + """ + if to_k is None: + to_k = from_k + + ks = [None, None] + # Parse the names of the kpoints into their numeric values + # if a string was provided. + for i, val in enumerate((from_k, to_k)): + ks[i] = sanitize_k(bands_data, val) + + VB, CB = bands + spin_bands = bands_data.E.sel(spin=spin) if "spin" in bands_data.coords else bands_data.E + Es = [spin_bands.dropna("k", "all").sel(k=k, band=band, method="nearest") for k, band in zip(ks, (VB, CB))] + # Get the real values of ks that have been obtained + # because we might not have exactly the ks requested + ks = tuple(np.ravel(E.k)[0] for E in Es) + Es = tuple(np.ravel(E)[0] for E in Es) + + return ks, Es + +def draw_gaps( + bands_data: xr.Dataset, + gap: bool, gap_info: dict, gap_tol: float, + gap_color: Optional[str], gap_marker: Optional[dict], + direct_gaps_only: bool, + custom_gaps: Sequence[dict], + E_axis: Literal["x", "y"] +) -> List[dict]: + """Returns the drawing actions to draw gaps. + + Parameters + ------------ + bands_data: xr.Dataset + The dataset containing bands energy information. + gap: bool + Whether to draw the minimum gap passed as gap_info or not. + gap_info: dict + Dictionary containing the information of the minimum gap, + as returned by `calculate_gap`. + gap_tol: float + Tolerance in k to consider that two gaps are the same. + gap_color: str or None + Color of the line that draws the gap. + gap_marker: str or None + Marker specification of the limits of the gap. + direct_gaps_only: bool + Whether to draw the minimum gap only if it is a direct gap. + custom_gaps: list of dict + List of custom gaps to draw. Each dict can contain the keys: + - "from": the k value where the gap starts. + - "to": the k value where the gap ends. If not present, equal to "from". + - "spin": For which spin component do you want to draw the gap + (has effect only if spin is polarized). Optional. If None and the bands + are polarized, the gap will be drawn for both spin components. + - "color": Color of the line that draws the gap. Optional. + - "marker": Marker specification for the limits of the gap. Optional. + E_axis: Literal["x", "y"] + Axis where the energy is plotted. + """ + draw_actions = [] + + # Draw gaps + if gap: + + gapKs = [np.atleast_1d(k) for k in gap_info['k']] + + # Remove "equivalent" gaps + def clear_equivalent(ks): + if len(ks) == 1: + return ks + + uniq = [ks[0]] + for k in ks[1:]: + if abs(min(np.array(uniq) - k)) > gap_tol: + uniq.append(k) + return uniq + + all_gapKs = itertools.product(*[clear_equivalent(ks) for ks in gapKs]) + + for gap_ks in all_gapKs: + + if direct_gaps_only and abs(gap_ks[1] - gap_ks[0]) > gap_tol: + continue + + ks, Es = get_gap_coords(bands_data, gap_info['bands'], *gap_ks, spin=gap_info.get('spin', [0])[0]) + name = "Gap" + + draw_actions.append( + draw_gap(ks, Es, color=gap_color, name=name, marker=gap_marker, E_axis=E_axis) + ) + + # Draw the custom gaps. These are gaps that do not necessarily represent + # the maximum and the minimum of the VB and CB. + for custom_gap in custom_gaps: + + requested_spin = custom_gap.get("spin", None) + if requested_spin is None: + requested_spin = [0, 1] + + avail_spins = bands_data.get("spin", [0]) + + for spin in avail_spins: + if spin in requested_spin: + from_k = custom_gap["from"] + to_k = custom_gap.get("to", from_k) + color = custom_gap.get("color", None) + name = f"Gap ({from_k}-{to_k})" + ks, Es = get_gap_coords(bands_data, gap_info['bands'], from_k, to_k, spin=spin) + + draw_actions.append( + draw_gap(ks, Es, color=color, name=name, marker=custom_gap.get("marker", {}), E_axis=E_axis) + ) + + return draw_actions + +def draw_gap( + ks: Tuple[float, float], + Es: Tuple[float, float], + color: Optional[str] = None, marker: dict = {}, + name: str = "Gap", + E_axis: Literal["x", "y"] = "y" +) -> dict: + """Returns the drawing action to draw a gap. + + Parameters + ------------ + ks: tuple of float + The k values where the gap starts and ends. + Es: tuple of float + The energy values where the gap starts and ends. + color: str or None + Color of the line that draws the gap. + marker: dict + Marker specification for the limits of the gap. + name: str + Name to give to the line that draws the gap. + E_axis: Literal["x", "y"] + Axis where the energy is plotted. + """ + if E_axis == "x": + coords = {"x": Es, "y": ks} + elif E_axis == "y": + coords = {"y": Es, "x": ks} + else: + raise ValueError(f"E_axis must be either 'x' or 'y', but was {E_axis}") + + return plot_actions.draw_line(**{ + **coords, + 'text': [f'Gap: {Es[1] - Es[0]:.3f} eV', ''], + 'name': name, + 'textposition': 'top right', + 'marker': {"size": 7, 'color': color, **marker}, + 'line': {'color': color}, + }) \ No newline at end of file diff --git a/src/sisl/viz/processors/cell.py b/src/sisl/viz/processors/cell.py new file mode 100644 index 0000000000..ba17decd57 --- /dev/null +++ b/src/sisl/viz/processors/cell.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import itertools +from typing import Any, List, Literal, TypedDict, Union + +import numpy as np +import numpy.typing as npt +from xarray import Dataset + +from sisl.lattice import Lattice, LatticeChild + +#from ...types import CellLike +#from .coords import project_to_axes, CoordsDataset + +#CellDataset = CoordsDataset + +def is_cartesian_unordered(cell: CellLike, tol: float = 1e-3) -> bool: + """Whether a cell has cartesian axes as lattice vectors, regardless of their order. + + Parameters + ----------- + cell: np.array of shape (3, 3) + The cell that you want to check. + tol: float, optional + Threshold value to consider a component of the cell nonzero. + """ + if isinstance(cell, (Lattice, LatticeChild)): + cell = cell.cell + + bigger_than_tol = abs(cell) > tol + return bigger_than_tol.sum() == 3 and bigger_than_tol.any(axis=0).all() and bigger_than_tol.any(axis=1).all() + +def is_1D_cartesian(cell: CellLike, coord_ax: Literal["x", "y", "z"], tol: float = 1e-3) -> bool: + """Whether a cell contains only one vector that contributes only to a given coordinate. + + That is, one vector follows the direction of the cartesian axis and the other vectors don't + have any component in that direction. + + Parameters + ----------- + cell: np.array of shape (3, 3) + The cell that you want to check. + coord_ax: {"x", "y", "z"} + The cartesian axis that you are looking for in the cell. + tol: float, optional + Threshold value to consider a component of the cell nonzero. + """ + if isinstance(cell, (Lattice, LatticeChild)): + cell = cell.cell + + coord_index = "xyz".index(coord_ax) + lattice_vecs = np.where(cell[:, coord_index] > tol)[0] + + is_1D_cartesian = lattice_vecs.shape[0] == 1 + return is_1D_cartesian and (cell[lattice_vecs[0]] > tol).sum() == 1 + +def infer_cell_axes(cell: CellLike, axes: List[str], tol: float = 1e-3) -> List[int]: + """Returns the indices of the lattice vectors that correspond to the given axes.""" + if isinstance(cell, (Lattice, LatticeChild)): + cell = cell.cell + + grid_axes = [] + for ax in axes: + if ax in ("x", "y", "z"): + coord_index = "xyz".index(ax) + lattice_vecs = np.where(cell[:, coord_index] > tol)[0] + if lattice_vecs.shape[0] != 1: + raise ValueError(f"There are {lattice_vecs.shape[0]} lattice vectors that contribute to the {'xyz'[coord_index]} coordinate.") + grid_axes.append(lattice_vecs[0]) + else: + grid_axes.append("abc".index(ax)) + + return grid_axes + +def gen_cell_dataset(lattice: Union[Lattice, LatticeChild]) -> CellDataset: + """Generates a dataset with the vertices of the cell.""" + if isinstance(lattice, LatticeChild): + lattice = lattice.lattice + + return Dataset( + {"xyz": (("a", "b", "c", "axis"), lattice.vertices())}, + coords={"a": [0,1], "b": [0, 1], "c": [0, 1], "axis": [0,1,2]}, + attrs={ + "lattice": lattice + } + ) + +class CellStyleSpec(TypedDict): + color: Any + width: Any + opacity: Any + +class PartialCellStyleSpec(TypedDict, total=False): + color: Any + width: Any + opacity: Any + +def cell_to_lines(cell_data: CellDataset, how: Literal["box", "axes"], cell_style: PartialCellStyleSpec = {}) -> CellDataset: + """Converts a cell dataset to lines that should be plotted. + + Parameters + ----------- + cell_data: xr.Dataset + The cell dataset, containing the vertices of the cell. + how: {"box", "axes"} + Whether to draw the cell as a box or as axes. + This determines how many points are needed to draw the cell + using lines, and where those points are located. + cell_style: dict, optional + Style of the cell lines. A dictionary optionally containing + the keys "color", "width" and "opacity". + """ + cell_data = cell_data.reindex(a=[0,1,2], b=[0,1,2], c=[0,1,2]) + + if how == "box": + verts = np.array([ + (0, 0, 0), (0, 1, 0), (1, 1, 0), (1, 1, 1), (0, 1, 1), (0, 1, 0), + (2, 2, 2), + (0, 1, 1), (0, 0, 1), (0, 0, 0), (1, 0, 0), (1, 0, 1), (0, 0, 1), + (2, 2, 2), + (1, 1, 0), (1, 0, 0), + (2, 2, 2), + (1, 1, 1), (1, 0, 1) + ]) + + elif how == "axes": + verts = np.array([ + (0, 0, 0), (1, 0, 0), (2, 2, 2), + (0, 0, 0), (0, 1, 0), (2, 2, 2), + (0, 0, 0), (0, 0, 1), (2, 2, 2), + ]) + else: + raise ValueError(f"'how' argument must be either 'box' or 'axes', but got {how}") + + xyz = cell_data.xyz.values[verts[:,0],verts[:,1],verts[:,2]] + + cell_data = cell_data.assign({ + "xyz": (("point_index", "axis"), xyz), + "color": cell_style.get("color"), + "width": cell_style.get("width"), + "opacity": cell_style.get("opacity"), + }) + + return cell_data diff --git a/src/sisl/viz/processors/coords.py b/src/sisl/viz/processors/coords.py new file mode 100644 index 0000000000..3c77e1f206 --- /dev/null +++ b/src/sisl/viz/processors/coords.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import re +from typing import Callable, Dict, Optional, Union + +import numpy as np +import numpy.typing as npt +from xarray import Dataset + +from sisl._lattice import cell_invert +from sisl.lattice import Lattice, LatticeChild +from sisl.utils.mathematics import fnorm + +from .axes import axes_cross_product, axis_direction, get_ax_title + +#from ...types import Axes, CellLike, Axis + +CoordsDataset = Dataset + +def projected_2Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], xaxis: Axis = "x", yaxis: Axis = "y") -> npt.NDArray[np.float64]: + """Moves the 3D positions of the atoms to a 2D supspace. + + In this way, we can plot the structure from the "point of view" that we want. + + NOTE: If xaxis/yaxis is one of {"a", "b", "c", "1", "2", "3"} the function doesn't + project the coordinates in the direction of the lattice vector. The fractional + coordinates, taking in consideration the three lattice vectors, are returned + instead. + + Parameters + ------------ + geometry: sisl.Geometry + the geometry for which you want the projected coords + xyz: array-like of shape (natoms, 3), optional + the 3D coordinates that we want to project. + otherwise they are taken from the geometry. + xaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional + the direction to be displayed along the X axis. + yaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional + the direction to be displayed along the X axis. + + Returns + ---------- + np.ndarray of shape (2, natoms) + the 2D coordinates of the geometry, with all positions projected into the plane + defined by xaxis and yaxis. + """ + if isinstance(cell, (Lattice, LatticeChild)): + cell = cell.cell + + try: + all_lattice_vecs = len(set([xaxis, yaxis]).intersection(["a", "b", "c"])) == 2 + except: + # If set fails it is because xaxis/yaxis is unhashable, which means it + # is a numpy array + all_lattice_vecs = False + + if all_lattice_vecs: + coord_indices = ["abc".index(ax) for ax in (xaxis, yaxis)] + + icell = cell_invert(cell.astype(float)) + else: + # Get the directions that these axes represent + xaxis = axis_direction(xaxis, cell) + yaxis = axis_direction(yaxis, cell) + + fake_cell = np.array([xaxis, yaxis, np.cross(xaxis, yaxis)], dtype=np.float64) + icell = cell_invert(fake_cell) + coord_indices = [0, 1] + + return np.dot(xyz, icell.T)[..., coord_indices] + +def projected_1Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], axis: Axis = "x"): + """ + Moves the 3D positions of the atoms to a 2D supspace. + + In this way, we can plot the structure from the "point of view" that we want. + + NOTE: If axis is one of {"a", "b", "c", "1", "2", "3"} the function doesn't + project the coordinates in the direction of the lattice vector. The fractional + coordinates, taking in consideration the three lattice vectors, are returned + instead. + + Parameters + ------------ + geometry: sisl.Geometry + the geometry for which you want the projected coords + xyz: array-like of shape (natoms, 3), optional + the 3D coordinates that we want to project. + otherwise they are taken from the geometry. + axis: {"x", "y", "z", "a", "b", "c", "1", "2", "3"} or array-like of shape 3, optional + the direction to be displayed along the X axis. + nsc: array-like of shape (3, ), optional + only used if `axis` is a lattice vector. It is used to rescale everything to the unit + cell lattice vectors, otherwise `GeometryPlot` doesn't play well with `GridPlot`. + + Returns + ---------- + np.ndarray of shape (natoms, ) + the 1D coordinates of the geometry, with all positions projected into the line + defined by axis. + """ + if isinstance(cell, (Lattice, LatticeChild)): + cell = cell.cell + + if isinstance(axis, str) and axis in ("a", "b", "c", "0", "1", "2"): + return projected_2Dcoords(cell, xyz, xaxis=axis, yaxis="a" if axis == "c" else "c")[..., 0] + + # Get the direction that the axis represents + axis = axis_direction(axis, cell) + + return xyz.dot(axis/fnorm(axis)) / fnorm(axis) + +def coords_depth( + coords_data: CoordsDataset, + axes: Axes +) -> npt.NDArray[np.float64]: + """Computes the depth of 3D points as projected in a 2D plane + + Parameters + ---------- + coords_data: CoordsDataset + The coordinates for which the depth is to be computed. + axes: Axes + The axes that define the plane where the coordinates are projected. + """ + cell = _get_cell_from_dataset(coords_data=coords_data) + + depth_vector = axes_cross_product(axes[0], axes[1], cell) + depth = project_to_axes(coords_data, axes=[depth_vector]).x.values + + return depth + +def sphere( + center: npt.ArrayLike = [0, 0, 0], + r: float = 1, + vertices: int = 10 +) -> Dict[str, np.ndarray]: + """Computes a mesh defining a sphere.""" + phi, theta = np.mgrid[0.0:np.pi: 1j*vertices, 0.0:2.0*np.pi: 1j*vertices] + center = np.array(center) + + phi = np.ravel(phi) + theta = np.ravel(theta) + + x = center[0] + r*np.sin(phi)*np.cos(theta) + y = center[1] + r*np.sin(phi)*np.sin(theta) + z = center[2] + r*np.cos(phi) + + return {'x': x, 'y': y, 'z': z} + +def _get_cell_from_dataset(coords_data: CoordsDataset) -> npt.NDArray[np.float64]: + cell = coords_data.attrs.get("cell") + if cell is None: + if "lattice" in coords_data.attrs: + cell = coords_data.lattice.cell + else: + cell = coords_data.geometry.cell + + return cell + +def projected_1D_data(coords_data: CoordsDataset, axis: Axis = "x", dataaxis_1d: Union[Callable, npt.NDArray, None] = None) -> CoordsDataset: + cell = _get_cell_from_dataset(coords_data=coords_data) + + xyz = coords_data.xyz.values + + x = projected_1Dcoords(cell, xyz=xyz, axis=axis) + + dims = coords_data.xyz.dims[:-1] + + if dataaxis_1d is None: + y = np.zeros_like(x) + else: + if callable(dataaxis_1d): + y = dataaxis_1d(x) + elif isinstance(dataaxis_1d, (int, float)): + y = np.full_like(x, dataaxis_1d) + else: + y = dataaxis_1d + + coords_data = coords_data.assign(x=(dims, x), y=(dims, y)) + + return coords_data + +def projected_2D_data(coords_data: CoordsDataset, xaxis: Axis = "x", yaxis: Axis = "y", sort_by_depth: bool = False) -> CoordsDataset: + cell = _get_cell_from_dataset(coords_data=coords_data) + + xyz = coords_data.xyz.values + + xy = projected_2Dcoords(cell, xyz, xaxis=xaxis, yaxis=yaxis) + + x, y = xy[..., 0], xy[..., 1] + dims = coords_data.xyz.dims[:-1] + + coords_data = coords_data.assign(x=(dims, x), y=(dims, y)) + + coords_data = coords_data.assign( + {"depth": (dims, coords_depth(coords_data, [xaxis, yaxis]).data)} + ) + if sort_by_depth: + coords_data = coords_data.sortby("depth") + + return coords_data + +def projected_3D_data(coords_data: CoordsDataset) -> CoordsDataset: + x, y, z = np.moveaxis(coords_data.xyz.values, -1, 0) + dims = coords_data.xyz.dims[:-1] + + coords_data = coords_data.assign(x=(dims, x), y=(dims, y), z=(dims, z)) + + return coords_data + +def project_to_axes( + coords_data: CoordsDataset, axes: Axes, + dataaxis_1d: Optional[Union[npt.ArrayLike, Callable]] = None, + sort_by_depth: bool = False, + cartesian_units: str = "Ang" +) -> CoordsDataset: + ndim = len(axes) + if ndim == 3: + xaxis, yaxis, zaxis = axes + coords_data = projected_3D_data(coords_data) + elif ndim == 2: + xaxis, yaxis = axes + coords_data = projected_2D_data(coords_data, xaxis=xaxis, yaxis=yaxis, sort_by_depth=sort_by_depth) + elif ndim == 1: + xaxis = axes[0] + yaxis = dataaxis_1d + coords_data = projected_1D_data(coords_data, axis=xaxis, dataaxis_1d=dataaxis_1d) + + plot_axes = ["x", "y", "z"][:ndim] + + for ax, plot_ax in zip(axes, plot_axes): + coords_data[plot_ax].attrs["axis"] = { + "title": get_ax_title(ax, cartesian_units=cartesian_units), + } + + coords_data.attrs['ndim'] = ndim + + return coords_data diff --git a/src/sisl/viz/processors/data.py b/src/sisl/viz/processors/data.py new file mode 100644 index 0000000000..ef33150285 --- /dev/null +++ b/src/sisl/viz/processors/data.py @@ -0,0 +1,25 @@ +from typing import Type, TypeVar + +from ..data import Data + +DataInstance = TypeVar("DataInstance", bound=Data) + +def accept_data(data: DataInstance, cls: Type[Data], check: bool = True) -> DataInstance: + + if not isinstance(data, cls): + raise TypeError(f"Data must be of type {cls.__name__} and was {type(data).__name__}") + + if check: + data.sanity_check() + + return data + +def extract_data(data: Data, cls: Type[Data], check: bool = True): + + if not isinstance(data, cls): + raise TypeError(f"Data must be of type {cls.__name__} and was {type(data).__name__}") + + if check: + data.sanity_check() + + return data._data \ No newline at end of file diff --git a/src/sisl/viz/processors/eigenstate.py b/src/sisl/viz/processors/eigenstate.py new file mode 100644 index 0000000000..942d4b16da --- /dev/null +++ b/src/sisl/viz/processors/eigenstate.py @@ -0,0 +1,151 @@ +from typing import Optional, Tuple, Union + +import numpy as np + +import sisl + + +def get_eigenstate(eigenstate: sisl.EigenstateElectron, i: int) -> sisl.EigenstateElectron: + """Gets the i-th wavefunction from the eigenstate. + + It takes into account if the info dictionary has an "index" key, which + might be present for example if the eigenstate object does not contain + the full set of wavefunctions, to indicate which wavefunctions are + present. + + Parameters + ---------- + eigenstate : sisl.EigenstateElectron + The eigenstate from which to extract the wavefunction. + i : int + The index of the wavefunction to extract. + """ + + if "index" in eigenstate.info: + wf_i = np.nonzero(eigenstate.info["index"] == i)[0] + if len(wf_i) == 0: + raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {eigenstate.info['index']}.") + wf_i = wf_i[0] + else: + max_index = eigenstate.shape[0] + if i > max_index: + raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}].") + wf_i = i + + return eigenstate[wf_i] + +def eigenstate_geometry(eigenstate: sisl.EigenstateElectron, geometry: Optional[sisl.Geometry] = None) -> Union[sisl.Geometry, None]: + """Returns the geometry associated with the eigenstate. + + Parameters + ---------- + eigenstate : sisl.EigenstateElectron + The eigenstate from which to extract the geometry. + geometry : sisl.Geometry, optional + If provided, this geometry is returned instead of the one associated. This is + a way to force a given geometry when using this function. + """ + if geometry is None: + geometry = getattr(eigenstate, "parent", None) + if geometry is not None and not isinstance(geometry, sisl.Geometry): + geometry = getattr(geometry, "geometry", None) + + return geometry + +def tile_if_k(geometry: sisl.Geometry, nsc: Tuple[int, int, int], eigenstate: sisl.EigenstateElectron) -> sisl.Geometry: + """Tiles the geometry if the eigenstate does not correspond to gamma. + + If we are calculating the wavefunction for any point other than gamma, + the periodicity of the WF will be bigger than the cell. Therefore, if + the user wants to see more than the unit cell, we need to generate the + wavefunction for all the supercell. + + Parameters + ---------- + geometry : sisl.Geometry + The geometry for which the wavefunction was calculated. + nsc : Tuple[int, int, int] + The number of supercells that are to be displayed in each direction. + eigenstate : sisl.EigenstateElectron + The eigenstate for which the wavefunction was calculated. + """ + + tiled_geometry = geometry + + k = eigenstate.info.get("k", (1, 1, 1) if np.iscomplexobj(eigenstate.state) else (0, 0, 0)) + + for ax, sc_i in enumerate(nsc): + if k[ax] != 0: + tiled_geometry = tiled_geometry.tile(sc_i, ax) + + return tiled_geometry + +def get_grid_nsc(nsc: Tuple[int, int, int], eigenstate: sisl.EigenstateElectron) -> Tuple[int, int, int]: + """Returns the supercell to display once the geometry is tiled. + + The geometry must be tiled if the eigenstate is not calculated at gamma, + as done by `tile_if_k`. This function returns the number of supercells + to display after that tiling. + + Parameters + ---------- + nsc : Tuple[int, int, int] + The number of supercells to be display in each direction. + eigenstate : sisl.EigenstateElectron + The eigenstate for which the wavefunction was calculated. + """ + k = eigenstate.info.get("k", (1, 1, 1) if np.iscomplexobj(eigenstate.state) else (0, 0, 0)) + + return tuple(nx if kx == 0 else 1 for nx, kx in zip(nsc, k)) + +def create_wf_grid(eigenstate: sisl.EigenstateElectron, grid_prec: float = 0.2, grid: Optional[sisl.Grid] = None, geometry: Optional[sisl.Geometry] = None) -> sisl.Grid: + """Creates a grid to display the wavefunction. + + Parameters + ---------- + eigenstate : sisl.EigenstateElectron + The eigenstate for which the wavefunction was calculated. The function uses + + grid_prec : float, optional + The precision of the grid. The grid will be created with a spacing of + `grid_prec` Angstroms. + grid : sisl.Grid, optional + If provided, this grid is returned instead of the one created. This is + a way to force a given grid when using this function. + geometry : sisl.Geometry, optional + Geometry that will be associated to the grid. Required unless the grid + is provided. + """ + if grid is None: + grid = sisl.Grid(grid_prec, geometry=geometry, dtype=eigenstate.state.dtype) + + return grid + +def project_wavefunction(eigenstate: sisl.EigenstateElectron, grid_prec: float = 0.2, grid: Optional[sisl.Grid] = None, geometry: Optional[sisl.Geometry] = None) -> sisl.Grid: + """Projects the wavefunction from an eigenstate into a grid. + + Parameters + ---------- + eigenstate : sisl.EigenstateElectron + The eigenstate for which the wavefunction was calculated. + grid_prec : float, optional + The precision of the grid. The grid will be created with a spacing of + `grid_prec` Angstroms. + grid : sisl.Grid, optional + If provided, the wavefunction is inserted into this grid instead of creating + a new one. + geometry : sisl.Geometry, optional + Geometry that will be associated to the grid. Required unless the grid + is provided. + """ + grid = create_wf_grid(eigenstate, grid_prec=grid_prec, grid=grid, geometry=geometry) + + # Ensure we are dealing with the R gauge + eigenstate.change_gauge('R') + + # Finally, insert the wavefunction values into the grid. + sisl.physics.electron.wavefunction( + eigenstate.state, grid, geometry=geometry, spinor=0, + ) + + return grid \ No newline at end of file diff --git a/src/sisl/viz/processors/geometry.py b/src/sisl/viz/processors/geometry.py new file mode 100644 index 0000000000..89431c01ac --- /dev/null +++ b/src/sisl/viz/processors/geometry.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +import itertools +from dataclasses import asdict +from typing import Any, List, Optional, Sequence, Tuple, TypedDict, Union + +import numpy as np +import numpy.typing as npt +from xarray import Dataset + +from sisl import BrillouinZone, Geometry, PeriodicTable +from sisl.messages import warn +from sisl.typing import AtomsArgument +from sisl.utils.mathematics import fnorm +from sisl.viz.types import AtomArrowSpec + +from ..data_sources.atom_data import AtomDefaultColors, AtomIsGhost, AtomPeriodicTable +from .coords import CoordsDataset, projected_1Dcoords, projected_2Dcoords + +#from ...types import AtomsArgument, GeometryLike, PathLike + +GeometryDataset = CoordsDataset +AtomsDataset = GeometryDataset +BondsDataset = GeometryDataset + +# class GeometryData(DataSource): +# pass + +# @GeometryData.from_func +# def geometry_from_file(file: PathLike) -> Geometry: +# return Geometry.new(file) + +# @GeometryData.from_func +# def geometry_from_obj(obj: GeometryLike) -> Geometry: +# return Geometry.new(obj) + +def tile_geometry(geometry: Geometry, nsc: Tuple[int, int, int]) -> Geometry: + """Tiles a geometry along the three lattice vectors. + + Parameters + ----------- + geometry: sisl.Geometry + the geometry to be tiled. + nsc: Tuple[int, int, int] + the number of repetitions along each lattice vector. + """ + + tiled_geometry = geometry.copy() + for ax, reps in enumerate(nsc): + tiled_geometry = tiled_geometry.tile(reps, ax) + + return tiled_geometry + +def find_all_bonds(geometry: Geometry, tol: float = 0.2) -> BondsDataset: + """ + Finds all bonds present in a geometry. + + Parameters + ----------- + geometry: sisl.Geometry + the structure where the bonds should be found. + tol: float + the fraction that the distance between atoms is allowed to differ from + the "standard" in order to be considered a bond. + + Return + --------- + np.ndarray of shape (nbonds, 2) + each item of the array contains the 2 indices of the atoms that participate in the + bond. + """ + pt = PeriodicTable() + + bonds = [] + for at in geometry: + neighs: npt.NDArray[np.int32] = geometry.close(at, R=[0.1, 3])[-1] + + for neigh in neighs[neighs > at]: + summed_radius = pt.radius([abs(geometry.atoms[at].Z), abs(geometry.atoms[neigh % geometry.na].Z)]).sum() + bond_thresh = (1+tol) * summed_radius + if bond_thresh > fnorm(geometry[neigh] - geometry[at]): + bonds.append([at, neigh]) + + if len(bonds) == 0: + bonds = np.empty((0, 2), dtype=np.int64) + + return Dataset({ + "bonds": (("bond_index", "bond_atom"), np.array(bonds, dtype=np.int64)) + }, + coords={"bond_index": np.arange(len(bonds)), "bond_atom": [0, 1]}, + attrs={"geometry": geometry} + ) + +def get_atoms_bonds(bonds: npt.NDArray[np.int32], atoms: npt.ArrayLike, ret_mask: bool = False) -> npt.NDArray[Union[np.float64, np.bool8]]: + """Gets the bonds where the given atoms are involved. + + Parameters + ----------- + bonds: np.ndarray of shape (nbonds, 2) + Pairs of indices of atoms that are bonded. + atoms: np.ndarray of shape (natoms,) + Indices of the atoms for which we want to keep the bonds. + """ + # For each bond, we check if one of the desired atoms is involved + mask = np.isin(bonds, atoms).any(axis=-1) + if ret_mask: + return mask + + return bonds[mask] + +def sanitize_atoms(geometry: Geometry, atoms: AtomsArgument = None) -> npt.NDArray[np.int32]: + """Sanitizes the atoms argument to a np.ndarray of shape (natoms,). + + This is the same as `geometry._sanitize_atoms` but ensuring that the + result is a numpy array of 1 dimension. + + Parameters + ----------- + geometry: sisl.Geometry + geometry that will sanitize the atoms + atoms: AtomsArgument + anything that `Geometry` can sanitize. + """ + atoms = geometry._sanitize_atoms(atoms) + return np.atleast_1d(atoms) + +def tile_data_sc(geometry_data: GeometryDataset, nsc: Tuple[int, int, int] = (1, 1, 1)) -> GeometryDataset: + """Tiles coordinates from unit cell to a supercell. + + Parameters + ----------- + geometry_data: GeometryDataset + the dataset containing the coordinates to be tiled. + nsc: np.ndarray of shape (3,) + the number of repetitions along each lattice vector. + """ + # Get the total number of supercells + total_sc = np.prod(nsc) + + xyz_shape = geometry_data.xyz.shape + + # Create a fake geometry + fake_geom = Geometry(xyz=geometry_data.xyz.values.reshape(-1, 3), + lattice=geometry_data.geometry.lattice.copy(), + atoms=1 + ) + + sc_offs = np.array(list(itertools.product(*[range(n) for n in nsc]))) + + sc_xyz = np.array([ + fake_geom.axyz(isc=sc_off) for sc_off in sc_offs + ]).reshape((total_sc, *xyz_shape)) + + # Build the new dataset + sc_atoms = geometry_data.assign({"xyz": (("isc", *geometry_data.xyz.dims), sc_xyz)}) + sc_atoms = sc_atoms.assign_coords(isc=range(total_sc)) + + return sc_atoms + +def stack_sc_data(geometry_data: GeometryDataset, newname: str, dims: Sequence[str]) -> GeometryDataset: + """Stacks the supercell coordinate with others. + + Parameters + ----------- + geometry_data: GeometryDataset + the dataset for which we want to stack the supercell coordinates. + newname: str + """ + + return geometry_data.stack(**{newname: ["isc", *dims]}).transpose(newname, ...) + +class AtomsStyleSpec(TypedDict): + color: Any + size: Any + opacity: Any + vertices: Any + +def parse_atoms_style(geometry: Geometry, atoms_style: Sequence[AtomsStyleSpec], scale: float = 1.) -> AtomsDataset: + """Parses atom style specifications to a dataset of styles. + + Parameters + ----------- + geometry: sisl.Geometry + the geometry for which the styles are parsed. + atoms_style: Sequence[AtomsStyleSpec] + the styles to be parsed. + scale: float + the scale to be applied to the size of the atoms. + """ + if isinstance(atoms_style, dict): + atoms_style = [atoms_style] + + # Add the default styles first + atoms_style = [ + { + "color": AtomDefaultColors(), + "size": AtomPeriodicTable(what="radius"), + "opacity": AtomIsGhost(fill_true=0.4, fill_false=1.), + "vertices": 15, + }, + *atoms_style + ] + + def _tile_if_needed(atoms, spec): + """Function that tiles an array style specification. + + It does so if the specification needs to be applied to more atoms + than items are in the array.""" + if isinstance(spec, (tuple, list, np.ndarray)): + n_ats = len(atoms) + n_spec = len(spec) + if n_ats != n_spec and n_ats % n_spec == 0: + spec = np.tile(spec, n_ats // n_spec) + return spec + + # Initialize the styles. + parsed_atoms_style = { + "color": np.empty((geometry.na, ), dtype=object), + "size": np.empty((geometry.na, ), dtype=float), + "vertices": np.empty((geometry.na, ), dtype=int), + "opacity": np.empty((geometry.na), dtype=float), + } + + # Go specification by specification and apply the styles + # to the corresponding atoms. + for style_spec in atoms_style: + atoms = geometry._sanitize_atoms(style_spec.get("atoms")) + for key in parsed_atoms_style: + if style_spec.get(key) is not None: + style = style_spec[key] + + if callable(style): + style = style(geometry=geometry, atoms=atoms) + + parsed_atoms_style[key][atoms] = _tile_if_needed(atoms, style) + + # Apply the scale + parsed_atoms_style['size'] = parsed_atoms_style['size'] * scale + # Convert colors to numbers if possible + try: + parsed_atoms_style['color'] = parsed_atoms_style['color'].astype(float) + except: + pass + + # Add coordinates to the values according to their unique dimensionality. + data_vars = {} + for k, value in parsed_atoms_style.items(): + if (k != "color" or value.dtype not in (float, int)): + unique = np.unique(value) + if len(unique) == 1: + data_vars[k] = unique[0] + continue + + data_vars[k] = ("atom", value) + + return Dataset( + data_vars, + coords={"atom": range(geometry.na)}, + attrs={"geometry": geometry}, + ) + +def sanitize_arrows(geometry: Geometry, arrows: Sequence[AtomArrowSpec], atoms: AtomsArgument, ndim: int, axes: Sequence[str]) -> List[dict]: + """Sanitizes a list of arrow specifications. + + Each arrow specification in the output has the atoms sanitized and + the data with the shape (natoms, ndim). + + Parameters + ---------- + geometry: sisl.Geometry + the geometry for which the arrows are sanitized. + arrows: Sequence[AtomArrowSpec] + unsanitized arrow specifications. + atoms: AtomsArgument + atoms for which we want the data. This means that data + will be filtered to only contain the atoms in this argument. + ndim: int + dimensionality of the space into which arrows must be projected. + axes: Sequence[str] + Axes onto which the arrows must be projected. + """ + atoms: np.ndarray = geometry._sanitize_atoms(atoms) + + def _sanitize_spec(arrow_spec): + arrow_spec = AtomArrowSpec(**arrow_spec) + arrow_spec = asdict(arrow_spec) + + arrow_spec["atoms"] = np.atleast_1d(geometry._sanitize_atoms(arrow_spec["atoms"])) + arrow_atoms = arrow_spec["atoms"] + + not_displayed = set(arrow_atoms) - set(atoms) + if not_displayed: + warn(f"Arrow data for atoms {not_displayed} will not be displayed because these atoms are not displayed.") + if set(atoms) == set(atoms) - set(arrow_atoms): + # Then it makes no sense to store arrows, as nothing will be drawn + return None + + arrow_data = np.full((geometry.na, ndim), np.nan, dtype=np.float64) + provided_data = np.array(arrow_spec["data"]) + + # Get the projected directions if we are not in 3D. + if ndim == 1: + provided_data = projected_1Dcoords(geometry, provided_data, axis=axes[0]) + provided_data = np.expand_dims(provided_data, axis=-1) + elif ndim == 2: + provided_data = projected_2Dcoords(geometry, provided_data, xaxis=axes[0], yaxis=axes[1]) + + arrow_data[arrow_atoms] = provided_data + arrow_spec["data"] = arrow_data[atoms] + + #arrow_spec["data"] = self._tile_atomic_data(arrow_spec["data"]) + + return arrow_spec + + if isinstance(arrows, dict): + if arrows == {}: + arrows = [] + else: + arrows = [arrows] + + san_arrows = [_sanitize_spec(arrow_spec) for arrow_spec in arrows] + + return [arrow_spec for arrow_spec in san_arrows if arrow_spec is not None] + +def add_xyz_to_dataset(dataset: AtomsDataset) -> AtomsDataset: + """Adds the xyz data variable to a dataset with associated geometry. + + The new xyz data variable contains the coordinates of the atoms. + + Parameters + ----------- + dataset: AtomsDataset + the dataset to be augmented with xyz data. + """ + geometry = dataset.attrs['geometry'] + + xyz_ds = Dataset({"xyz": (("atom", "axis"), geometry.xyz)}, coords={"axis": [0,1,2]}, attrs={"geometry": geometry}) + + return xyz_ds.merge(dataset, combine_attrs="no_conflicts") + +class BondsStyleSpec(TypedDict): + color: Any + width: Any + opacity: Any + +def style_bonds(bonds_data: BondsDataset, bonds_style: BondsStyleSpec, scale: float = 1.) -> BondsDataset: + """Adds styles to a bonds dataset. + + Parameters + ----------- + bonds_data: BondsDataset + the bonds that need to be styled. + This can come from the `find_all_bonds` function. + bonds_style: Sequence[BondsStyleSpec] + the styles to be parsed. + scale: float + the scale to be applied to the width of the bonds. + """ + geometry = bonds_data.geometry + + nbonds = bonds_data.bonds.shape[0] + + # Add the default styles first + bonds_styles: Sequence[BondsStyleSpec] = [ + { + "color": "gray", + "width": 1, + "opacity": 1, + }, + bonds_style + ] + + # Initialize the styles. + # Potentially bond styles could have two styles, one for each halve. + parsed_bonds_style = { + "color": np.empty((nbonds, ), dtype=object), + "width": np.empty((nbonds, ), dtype=float), + "opacity": np.empty((nbonds, ), dtype=float), + } + + # Go specification by specification and apply the styles + # to the corresponding bonds. Note that we still have no way of + # selecting bonds, so for now we just apply the styles to all bonds. + for style_spec in bonds_styles: + for key in parsed_bonds_style: + style = style_spec.get(key) + if style is None: + continue + if callable(style): + style = style(geometry=geometry, bonds=bonds_data.bonds) + + parsed_bonds_style[key][:] = style + + + # Apply the scale + parsed_bonds_style['width'] = parsed_bonds_style['width'] * scale + # Convert colors to float datatype if possible + try: + parsed_bonds_style['color'] = parsed_bonds_style['color'].astype(float) + except ValueError: + pass + + # Add coordinates to the values according to their unique dimensionality. + data_vars = {} + for k, value in parsed_bonds_style.items(): + if (k != "color" or value.dtype not in (float, int)): + unique = np.unique(value) + if len(unique) == 1: + data_vars[k] = unique[0] + continue + + data_vars[k] = ("bond_index", value) + + return bonds_data.assign(data_vars) + +def add_xyz_to_bonds_dataset(bonds_data: BondsDataset) -> BondsDataset: + """Adds the coordinates of the bonds endpoints to a bonds dataset. + + Parameters + ----------- + bonds_data: BondsDataset + the bonds dataset to be augmented with xyz data. + """ + geometry = bonds_data.attrs['geometry'] + + def _bonds_xyz(ds): + bonds_shape = ds.bonds.shape + bonds_xyz = geometry[ds.bonds.values.reshape(-1)].reshape((*bonds_shape, 3)) + return (("bond_index", "bond_atom", "axis"), bonds_xyz) + + return bonds_data.assign({"xyz": _bonds_xyz}) + +def sanitize_bonds_selection(bonds_data: BondsDataset, atoms: Optional[npt.NDArray[np.int32]] = None, bind_bonds_to_ats: bool = False, show_bonds: bool = True) -> Union[np.ndarray, None]: + """Sanitizes bonds selection, unifying multiple parameters into a single value + + Parameters + ----------- + bonds_data: BondsDataset + the bonds dataset containing the already computed bonds. + atoms: np.ndarray of shape (natoms,) + the atoms for which we want to keep the bonds. + bind_bonds_to_ats: bool + if True, the bonds will be bound to the atoms, + so that if an atom is not displayed, its bonds + will not be displayed either. + show_bonds: bool + if False, no bonds will be displayed. + """ + if not show_bonds: + return np.array([], dtype=np.int64) + elif bind_bonds_to_ats and atoms is not None: + return get_atoms_bonds(bonds_data.bonds, atoms, ret_mask=True) + else: + return None + +def bonds_to_lines(bonds_data: BondsDataset, points_per_bond: int = 2) -> BondsDataset: + """Computes intermediate points between the endpoints of the bonds by interpolation. + + Bonds are concatenated into a single dimension "point index", and NaNs + are added between bonds. + + Parameters + ----------- + bonds_data: BondsDataset + the bonds dataset containing the endpoints of the bonds. + points_per_bond: int + the number of points to be computed between the endpoints, + including the endpoints. + """ + if points_per_bond > 2: + bonds_data = bonds_data.interp(bond_atom=np.linspace(0, 1, points_per_bond)) + + bonds_data = bonds_data.reindex({"bond_atom": [*bonds_data.bond_atom.values, 2]}).stack(point_index=bonds_data.xyz.dims[:-1]) + + return bonds_data + +def sites_obj_to_geometry(sites_obj: BrillouinZone): + """Converts anything that contains sites into a geometry. + + Possible conversions: + - BrillouinZone object to geometry, kpoints to atoms. + + Parameters + ----------- + sites_obj + the object to be converted. + """ + + if isinstance(sites_obj, BrillouinZone): + return Geometry(sites_obj.k.dot(sites_obj.rcell), lattice=sites_obj.rcell) + else: + raise ValueError(f"Cannot convert {sites_obj.__class__.__name__} to a geometry.") + +def get_sites_units(sites_obj: BrillouinZone): + """Units of space for an object that is to be converted into a geometry""" + if isinstance(sites_obj, BrillouinZone): + return "1/Ang" + else: + return "" + + diff --git a/src/sisl/viz/processors/grid.py b/src/sisl/viz/processors/grid.py new file mode 100644 index 0000000000..f69a31b117 --- /dev/null +++ b/src/sisl/viz/processors/grid.py @@ -0,0 +1,583 @@ +from __future__ import annotations + +from typing import Callable, List, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +import numpy.typing as npt +from scipy.ndimage import affine_transform +from xarray import DataArray + +import sisl +from sisl import Geometry, Grid +from sisl import _array as _a +from sisl._lattice import cell_invert + +from .cell import infer_cell_axes, is_1D_cartesian, is_cartesian_unordered + +#from ...types import Axis, PathLike +#from ..data_sources import DataSource + +def get_grid_representation(grid: Grid, represent: Literal['real', 'imag', 'mod', 'phase', 'rad_phase', 'deg_phase']) -> Grid: + """Returns a representation of the grid + + Parameters + ------------ + grid: sisl.Grid + the grid for which we want return + represent: {"real", "imag", "mod", "phase", "deg_phase", "rad_phase"} + the type of representation. "phase" is equivalent to "rad_phase" + + Returns + ------------ + sisl.Grid + """ + def _func(values: npt.NDArray[Union[np.int_, np.float_, np.complex_]]) -> npt.NDArray: + if represent == 'real': + new_values = values.real + elif represent == 'imag': + new_values = values.imag + elif represent == 'mod': + new_values = np.absolute(values) + elif represent in ['phase', 'rad_phase', 'deg_phase']: + new_values = np.angle(values, deg=represent.startswith("deg")) + else: + raise ValueError(f"'{represent}' is not a valid value for the `represent` argument") + + return new_values + + return grid.apply(_func) + +def tile_grid(grid: Grid, nsc: Tuple[int, int, int] = (1, 1, 1)) -> Grid: + """Tiles the grid""" + for ax, reps in enumerate(nsc): + grid = grid.tile(reps, ax) + return grid + +def transform_grid_cell(grid: Grid, cell: npt.NDArray[np.float_] = np.eye(3), output_shape: Optional[Tuple[int, int, int]] = None, mode: str = "constant", order: int = 1, **kwargs) -> Grid: + """Applies a linear transformation to the grid to get it relative to an arbitrary cell. + + This method can be used, for example to get the values of the grid with respect to + the standard basis, so that you can easily visualize it or overlap it with other grids + (e.g. to perform integrals). + + Parameters + ----------- + cell: array-like of shape (3,3) + these cell represent the directions that you want to use as references for + the new grid. + + The length of the axes does not have any effect! They will be rescaled to create + the minimum bounding box necessary to accomodate the unit cell. + output_shape: array-like of int of shape (3,), optional + the shape of the final output. If not provided, the current shape of the grid + will be used. + + Notice however that if the transformation applies a big shear to the image (grid) + you will probably need to have a bigger output_shape. + mode: str, optional + determines how to handle borders. See scipy docs for more info on the possible values. + order : int 0-5, optional + the order of the spline interpolation to calculate the values (since we are applying + a transformation, we don't actually have values for the new locations and we need to + interpolate them) + 1 means linear, 2 quadratic, etc... + **kwargs: + the rest of keyword arguments are passed directly to `scipy.ndimage.affine_transform` + + See also + ---------- + scipy.ndimage.affine_transform : method used to apply the linear transformation. + """ + # Take the current shape of the grid if no output shape was provided + if output_shape is None: + output_shape = grid.shape + + # Make sure the cell has type float + cell = np.asarray(cell, dtype=float) + + # Get the current cell in coordinates of the destination axes + inv_cell = cell_invert(cell).T + projected_cell = grid.cell.dot(inv_cell) + + # From that, infere how long will the bounding box of the cell be + lengths = abs(projected_cell).sum(axis=0) + + # Create the transformation matrix. Since we want to control the shape + # of the output, we can not use grid.dcell directly, we need to modify it. + scales = output_shape / lengths + forward_t = (grid.dcell.dot(inv_cell)*scales).T + + # Scipy's affine transform asks for the inverse transformation matrix, to + # map from output pixels to input pixels. By taking the inverse of our + # transformation matrix, we get exactly that. + tr = cell_invert(forward_t).T + + # Calculate the offset of the image so that all points of the grid "fall" inside + # the output array. + # For this we just calculate the centers of the input and output images + center_input = 0.5 * (_a.asarrayd(grid.shape) - 1) + center_output = 0.5 * (_a.asarrayd(output_shape) - 1) + + # And then make sure that the input center that is interpolated from the output + # falls in the actual input's center + offset = center_input - tr.dot(center_output) + + # We pass all the parameters to scipy's affine_transform + transformed_image = affine_transform(grid.grid, tr, order=1, offset=offset, + output_shape=output_shape, mode=mode, **kwargs) + + # Create a new grid with the new shape and the new cell (notice how the cell + # is rescaled from the input cell to fit the actual coordinates of the system) + new_grid = grid.__class__((1, 1, 1), lattice=cell*lengths.reshape(3, 1)) + new_grid.grid = transformed_image + new_grid.geometry = grid.geometry + #new_grid.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset)) + + # Find the offset between the origin before and after the transformation + return new_grid #,new_grid.dcell.dot(forward_t.dot(offset)) + +def orthogonalize_grid(grid:Grid, interp: Tuple[int, int, int] = (1, 1, 1), mode: str = "constant", **kwargs) -> Grid: + """Transform grid cell to be orthogonal. + + Uses `transform_grid_cell`. + + Parameters + ----------- + grid: sisl.Grid + The grid to transform. + interp: array-like of int of shape (3,), optional + Number of times that the grid should be augmented for each + lattice vector. + mode: str, optional + determines how to handle borders. + See `transform_grid_cell` for more info on the possible values. + **kwargs: + the rest of keyword arguments are passed directly to `transform_grid_cell` + """ + return transform_grid_cell( + grid, mode=mode, output_shape=tuple(interp[i] * grid.shape[i] for i in range(3)), cval=np.nan, **kwargs + ) + +def orthogonalize_grid_if_needed(grid: Grid, axes: Sequence[str], tol: float = 1e-3, + interp: Tuple[int, int, int] = (1, 1, 1), mode: str = "constant", **kwargs) -> Grid: + """Same as `orthogonalize_grid`, but first checks if it is really needed. + + Parameters + ----------- + grid: sisl.Grid + The grid to transform. + axes: list of str + axes that will be plotted. + tol: float, optional + tolerance to determine whether the grid should be transformed. + interp: array-like of int of shape (3,), optional + Number of times that the grid should be augmented for each + lattice vector. + mode: str, optional + determines how to handle borders. + See `transform_grid_cell` for more info on the possible values. + **kwargs: + the rest of keyword arguments are passed directly to `transform_grid_cell` + """ + + should_ortogonalize = should_transform_grid_cell_plotting(grid=grid, axes=axes, tol=tol) + + if should_ortogonalize: + grid = orthogonalize_grid(grid, interp=interp, mode=mode, **kwargs) + + return grid + +def apply_transform(grid: Grid, transform: Union[Callable, str]) -> Grid: + """applies a transformation to the grid. + + Parameters + ----------- + grid: sisl.Grid + The grid to transform. + transform: callable or str + The transformation to apply. If it is a string, it will be + interpreted as a numpy function unless it contains a dot, in + which case it will be interpreted as a path to a function. + """ + if isinstance(transform, str): + # Since this may come from the GUI, there might be extra spaces + transform = transform.strip() + + # If is a string with no dots, we will assume it is a numpy function + if len(transform.split(".")) == 1: + transform = f"numpy.{transform}" + + return grid.apply(transform) + +def apply_transforms(grid: Grid, transforms: Sequence[Union[Callable, str]]) -> Grid: + """Applies multiple transformations sequentially + + Parameters + ----------- + grid: sisl.Grid + The grid to transform. + transforms: list of callable or str + The transformations to apply. If a transformation it is a string, it will be + interpreted as a numpy function unless it contains a dot, in which case it will + be interpreted as a path to a function. + """ + for transform in transforms: + grid = apply_transform(grid, transform) + return grid + +def reduce_grid(grid: Grid, reduce_method: Literal["average", "sum"], keep_axes: Sequence[int]) -> Grid: + """Reduces the grid along multiple axes + + Parameters + ----------- + grid: sisl.Grid + The grid to reduce. + reduce_method: {"average", "sum"} + The method to use to reduce the grid. + keep_axes: list of int + Lattice vectors to maintain (not reduce). + """ + # Reduce the dimensions that are not going to be displayed + for ax in [0, 1, 2]: + if ax not in keep_axes: + grid = getattr(grid, reduce_method)(ax) + + return grid + +def sub_grid( + grid: Grid, + x_range: Optional[Tuple[float, float]] = None, + y_range: Optional[Tuple[float, float]] = None, + z_range: Optional[Tuple[float, float]] = None, + cart_tol: float = 1e-3 +) -> Grid: + """Returns only the part of the grid that is within the specified ranges. + + Only works for cartesian dimensions that correspond to some lattice vector. For + example, if the grid is skewed in XY but not in Z, this function can sub along Z + but not along X or Y. + + If there's no point that coincides with the limits, the closest point will be + taken. This means that the returned grid might not be limited exactly by the bounds + provided. + + Parameters + ----------- + grid: sisl.Grid + The grid to sub. + x_range: tuple of float, optional + The range of the x coordinate. + y_range: tuple of float, optional + The range of the y coordinate. + z_range: tuple of float, optional + The range of the z coordinate. + cart_tol: float, optional + Tolerance to determine whether a dimension is cartesian or not. + """ + + cell = grid.lattice.cell + + # Get only the part of the grid that we need + ax_ranges = [x_range, y_range, z_range] + directions = ["x", "y", "z"] + for ax, (ax_range, direction) in enumerate(zip(ax_ranges, directions)): + if ax_range is not None: + + # Cartesian check + if not is_1D_cartesian(cell, direction, tol=cart_tol): + raise ValueError(f"Cannot sub grid along '{direction}', since there is no unique lattice vector that represents this direction. Cell: {cell}") + + # Find out which lattice vector represents the direction + lattice_ax = np.where(cell[:, ax] > cart_tol)[0][0] + + # Build an array with the limits + lims = np.zeros((2, 3)) + # If the cell was transformed, then we need to modify + # the range to get what the user wants. + lims[:, ax] = ax_range #+ self.offsets["cell_transform"][ax] - self.offsets["origin"][ax] + + # Get the indices of those points + indices = np.array([grid.index(lim) for lim in lims], dtype=int) + + # And finally get the subpart of the grid + grid = grid.sub(np.arange(indices[0, lattice_ax], indices[1, lattice_ax] + 1), lattice_ax) + + return grid + +def interpolate_grid(grid: Grid, interp: Tuple[int, int, int] = (1, 1, 1), force: bool = False) -> Grid: + """Interpolates the grid. + + It also makes sure that the grid is not interpolated over dimensions that only + contain one value, unless `force` is True. + + If the interpolation factors are all 1, the grid is returned unchanged. + + Parameters + ----------- + grid: sisl.Grid + The grid to interpolate. + interp: array-like of int of shape (3,), optional + Number of times that the grid should be augmented for each + lattice vector. + force: bool, optional + Whether to force the interpolation over dimensions that only + contain one value. + """ + + grid_shape = np.array(grid.shape) + + interp_factors = np.array(interp) + if not force: + # No need to interpolate over dimensions that only contain one value. + interp_factors[grid_shape == 1] = 1 + + interp_factors = interp_factors * grid_shape + if (interp_factors != 1).any(): + grid = grid.interp(interp_factors.astype(int)) + + return grid + +def grid_geometry(grid: Grid, geometry: Optional[Geometry] = None) -> Union[Geometry, None]: + """Returns the geometry associated with the grid. + + Parameters + ----------- + grid: sisl.Grid + The grid for which we want to get the geometry. + geometry: sisl.Geometry, optional + If provided, this geometry will be returned instead of the one + associated with the grid. + """ + if geometry is None: + geometry = getattr(grid, "geometry", None) + + return geometry + +def should_transform_grid_cell_plotting(grid: Grid, axes: Sequence[str], tol: float = 1e-3) -> bool: + """Determines whether the grid should be transformed for plotting. + + It takes into account the axes that will be plotted and checks if the grid + is skewed in any of those directions. If it is, it will return True, meaning + that the grid should be transformed before plotting. + + Parameters + ----------- + grid: sisl.Grid + grid to check. + axes: list of str + axes that will be plotted. + """ + ndim = len(axes) + + # Determine whether we should transform the grid to cartesian axes. This will be needed + # if the grid is skewed. However, it is never needed for the 3D representation, since we + # compute the coordinates of each point in the isosurface, and we don't need to reduce the + # grid. + should_orthogonalize = not is_cartesian_unordered(grid, tol=tol) and len(axes) < 3 + # We also don't need to orthogonalize if cartesian coordinates are not requested + # (this would mean that axes is a combination of "a", "b" and "c") + should_orthogonalize = should_orthogonalize and bool(set(axes).intersection(["x", "y", "z"])) + + if should_orthogonalize and ndim == 1: + # In 1D representations, even if the cell is skewed, we might not need to transform. + # An example of a cell that we don't need to transform is: + # a = [1, 1, 0], b = [1, -1, 0], c = [0, 0, 1] + # If the user wants to display the values on the z coordinate, we can safely reduce the + # first two axes, as they don't contribute in the Z direction. Also, it is required that + # "c" doesn't contribute to any of the other two directions. + should_orthogonalize &= not is_1D_cartesian(grid, axes[0], tol=tol) + + return should_orthogonalize + +def get_grid_axes(grid: Grid, axes: Sequence[str]) -> List[int]: + """Returns the indices of the lattice vectors that correspond to the axes. + + If axes is of length 3 (i.e. a 3D view), this function always returns [0, 1, 2] + regardless of what the axes are. + + Parameters + ----------- + grid: sisl.Grid + The grid for which we want to get the axes. + axes: list of str + axes that will be plotted. Either cartesian or "a", "b", "c". + """ + + ndim = len(axes) + + if ndim < 3: + grid_axes = infer_cell_axes(grid, axes) + elif ndim == 3: + grid_axes = [0, 1, 2] + else: + raise ValueError(f"Invalid number of axes: {ndim}") + + return grid_axes + +def get_ax_vals(grid: Grid, ax: Literal[0,1,2,"a","b","c","x","y","z"], nsc: Tuple[int, int, int]) -> npt.NDArray[np.float_]: + """Returns the values of a given axis on all grid points. + + These can be used for example as axes ticks on a plot. + + Parameters + ---------- + grid: sisl.Grid + The grid for which we want to get the axes values. + ax: {"x", "y", "z", "a", "b", "c", 0, 1, 2} + The axis for which we want the values. + nsc: array-like of int of shape (3,) + Number of times that the grid has been tiled in each direction, so that + if a fractional axis is requested, the values are correct. + """ + if isinstance(ax, int) or ax in ("a", "b", "c"): + ax = {"a": 0, "b": 1, "c": 2}.get(ax, ax) + ax_vals = np.linspace(0, nsc[ax], grid.shape[ax]) + else: + offset = grid.origin + + ax = {"x": 0, "y": 1, "z": 2}[ax] + + ax_vals = np.arange(0, grid.cell[ax, ax], grid.dcell[ax, ax]) + get_offset(grid, ax) + + if len(ax_vals) == grid.shape[ax] + 1: + ax_vals = ax_vals[:-1] + + return ax_vals + +def get_offset(grid: Grid, ax: Literal[0,1,2,"a","b","c","x","y","z"]) -> float: + """Returns the offset of the grid along a certain axis. + + Parameters + ----------- + grid: sisl.Grid + The grid for which we want to get the offset. + ax: {"x", "y", "z", "a", "b", "c", 0, 1, 2} + The axis for which we want the offset. + """ + + if isinstance(ax, int) or ax in ("a", "b", "c"): + return 0 + else: + coord_index = "xyz".index(ax) + return grid.origin[coord_index] + +GridDataArray = DataArray + +def grid_to_dataarray(grid: Grid, axes: Sequence[str], grid_axes: Sequence[int], nsc: Tuple[int, int, int]) -> GridDataArray: + + transpose_grid_axes = [*grid_axes] + for ax in (0, 1, 2): + if ax not in transpose_grid_axes: + transpose_grid_axes.append(ax) + + values = np.squeeze(grid.grid.transpose(*transpose_grid_axes)) + + arr = DataArray( + values, + coords=[ + (k, get_ax_vals(grid, ax, nsc=nsc)) + for k, ax in zip(["x", "y", "z"], axes) + ] + ) + + arr.attrs['grid'] = grid + + return arr + +def get_isos(data: GridDataArray, isos: Sequence[dict]) -> List[dict]: + """Gets the iso surfaces or isocontours of an array of data. + + Parameters + ----------- + data: DataArray + The data for which we want to get the iso surfaces. + isos: list of dict + List of isosurface specifications. + """ + from skimage.measure import find_contours + + #values = data['values'].values + values = data.values + isos_to_draw = [] + + # Get the dimensionality of the data + ndim = values.ndim + + if len(isos) > 0 or ndim == 3: + minval = np.nanmin(values) + maxval = np.nanmax(values) + + # Prepare things for each possible dimensionality + if ndim == 1: + # For now, we don't calculate 1D "isopoints" + return [] + elif ndim == 2: + # Get the partition size + dx = data.x[1] - data.x[0] + dy = data.y[1] - data.y[0] + + # Function to get the coordinates from indices + def _indices_to_2Dspace(contour_coords): + return contour_coords.dot([[dx, 0, 0], [0, dy, 0]]) + + def _calc_iso(isoval): + contours = find_contours(values, isoval) + + contour_xs = [] + contour_ys = [] + for contour in contours: + # Swap the first and second columns so that we have [x,y] for each + # contour point (instead of [row, col], which means [y, x]) + contour_coords = contour[:, [1, 0]] + # Then convert from indices to coordinates in the 2D space + contour_coords = _indices_to_2Dspace(contour_coords) + contour_xs = [*contour_xs, None, *contour_coords[:, 0]] + contour_ys = [*contour_ys, None, *contour_coords[:, 1]] + + # Add the information about this isoline to the list of isolines + return { + "x": contour_xs, "y": contour_ys, "width": iso.get("width"), + } + + elif ndim == 3: + # In 3D, use default isosurfaces if none were provided. + if len(isos) == 0 and maxval != minval: + default_iso_frac = 0.3 #isos_param["frac"].default + + # If the default frac is 0.3, they will be displayed at 0.3 and 0.7 + isos = [ + {"frac": default_iso_frac}, + {"frac": 1-default_iso_frac} + ] + + # Define the function that will calculate each isosurface + def _calc_iso(isoval): + vertices, faces, normals, intensities = data.grid.isosurface(isoval, iso.get("step_size", 1)) + + #vertices = vertices + self._get_offsets(grid) + self.offsets["origin"] + + return {"vertices": vertices, "faces": faces} + else: + raise ValueError(f"Dimensionality must be lower than 3, but is {ndim}") + + # Now loop through all the isos + for iso in isos: + if not iso.get("active", True): + continue + + # Infer the iso value either from val or from frac + isoval = iso.get("val") + if isoval is None: + frac = iso.get("frac") + if frac is None: + raise ValueError(f"You are providing an iso query without 'val' and 'frac'. There's no way to know the isovalue!\nquery: {iso}") + isoval = minval + (maxval-minval)*frac + + isos_to_draw.append({ + "color": iso.get("color"), "opacity": iso.get("opacity"), + "name": iso.get("name", "Iso: $isoval$").replace("$isoval$", f"{isoval:.4f}"), + **_calc_iso(isoval), + }) + + return isos_to_draw + diff --git a/src/sisl/viz/processors/logic.py b/src/sisl/viz/processors/logic.py new file mode 100644 index 0000000000..8253b4cd77 --- /dev/null +++ b/src/sisl/viz/processors/logic.py @@ -0,0 +1,21 @@ +from typing import Any, Tuple, TypeVar, Union + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + +def swap(val: Union[T1, T2], vals: Tuple[T1, T2]) -> Union[T1, T2]: + """Given two values, returns the one that is not the input value.""" + if val == vals[0]: + return vals[1] + elif val == vals[1]: + return vals[0] + else: + raise ValueError(f"Value {val} not in {vals}") + +def matches(first: Any, second: Any, ret_true: T1 = True, ret_false: T2 = False) -> Union[T1, T2]: + """If first matches second, return ret_true, else return ret_false.""" + return ret_true if first == second else ret_false + +def switch(obj: Any, ret_true: T1, ret_false: T2) -> Union[T1, T2]: + """If obj is True, return ret_true, else return ret_false.""" + return ret_true if obj else ret_false \ No newline at end of file diff --git a/src/sisl/viz/processors/orbital.py b/src/sisl/viz/processors/orbital.py new file mode 100644 index 0000000000..18495856c0 --- /dev/null +++ b/src/sisl/viz/processors/orbital.py @@ -0,0 +1,645 @@ +from collections import ChainMap, defaultdict +from dataclasses import asdict, is_dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Sequence, TypedDict, Union + +import numpy as np +import xarray +from xarray import DataArray, Dataset + +import sisl +from sisl import Geometry, Spin +from sisl.messages import SislError +from sisl.typing import AtomsArgument +from sisl.viz.types import OrbitalStyleQuery + +from .._single_dispatch import singledispatchmethod +from ..data import Data +from ..processors.xarray import group_reduce +from .spin import get_spin_options + + +class OrbitalGroup(TypedDict): + name: str + orbitals: Any + spin: Any + reduce_func: Optional[Callable] + spin_reduce: Optional[Callable] + +class OrbitalQueriesManager: + """ + This class implements an input field that allows you to select orbitals by atom, species, etc... + """ + _item_input_type = OrbitalStyleQuery + + _keys_to_cols = { + "atoms": "atom", + "orbitals": "orbital_name", + } + + geometry: Geometry + spin: Spin + + key_gens: Dict[str, Callable] = {} + + @singledispatchmethod + @classmethod + def new(cls, geometry: Geometry, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}): + return cls(geometry=geometry, spin=spin or "", key_gens=key_gens) + + @new.register + @classmethod + def from_geometry(cls, geometry: Geometry, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}): + return cls(geometry=geometry, spin=spin or "", key_gens=key_gens) + + @new.register + @classmethod + def from_string(cls, + string: str, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {} + ): + """Initializes an OrbitalQueriesManager from a string, assuming it is a path.""" + return cls.new(Path(string), spin=spin, key_gens=key_gens) + + @new.register + @classmethod + def from_path(cls, + path: Path, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {} + ): + """Initializes an OrbitalQueriesManager from a path, converting it to a sile.""" + return cls.new(sisl.get_sile(path), spin=spin, key_gens=key_gens) + + @new.register + @classmethod + def from_sile(cls, + sile: sisl.io.BaseSile, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}, + ): + """Initializes an OrbitalQueriesManager from a sile.""" + return cls.new(sile.read_geometry(), spin=spin, key_gens=key_gens) + + @new.register + @classmethod + def from_xarray(cls, + array: xarray.core.common.AttrAccessMixin, spin: Optional[Union[str, Spin]] = None, key_gens: Dict[str, Callable] = {}, + ): + """Initializes an OrbitalQueriesManager from an xarray object.""" + if spin is None: + spin = array.attrs.get("spin", "") + + return cls.new(array.attrs.get("geometry"), spin=spin, key_gens=key_gens) + + @new.register + @classmethod + def from_data(cls, + data: Data, spin: Optional[Union[str, Spin]] = None, key_gens: Dict[str, Callable] = {} + ): + """Initializes an OrbitalQueriesManager from a sisl Data object.""" + return cls.new(data._data, spin=spin, key_gens=key_gens) + + def __init__(self, geometry: Optional[Geometry] = None, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}): + + self.geometry = geometry + self.spin = Spin(spin) + + self.key_gens = key_gens + + self._build_orb_filtering_df(geometry) + + def complete_query(self, query={}, **kwargs): + """ + Completes a partially build query with the default values + + Parameters + ----------- + query: dict + the query to be completed. + **kwargs: + other keys that need to be added to the query IN CASE THEY DON'T ALREADY EXIST + """ + kwargs.update(query) + + # If it's a non-colinear or spin orbit spin class, the default spin will be total, + # since averaging/summing over "x","y","z" does not make sense. + if "spin" not in kwargs and not self.spin.is_diagonal: + kwargs["spin"] = ["total"] + + return self._item_input_type(**kwargs) + + def filter_df(self, df, query, key_to_cols, raise_not_active=False): + """ + Filters a dataframe according to a query + + Parameters + ----------- + df: pd.DataFrame + the dataframe to filter. + query: dict + the query to be used as a filter. Can be incomplete, it will be completed using + `self.complete_query()` + keys_to_cols: array-like of tuples + An array of tuples that look like (key, col) + where key is the key of the parameter in the query and col the corresponding + column in the dataframe. + """ + query = asdict(self.complete_query(query)) + + if raise_not_active: + if not query["active"]: + raise ValueError(f"Query {query} is not active and you are trying to use it") + + query_str = [] + for key, val in query.items(): + if key == "orbitals" and val is not None and len(val) > 0 and isinstance(val[0], int): + df = df.iloc[val] + continue + + key = key_to_cols.get(key, key) + if key in df and val is not None: + if isinstance(val, (np.ndarray, tuple)): + val = np.ravel(val).tolist() + query_str.append(f'{key}=={repr(val)}') + + if len(query_str) == 0: + return df + else: + return df.query(" & ".join(query_str)) + + def _build_orb_filtering_df(self, geom): + import pandas as pd + + orb_props = defaultdict(list) + del_key = set() + #Loop over all orbitals of the basis + for at, iorb in geom.iter_orbitals(): + + atom = geom.atoms[at] + orb = atom[iorb] + + orb_props["atom"].append(at) + orb_props["Z"].append(atom.Z) + orb_props["species"].append(atom.symbol) + orb_props["orbital_name"].append(orb.name()) + + for key in ("n", "l", "m", "zeta"): + val = getattr(orb, key, None) + if val is None: + del_key.add(key) + orb_props[key].append(val) + + for key in del_key: + del orb_props[key] + + self.orb_filtering_df = pd.DataFrame(orb_props) + + def get_options(self, key, **kwargs): + """ + Gets the options for a given key or combination of keys. + + Parameters + ------------ + key: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"} + the parameter that you want the options for. + + Note that you can combine them with a "+" to get all the possible combinations. + You can get the same effect also by passing a list. + See examples. + **kwargs: + keyword arguments that add additional conditions to the query. The values of this + keyword arguments can be lists, in which case it indicates that you want a value + that is in the list. See examples. + + Returns + ---------- + np.ndarray of shape (n_options, [n_keys]) + all the possible options. + + If only one key was provided, it is a one dimensional array. + + Examples + ----------- + + >>> orb_manager = OrbitalQueriesManager(geometry) + >>> orb_manager.get_options("l", species="Au") + >>> orb_manager.get_options("n+l", atoms=[0,1]) + """ + # Get the tadatframe + df = self.orb_filtering_df + + # Filter the dataframe according to the constraints imposed by the kwargs, + # if there are any. + if kwargs: + if "atoms" in kwargs: + kwargs["atoms"] = self.geometry._sanitize_atoms(kwargs["atoms"]) + def _repr(v): + if isinstance(v, np.ndarray): + v = list(v.ravel()) + if isinstance(v, dict): + raise Exception(str(v)) + return repr(v) + query = ' & '.join([f'{self._keys_to_cols.get(k, k)}=={_repr(v)}' for k, v in kwargs.items( + ) if self._keys_to_cols.get(k, k) in df]) + if query: + df = df.query(query) + + # If + is in key, it is a composite key. In that case we are going to + # split it into all the keys that are present and get the options for all + # of them. At the end we are going to return a list of tuples that will be all + # the possible combinations of the keys. + keys = [self._keys_to_cols.get(k, k) for k in key.split("+")] + + # Spin values are not stored in the orbital filtering dataframe. If the options + # for spin are requested, we need to pop the key out and get the current options + # for spin from the input field + spin_in_keys = "spin" in keys + if spin_in_keys: + spin_key_i = keys.index("spin") + keys.remove("spin") + spin_options = get_spin_options(self.spin) + + # We might have some constraints on what the spin value can be + if "spin" in kwargs: + spin_options = set(spin_options).intersection(kwargs["spin"]) + + # Now get the unique options from the dataframe + if keys: + options = df.drop_duplicates(subset=keys)[ + keys].values.astype(object) + else: + # It might be the only key was "spin", then we are going to fake it + # to get an options array that can be treated in the same way. + options = np.array([[]], dtype=object) + + # If "spin" was one of the keys, we are going to incorporate the spin options, taking into + # account the position (column index) where they are expected to be returned. + if spin_in_keys and len(spin_options) > 0: + options = np.concatenate( + [np.insert(options, spin_key_i, spin, axis=1) for spin in spin_options]) + + # Squeeze the options array, just in case there is only one key + # There's a special case: if there is only one option for that key, + # squeeze converts it to a number, so we need to make sure there is at least 1d + if options.shape[1] == 1: + options = options.squeeze() + options = np.atleast_1d(options) + + return options + + def get_orbitals(self, query): + + if "atoms" in query: + query["atoms"] = self.geometry._sanitize_atoms(query["atoms"]) + + filtered_df = self.filter_df( + self.orb_filtering_df, query, self._keys_to_cols + ) + + return filtered_df.index.values + + def get_atoms(self, query): + + if "atoms" in query: + query["atoms"] = self.geometry._sanitize_atoms(query["atoms"]) + + filtered_df = self.filter_df( + self.orb_filtering_df, query, self._keys_to_cols + ) + + return np.unique(filtered_df['atom'].values) + + def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignore_constraints=False, **kwargs): + """ + Splits a query into multiple queries based on one of its parameters. + + Parameters + -------- + query: dict + the query that we want to split + on: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"}, or list of str + the parameter to split along. + Note that you can combine parameters with a "+" to split along multiple parameters + at the same time. You can get the same effect also by passing a list. + only: array-like, optional + if desired, the only values that should be plotted out of + all of the values that come from the splitting. + exclude: array-like, optional + values of the splitting that should not be plotted. + query_gen: function, optional + the request generator. It is a function that takes all the parameters for each + request that this method has come up with and gets a chance to do some modifications. + + This may be useful, for example, to give each request a color, or a custom name. + ignore_constraints: boolean or array-like, optional + determines whether constraints (imposed by the query that you want to split) + on the parameters that we want to split along should be taken into consideration. + + If `False`: all constraints considered. + If `True`: no constraints considered. + If array-like: parameters contained in the list ignore their constraints. + **kwargs: + keyword arguments that go directly to each new request. + + This is useful to add extra filters. For example: + + `self._split_query(request, on="orbitals", spin=[0])` + will split the request on the different orbitals but will take + only the contributions from spin up. + """ + if exclude is None: + exclude = [] + + # Divide the splitting request into all the parameters + if isinstance(on, str): + on = on.split("+") + + # Get the current values of the parameters that we want to split the request on + # because these will be our constraints. If a parameter is set to None or not + # provided, we have no constraints for that parameter. + if ignore_constraints is True: + constraints = {} + else: + constraints = ChainMap(kwargs, query) + + if ignore_constraints is False: + ignore_constraints = () + + constraints = {key: val for key, val in constraints.items() if key not in ignore_constraints and val is not None} + + # Knowing what are our constraints (which may be none), get the available options + values = self.get_options("+".join(on), **constraints) + + # We are going to make sure that, even if there was only one parameter to split on, + # the values are two dimensional. In this way, we can take the same actions for the + # case when there is only one parameter and the case when there are multiple. + if values.ndim == 1: + values = values.reshape(-1, 1) + + # If no function to modify queries was provided we are going to use the default one + # associated to this class. + if query_gen is None: + query_gen = self.complete_query + + # We ensure that on is a list even if there is only one parameter, for the same + # reason we ensured values was 2 dimensional + if isinstance(on, str): + on = on.split("+") + + # Define the name that we will give to the new queries, using templating + # If a splitting parameter is not used by the name, we are going to + # append it, in order to make names unique and self-explanatory. + base_name = kwargs.pop("name", query.get("name", "")) or "" + first_added = True + for key in on: + kwargs.pop(key, None) + + if f"${key}" not in base_name: + base_name += f"{' | ' if first_added else ', '}{key}=${key}" + first_added = False + + # Now build all the queries + queries = [] + for i, value in enumerate(values): + if value not in exclude and (only is None or value in only): + + # Use the name template to generate the name for this query + name = base_name + for key, val in zip(on, value): + name = name.replace(f"${key}", str(val)) + + # Build the query + query = query_gen(**{ + **query, + **{key: [val] for key, val in zip(on, value)}, + "name": name, **kwargs + }) + + # Make sure it is a dict + if is_dataclass(query): + query = asdict(query) + + # And append the new query to the queries + queries.append(query) + + return queries + + def generate_queries(self, + split: str, + only: Optional[Sequence] = None, + exclude: Optional[Sequence] = None, + query_gen: Optional[Callable[[dict], dict]] = None, + **kwargs + ): + """ + Automatically generates queries based on the current options. + + Parameters + -------- + split: str, {"species", "atoms", "Z", "orbitals", "n", "l", "m", "zeta", "spin"} or list of str + the parameter to split on. + Note that you can combine parameters with a "+" to split along multiple parameters + at the same time. You can get the same effect also by passing a list. + only: array-like, optional + if desired, the only values that should be plotted out of + all of the values that come from the splitting. + exclude: array-like, optional + values that should not be plotted + query_gen: function, optional + the request generator. It is a function that takes all the parameters for each + request that this method has come up with and gets a chance to do some modifications. + + This may be useful, for example, to give each request a color, or a custom name. + **kwargs: + keyword arguments that go directly to each request. + + This is useful to add extra filters. For example: + `generate_queries(split="orbitals", species=["C"])` + will split the PDOS on the different orbitals but will take + only those that belong to carbon atoms. + """ + return self._split_query({}, on=split, only=only, exclude=exclude, query_gen=query_gen, **kwargs) + + def sanitize_query(self, query): + # Get the complete request and make sure it is a dict. + query = self.complete_query(query) + if is_dataclass(query): + query = asdict(query) + + # Determine the reduce function from the "reduce" passed and the scale factor. + def _reduce_func(arr, **kwargs): + reduce_ = query['reduce'] + if isinstance(reduce_, str): + reduce_ = getattr(np, reduce_) + + if kwargs['axis'] == (): + return arr + return reduce_(arr, **kwargs) * query.get("scale", 1) + + # Finally, return the sanitized request, converting the request (contains "species", "n", "l", etc...) + # into a list of orbitals. + return { + **query, + "orbitals": self.get_orbitals(query), + "reduce_func": _reduce_func, + **{k: gen(query) for k, gen in self.key_gens.items()} + } + +def generate_orbital_queries( + orb_manager: OrbitalQueriesManager, + split: str, + only: Optional[Sequence] = None, + exclude: Optional[Sequence] = None, + query_gen: Optional[Callable[[dict], dict]] = None, +): + return orb_manager.generate_queries(split, only=only, exclude=exclude, query_gen=query_gen) + +def reduce_orbital_data(orbital_data: Union[DataArray, Dataset], groups: Sequence[OrbitalGroup], geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, spin_reduce: Optional[Callable] = None, orb_dim: str = "orb", spin_dim: str = "spin", + groups_dim: str = "group", sanitize_group: Union[Callable, OrbitalQueriesManager, None] = None, group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, fill_empty: Any = 0. +) -> Union[DataArray, Dataset]: + """Groups contributions of orbitals into a new dimension. + + Given an xarray object containing orbital information and the specification of groups of orbitals, this function + computes the total contribution for each group of orbitals. It therefore removes the orbitals dimension and + creates a new one to account for the groups. + + It can also reduce spin in the same go if requested. In that case, groups can also specify particular spin components. + + Parameters + ---------- + orbital_data : DataArray or Dataset + The xarray object to reduce. + groups : Sequence[OrbitalGroup] + A sequence containing the specifications for each group of orbitals. See ``OrbitalGroup``. + geometry : Geometry, optional + The geometry object that will be used to parse orbital specifications into actual orbital indices. Knowing the + geometry therefore allows you to specify more complex selections. + If not provided, it will be searched in the ``geometry`` attribute of the ``orbital_data`` object and + afterwards in the ``parent`` attribute, under ``parent.geometry``. + reduce_func : Callable, optional + The function that will compute the reduction along the orbitals dimension once the selection is done. + This could be for example ``numpy.mean`` or ``numpy.sum``. + Notice that this will only be used in case the group specification doesn't specify a particular function + in its "reduce_func" field, which will take preference. + spin_reduce: Callable, optional + The function that will compute the reduction along the spin dimension once the selection is done. + orb_dim: str, optional + Name of the dimension that contains the orbital indices in ``orbital_data``. + spin_dim: str, optional + Name of the dimension that contains the spin components in ``orbital_data``. + groups_dim: str, optional + Name of the new dimension that will be created for the groups. + sanitize_group: Union[Callable, OrbitalQueriesManager], optional + A function that will be used to sanitize the group specification before it is used. + If a ``OrbitalQueriesManager`` is passed, its `sanitize_query` method will be used. + If not provided and a geometry is found in the attributes of the ``orbital_data`` object, + an `OrbitalQueriesManager` will be automatically created from it. + group_vars: Sequence[str], optional + If set, this argument specifies extra variables that depend on the group and the user would like to + introduce in the new xarray object. These variables will be searched as fields for each group specification. + A data variable will be created for each group_var and they will be added to the final xarray object. + Note that this forces the returned object to be a Dataset, even if the input data is a DataArray. + drop_empty: bool, optional + If set to `True`, group specifications that do not correspond to any orbital will not appear in the final + returned object. + fill_empty: Any, optional + If ``drop_empty`` is set to ``False``, this argument specifies the value to use for group specifications + that do not correspond to any orbital. + """ + # If no geometry was provided, then get it from the attrs of the xarray object. + if geometry is None: + geometry = orbital_data.attrs.get("geometry") + if geometry is None: + parent = orbital_data.attrs.get('parent') + if parent is not None: + getattr(parent, "geometry") + + if sanitize_group is None: + if geometry is not None: + sanitize_group = OrbitalQueriesManager(geometry=geometry, spin=orbital_data.attrs.get("spin", "")) + else: + sanitize_group = lambda x: x + if isinstance(sanitize_group, OrbitalQueriesManager): + sanitize_group = sanitize_group.sanitize_query + + if geometry is None: + def _sanitize_group(group): + group = group.copy() + group = sanitize_group(group) + orbitals = group.get('orbitals') + try: + group['orbitals'] = np.array(orbitals, dtype=int) + assert orbitals.ndim == 1 + except: + raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't" + f" convert the provided atom selection ({orbitals}) to an array of integers.") + + group['selector'] = group['orbitals'] + if spin_reduce is not None and spin_dim in orbital_data.dims: + group['selector'] = (group['selector'], group.get('spin')) + group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) + + return group + else: + def _sanitize_group(group): + group = group.copy() + group = sanitize_group(group) + group["orbitals"] = geometry._sanitize_orbs(group["orbitals"]) + group['selector'] = group['orbitals'] + if spin_reduce is not None and spin_dim in orbital_data.dims: + group['selector'] = (group['selector'], group.get('spin')) + group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) + return group + + # If a reduction for spin was requested, then pass the two different functions to reduce + # each coordinate. + reduce_funcs = reduce_func + reduce_dims = orb_dim + if spin_reduce is not None and spin_dim in orbital_data.dims: + reduce_funcs = (reduce_func, spin_reduce) + reduce_dims = (orb_dim, spin_dim) + + return group_reduce( + data=orbital_data, groups=groups, reduce_dim=reduce_dims, reduce_func=reduce_funcs, + groups_dim=groups_dim, sanitize_group=_sanitize_group, group_vars=group_vars, + drop_empty=drop_empty, fill_empty=fill_empty + ) + +def get_orbital_queries_manager(obj, spin: Optional[str] = None, key_gens: Dict[str, Callable] = {}) -> OrbitalQueriesManager: + return OrbitalQueriesManager.new(obj, spin=spin, key_gens=key_gens) + +def split_orbitals(orbital_data, on="species", only=None, exclude=None, geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, spin_reduce: Optional[Callable] = None, orb_dim: str = "orb", spin_dim: str = "spin", + groups_dim: str = "group", group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, fill_empty: Any = 0., **kwargs): + + if geometry is not None: + orbital_data = orbital_data.copy() + orbital_data.attrs['geometry'] = geometry + + orbital_data = orbital_data.copy() + + orb_manager = get_orbital_queries_manager(orbital_data, key_gens=kwargs.pop("key_gens", {})) + + groups = orb_manager.generate_queries(split=on, only=only, exclude=exclude, **kwargs) + + return reduce_orbital_data( + orbital_data, groups=groups, sanitize_group=orb_manager, reduce_func=reduce_func, spin_reduce=spin_reduce, + orb_dim=orb_dim, spin_dim=spin_dim, groups_dim=groups_dim, group_vars=group_vars, drop_empty=drop_empty, + fill_empty=fill_empty + ) + +def atom_data_from_orbital_data(orbital_data, atoms: AtomsArgument = None, request_kwargs: Dict = {}, geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, spin_reduce: Optional[Callable] = None, orb_dim: str = "orb", spin_dim: str = "spin", + groups_dim: str = "atom", group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, fill_empty: Any = 0., +): + request_kwargs["name"] = "$atoms" + + atom_data = split_orbitals( + orbital_data, on="atoms", only=atoms, reduce_func=reduce_func, spin_reduce=spin_reduce, + orb_dim=orb_dim, spin_dim=spin_dim, groups_dim=groups_dim, group_vars=group_vars, drop_empty=drop_empty, + fill_empty=fill_empty, **request_kwargs + ) + + atom_data = atom_data.assign_coords(atom=atom_data.atom.astype(int)) + + return atom_data \ No newline at end of file diff --git a/src/sisl/viz/processors/spin.py b/src/sisl/viz/processors/spin.py new file mode 100644 index 0000000000..1b151d35ac --- /dev/null +++ b/src/sisl/viz/processors/spin.py @@ -0,0 +1,30 @@ +from typing import List, Literal, Union + +from sisl import Spin + +_options = { + Spin.UNPOLARIZED: [], + Spin.POLARIZED: [{"label": "↑", "value": 0}, {"label": "↓", "value": 1}, + {"label": "Total", "value": "total"}, {"label": "Net z", "value": "z"}], + Spin.NONCOLINEAR: [{"label": val, "value": val} for val in ("total", "x", "y", "z")], + Spin.SPINORBIT: [{"label": val, "value": val} for val in ("total", "x", "y", "z")] +} + +def get_spin_options(spin: Union[Spin, str], only_if_polarized: bool = False) -> List[Literal[0, 1, "total", "x", "y", "z"]]: + """Returns the options for a given spin class. + + Parameters + ---------- + spin: sisl.Spin or str + The spin class to get the options for. + only_if_polarized: bool, optional + If set to `True`, non colinear spins will not have multiple options. + """ + spin = Spin(spin) + + if only_if_polarized and not spin.is_polarized: + options_spin = Spin("") + else: + options_spin = spin + + return [option['value'] for option in _options[options_spin.kind]] \ No newline at end of file diff --git a/src/sisl/viz/processors/tests/__init__.py b/src/sisl/viz/processors/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sisl/viz/processors/tests/test_axes.py b/src/sisl/viz/processors/tests/test_axes.py new file mode 100644 index 0000000000..07aaf87887 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_axes.py @@ -0,0 +1,87 @@ +import numpy as np +import pytest + +from sisl.viz.processors.axes import ( + axes_cross_product, + axis_direction, + get_ax_title, + sanitize_axes, +) + + +def test_sanitize_axes(): + + assert sanitize_axes(["x", "y", "z"]) == ["x", "y", "z"] + assert sanitize_axes("xyz") == ["x", "y", "z"] + assert sanitize_axes("abc") == ["a", "b", "c"] + assert sanitize_axes([0, 1, 2]) == ["a", "b", "c"] + assert sanitize_axes("-xy") == ["-x", "y"] + assert sanitize_axes("x-y") == ["x", "-y"] + assert sanitize_axes("-x-y") == ["-x", "-y"] + assert sanitize_axes("a-b") == ["a", "-b"] + + axes = sanitize_axes([[0,1,2]]) + assert isinstance(axes[0], np.ndarray) + assert axes[0].shape == (3,) + assert np.all(axes[0] == [0,1,2]) + + with pytest.raises(ValueError): + sanitize_axes([None]) + +def test_axis_direction(): + + assert np.allclose(axis_direction("x"), [1, 0, 0]) + assert np.allclose(axis_direction("y"), [0, 1, 0]) + assert np.allclose(axis_direction("z"), [0, 0, 1]) + + assert np.allclose(axis_direction("-x"), [-1, 0, 0]) + assert np.allclose(axis_direction("-y"), [0, -1, 0]) + assert np.allclose(axis_direction("-z"), [0, 0, -1]) + + assert np.allclose(axis_direction([1, 0, 0]), [1, 0, 0]) + + cell = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) + + assert np.allclose(axis_direction("a", cell), [0, 0, 1]) + assert np.allclose(axis_direction("b", cell), [1, 0, 0]) + assert np.allclose(axis_direction("c", cell), [0, 1, 0]) + + assert np.allclose(axis_direction("-a", cell), [0, 0, -1]) + assert np.allclose(axis_direction("-b", cell), [-1, 0, 0]) + assert np.allclose(axis_direction("-c", cell), [0, -1, 0]) + +def test_axes_cross_product(): + + assert np.allclose(axes_cross_product("x", "y"), [0, 0, 1]) + assert np.allclose(axes_cross_product("y", "x"), [0, 0, -1]) + assert np.allclose(axes_cross_product("-x", "y"), [0, 0, -1]) + + assert np.allclose(axes_cross_product([1,0,0], [0,1,0]), [0, 0, 1]) + assert np.allclose(axes_cross_product([0,1,0], [1,0,0]), [0, 0, -1]) + + cell = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) + + assert np.allclose(axes_cross_product("b", "c", cell), [0, 0, 1]) + assert np.allclose(axes_cross_product("c", "b", cell), [0, 0, -1]) + assert np.allclose(axes_cross_product("-b", "c", cell), [0, 0, -1]) + +def test_axis_title(): + + assert get_ax_title("title") == "title" + + assert get_ax_title("x") == "X axis [Ang]" + assert get_ax_title("y") == "Y axis [Ang]" + assert get_ax_title("-z") == "-Z axis [Ang]" + + assert get_ax_title("a") == "A lattice vector" + assert get_ax_title("b") == "B lattice vector" + assert get_ax_title("-c") == "-C lattice vector" + + assert get_ax_title(None) == "" + + assert get_ax_title(np.array([1,2,3])) == "[1 2 3]" + + def some_axis(): pass + + assert get_ax_title(some_axis) == "some_axis" + diff --git a/src/sisl/viz/processors/tests/test_bands.py b/src/sisl/viz/processors/tests/test_bands.py new file mode 100644 index 0000000000..ccf7f60b94 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_bands.py @@ -0,0 +1,250 @@ +import numpy as np +import pytest +import xarray as xr + +import sisl +from sisl import Spin +from sisl.viz.data import BandsData +from sisl.viz.processors.bands import ( + calculate_gap, + draw_gap, + draw_gaps, + filter_bands, + get_gap_coords, + sanitize_k, + style_bands, +) + + +@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"]) +def spin(request): + return Spin(request.param) + +@pytest.fixture(scope="module") +def gap(): + return 2.5 + +@pytest.fixture(scope="module") +def bands_data(spin, gap): + return BandsData.toy_example(spin=spin, gap=gap) + +@pytest.fixture(scope="module", params=["x", "y"]) +def E_axis(request): + return request.param + +def test_filter_bands(bands_data): + spin = bands_data.attrs['spin'] + + # Check that it works without any arguments + filtered_bands = filter_bands(bands_data) + + # Test filtering by band index. + filtered_bands = filter_bands(bands_data, bands_range=[0, 5]) + assert np.all(filtered_bands.band == np.arange(0, 6)) + + # Test filtering by energy. First check that we actually + # have bands beyond the energy range that we want to test. + assert bands_data.E.min() <= -5 + assert bands_data.E.max() >= 5 + + filtered_bands = filter_bands(bands_data, Erange=[-5, 5]) + + assert filtered_bands.E.min() >= -5 + assert filtered_bands.E.max() <= 5 + + if spin.is_polarized: + filtered_bands = filter_bands(bands_data, Erange=[-5, 5], spin=1) + + assert filtered_bands.E.min() >= -5 + assert filtered_bands.E.max() <= 5 + + assert filtered_bands.spin == 1 + + +def test_calculate_gap(bands_data, gap): + + spin = bands_data.attrs["spin"] + + gap_info = calculate_gap(bands_data) + + # Check that the gap value is correct + assert gap_info['gap'] == gap + + # Check also that the position of the gap is in the information + assert isinstance(gap_info['k'], tuple) and len(gap_info['k']) == 2 + + VB = len(bands_data.band) // 2 - 1 + assert isinstance(gap_info['bands'], tuple) and len(gap_info['bands']) == 2 + assert gap_info['bands'][0] < gap_info['bands'][1] + + assert isinstance(gap_info['spin'], tuple) and len(gap_info['spin']) == 2 + if not spin.is_polarized: + assert gap_info['spin'] == (0, 0) + + assert isinstance(gap_info['Es'], tuple) and len(gap_info['Es']) == 2 + assert np.allclose(gap_info['Es'], (- gap / 2, gap / 2)) + +def test_sanitize_k(bands_data): + + assert sanitize_k(bands_data, "Gamma") == 0 + assert sanitize_k(bands_data, "X") == 1 + +def test_get_gap_coords(bands_data): + + spin = bands_data.attrs["spin"] + + vb = len(bands_data.band) // 2 - 1 + + # We can get the gamma gap by specifying both origin and destination or + # just origin. Check also Gamma to X. + for to_k in ["Gamma", None, "X"]: + + k, E = get_gap_coords(bands_data, (vb, vb + 1), from_k="Gamma", to_k=to_k, spin=1) + + kval = 1 if to_k == "X" else 0 + + # Check that the K is correct + assert k[0] == 0 + assert k[1] == kval + + # Check that E is correct. + if spin.is_polarized: + bands_E = bands_data.E.sel(spin=1) + else: + bands_E = bands_data.E + + assert E[0] == bands_E.sel(band=vb, k=0) + assert E[1] == bands_E.sel(band=vb + 1, k=kval) + +def test_draw_gap(E_axis): + + ks = (0, 0.5) + Es = (0, 1) + + if E_axis == "x": + x, y = Es, ks + else: + x, y = ks, Es + + gap_action = draw_gap(ks, Es, color="red", name="test", E_axis=E_axis) + + assert isinstance(gap_action, dict) + assert gap_action["method"] == "draw_line" + + action_kwargs = gap_action["kwargs"] + + assert action_kwargs["name"] == "test" + assert action_kwargs["line"]["color"] == "red" + assert action_kwargs["marker"]["color"] == "red" + assert action_kwargs['x'] == x + assert action_kwargs['y'] == y + +@pytest.mark.parametrize("display_gap", [True, False]) +def test_draw_gaps(bands_data, E_axis, display_gap): + + spin = bands_data.attrs["spin"] + + gap_info = calculate_gap(bands_data) + + # Run the function only to draw the minimum gap. + gap_actions = draw_gaps( + bands_data, gap=display_gap, gap_info=gap_info, + gap_tol=0.3, gap_color="red", gap_marker={}, + direct_gaps_only=False, custom_gaps=[], E_axis=E_axis + ) + + assert isinstance(gap_actions, list) + assert len(gap_actions) == (1 if display_gap else 0) + + if display_gap: + assert isinstance(gap_actions[0], dict) + assert gap_actions[0]["method"] == "draw_line" + + action_kwargs = gap_actions[0]["kwargs"] + assert action_kwargs["line"]["color"] == "red" + assert action_kwargs["marker"]["color"] == "red" + + # Now run the function with a custom gap. + gap_actions = draw_gaps( + bands_data, gap=display_gap, gap_info=gap_info, + gap_tol=0.3, gap_color="red", gap_marker={}, + direct_gaps_only=False, + custom_gaps=[{"from": "Gamma", "to": "X", "color": "blue"}], + E_axis=E_axis + ) + + assert isinstance(gap_actions, list) + assert len(gap_actions) == (2 if display_gap else 1) + (1 if spin.is_polarized else 0) + + # Check the minimum gap + if display_gap: + assert isinstance(gap_actions[0], dict) + assert gap_actions[0]["method"] == "draw_line" + + action_kwargs = gap_actions[0]["kwargs"] + assert action_kwargs["line"]["color"] == "red" + assert action_kwargs["marker"]["color"] == "red" + + # Check the custom gap + assert isinstance(gap_actions[-1], dict) + assert gap_actions[-1]["method"] == "draw_line" + + action_kwargs = gap_actions[-1]["kwargs"] + assert action_kwargs["line"]["color"] == "blue" + assert action_kwargs["marker"]["color"] == "blue" + assert action_kwargs['x' if E_axis == "y" else "y"] == (0, 1) + +def test_style_bands(bands_data): + + spin = bands_data.attrs["spin"] + + # Check basic styles + styled_bands = style_bands( + bands_data, {"color": "red", "width": 3}, + spindown_style={"opacity": 0.5, "color": "blue"} + ) + + assert isinstance(styled_bands, xr.Dataset) + + for k in ("color", "width", "opacity"): + assert k in styled_bands.data_vars + + if not spin.is_polarized: + assert styled_bands.color == "red" + assert styled_bands.width == 3 + assert styled_bands.opacity == 1 + else: + assert np.all(styled_bands.color == ["red", "blue"]) + assert np.all(styled_bands.width == [3, 3]) + assert np.all(styled_bands.opacity == [1, 0.5]) + + # Check function as style + def color(data): + return xr.DataArray( + np.where(data.band < 5, "red", "blue"), + coords=[("band", data.band.values)] + ) + + styled_bands = style_bands( + bands_data, {"color": color, "width": 3}, + spindown_style={"opacity": 0.5, "color": "blue"} + ) + + assert isinstance(styled_bands, xr.Dataset) + + for k in ("color", "width", "opacity"): + assert k in styled_bands.data_vars + + assert "band" in styled_bands.color.coords + if spin.is_polarized: + bands_color = styled_bands.color.sel(spin=0) + assert np.all(styled_bands.color.sel(spin=1) == "blue") + else: + bands_color = styled_bands.color + assert np.all((styled_bands.band < 5) == (bands_color == "red")) + + + + + + diff --git a/src/sisl/viz/processors/tests/test_cell.py b/src/sisl/viz/processors/tests/test_cell.py new file mode 100644 index 0000000000..0f20c3dd64 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_cell.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest +import xarray as xr + +from sisl import Lattice +from sisl.viz.processors.cell import ( + cell_to_lines, + gen_cell_dataset, + infer_cell_axes, + is_1D_cartesian, + is_cartesian_unordered, +) + + +@pytest.fixture(scope="module", params=["numpy", "lattice"]) +def Cell(request): + + if request.param == "numpy": + return np.array + elif request.param == "lattice": + return Lattice + +def test_cartesian_unordered(Cell): + + assert is_cartesian_unordered( + Cell([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + ) + + assert is_cartesian_unordered( + Cell([[0, 0, 1], [1, 0, 0], [0, 1, 0]]), + ) + + assert not is_cartesian_unordered( + Cell([[0, 2, 1], [1, 0, 0], [0, 1, 0]]), + ) + +def test_1D_cartesian(Cell): + + assert is_1D_cartesian( + Cell([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + "x" + ) + + assert is_1D_cartesian( + Cell([[1, 0, 0], [0, 1, 1], [0, 0, 1]]), + "x" + ) + + assert not is_1D_cartesian( + Cell([[1, 0, 0], [1, 1, 0], [0, 0, 1]]), + "x" + ) + +def test_infer_cell_axes(Cell): + + assert infer_cell_axes( + Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), + axes=["x", "y", "z"] + ) == [1, 0, 2] + + assert infer_cell_axes( + Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), + axes=["b", "y"] + ) == [1, 0] + +def test_gen_cell_dataset(): + + lattice = Lattice([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + cell_dataset = gen_cell_dataset(lattice) + + assert isinstance(cell_dataset, xr.Dataset) + + assert "lattice" in cell_dataset.attrs + assert cell_dataset.attrs["lattice"] is lattice + + assert "xyz" in cell_dataset.data_vars + assert cell_dataset.xyz.shape == (2, 2, 2, 3) + assert np.all(cell_dataset.xyz.values == lattice.vertices()) + +@pytest.mark.parametrize("mode", ["box", "axes", "other"]) +def test_cell_to_lines(mode): + + lattice = Lattice([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + cell_dataset = gen_cell_dataset(lattice) + + if mode == "other": + with pytest.raises(ValueError): + cell_to_lines(cell_dataset, mode) + else: + lines = cell_to_lines(cell_dataset, mode, cell_style={"color": "red"}) + + assert isinstance(lines, xr.Dataset) + + if mode == "box": + # 19 points are required to draw a box + assert lines.xyz.shape == (19, 3) + elif mode == "axes": + # 9 points are required to draw the axes + assert lines.xyz.shape == (9, 3) \ No newline at end of file diff --git a/src/sisl/viz/processors/tests/test_coords.py b/src/sisl/viz/processors/tests/test_coords.py new file mode 100644 index 0000000000..7044525c26 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_coords.py @@ -0,0 +1,219 @@ +import numpy as np +import pytest +import xarray as xr + +import sisl +from sisl import Lattice +from sisl.viz.processors.coords import ( + coords_depth, + project_to_axes, + projected_1D_data, + projected_1Dcoords, + projected_2D_data, + projected_2Dcoords, + projected_3D_data, + sphere, +) + + +@pytest.fixture(scope="module", params=["numpy", "lattice"]) +def Cell(request): + + if request.param == "numpy": + return np.array + elif request.param == "lattice": + return Lattice + +@pytest.fixture(scope="module") +def coords_dataset(): + geometry = sisl.geom.bcc(2.93, "Au", False) + + return xr.Dataset( + {"xyz": (("atom", "axis"), geometry.xyz)}, + coords={"axis": [0,1,2]}, + attrs={"geometry": geometry} + ) + + +def test_projected_1D_coords(Cell): + + cell = Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + + x, y, z = 3, -4, 2 + + coords = np.array([[x, y, z]]) + + # Project to cartesian + projected = projected_1Dcoords(cell, coords, "x") + assert np.allclose(projected, [[x]]) + + projected = projected_1Dcoords(cell, coords, "-y") + assert np.allclose(projected, [[-y]]) + + # Project to lattice + projected = projected_1Dcoords(cell, coords, "b") + assert np.allclose(projected, [[x]]) + + # Project to vector + projected = projected_1Dcoords(cell, coords, [x, 0, z]) + assert np.allclose(projected, [[1]]) + + +def test_projected_2D_coords(Cell): + + cell = Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + + x, y, z = 3, -4, 2 + + coords = np.array([[x, y, z]]) + + # Project to cartesian + projected = projected_2Dcoords(cell, coords, "x", "y") + assert np.allclose(projected, [[x, y]]) + + projected = projected_2Dcoords(cell, coords, "-x", "y") + assert np.allclose(projected, [[-x, y]]) + + projected = projected_2Dcoords(cell, coords, "z", "x") + assert np.allclose(projected, [[z, x]]) + + # Project to lattice + projected = projected_2Dcoords(cell, coords, "a", "b") + assert np.allclose(projected, [[y, x]]) + + projected = projected_2Dcoords(cell, coords, "-b", "a") + assert np.allclose(projected, [[-x, y]]) + + # Project to vectors + projected = projected_2Dcoords(cell, coords, [x, y, 0], [0, 0, z]) + assert np.allclose(projected, [[1, 1]]) + +def test_coords_depth(coords_dataset): + + depth = coords_depth(coords_dataset, ["x", "y"]) + assert isinstance(depth, np.ndarray) + assert np.allclose(depth, coords_dataset.xyz.sel(axis=2).values) + + depth = coords_depth(coords_dataset, ["y", "x"]) + assert np.allclose(depth, - coords_dataset.xyz.sel(axis=2).values) + + depth = coords_depth(coords_dataset, [[1, 0, 0], [0, 0, 1]]) + assert np.allclose(depth, - coords_dataset.xyz.sel(axis=1).values) + +@pytest.mark.parametrize("center", [[0, 0, 0], [1, 1, 0]]) +def test_sphere(center): + + coords = sphere(center=center, r=3.5, vertices=15) + + assert isinstance(coords, dict) + + assert "x" in coords + assert "y" in coords + assert "z" in coords + + assert coords["x"].shape == coords["y"].shape == coords["z"].shape == (15 ** 2,) + + R = np.linalg.norm(np.array([coords["x"], coords["y"], coords["z"]]).T - center, axis=1) + + assert np.allclose(R, 3.5) + +def test_projected_1D_data(coords_dataset): + + # No data + projected = projected_1D_data(coords_dataset, "y") + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, coords_dataset.xyz.sel(axis=1)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, 0) + + # Data from function + projected = projected_1D_data(coords_dataset, "-y", dataaxis_1d=np.sin) + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, np.sin(- coords_dataset.xyz.sel(axis=1))) + + # Data from array + projected = projected_1D_data(coords_dataset, "-y", dataaxis_1d=coords_dataset.xyz.sel(axis=2).values) + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=2)) + +def test_projected_2D_data(coords_dataset): + + projected = projected_2D_data(coords_dataset, "-y", "x") + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=0)) + + assert "depth" in projected.data_vars + + projected = projected_2D_data(coords_dataset, "x", "y", sort_by_depth=True) + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, coords_dataset.xyz.sel(axis=0)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=1)) + + assert "depth" in projected.data_vars + assert np.allclose(projected.depth.values, coords_dataset.xyz.sel(axis=2)) + # Check that points are sorted by depth. + assert np.all(np.diff(projected.depth) > 0) + +def test_projected_3D_data(coords_dataset): + + projected = projected_3D_data(coords_dataset) + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, coords_dataset.xyz.sel(axis=0)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=1)) + assert "z" in projected.data_vars + assert np.allclose(projected.z, coords_dataset.xyz.sel(axis=2)) + +def test_project_to_axes(coords_dataset): + + projected = project_to_axes(coords_dataset, ["z"], dataaxis_1d=4) + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, coords_dataset.xyz.sel(axis=2)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, 4) + assert "z" not in projected.data_vars + + projected = project_to_axes(coords_dataset, ["-y", "x"]) + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=0)) + assert "z" not in projected.data_vars + + projected = project_to_axes(coords_dataset, ["x", "y", "z"]) + assert isinstance(projected, xr.Dataset) + assert "x" in projected.data_vars + assert np.allclose(projected.x, coords_dataset.xyz.sel(axis=0)) + assert "y" in projected.data_vars + assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=1)) + assert "z" in projected.data_vars + assert np.allclose(projected.z, coords_dataset.xyz.sel(axis=2)) + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/sisl/viz/processors/tests/test_data.py b/src/sisl/viz/processors/tests/test_data.py new file mode 100644 index 0000000000..baec458458 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_data.py @@ -0,0 +1,54 @@ +import pytest + +from sisl.viz.data import Data +from sisl.viz.processors.data import accept_data, extract_data + + +class FakeData(Data): + def __init__(self, valid: bool = True): + self._data = valid + + def sanity_check(self): + assert self._data == True + +class OtherData(Data): + pass + +@pytest.mark.parametrize("valid", [True, False]) +def test_accept_data(valid): + + data = FakeData(valid) + + # If the input is an instance of an invalid class + with pytest.raises(TypeError): + accept_data(data, OtherData) + + # Perform a sanity check on data + if valid: + assert accept_data(data, FakeData) is data + else: + with pytest.raises(AssertionError): + accept_data(data, FakeData) + + # Don't perform a sanity check on data + assert accept_data(data, FakeData, check=False) is data + +@pytest.mark.parametrize("valid", [True, False]) +def test_extract_data(valid): + + data = FakeData(valid) + + # If the input is an instance of an invalid class + with pytest.raises(TypeError): + extract_data(data, OtherData) + + # Perform a sanity check on data + if valid: + assert extract_data(data, FakeData) is data._data + else: + with pytest.raises(AssertionError): + extract_data(data, FakeData) + + # Don't perform a sanity check on data + assert extract_data(data, FakeData, check=False) is data._data + \ No newline at end of file diff --git a/src/sisl/viz/processors/tests/test_eigenstate.py b/src/sisl/viz/processors/tests/test_eigenstate.py new file mode 100644 index 0000000000..8c7405e613 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_eigenstate.py @@ -0,0 +1,125 @@ +import numpy as np +import pytest + +import sisl +from sisl.viz.processors.eigenstate import ( + create_wf_grid, + eigenstate_geometry, + get_eigenstate, + get_grid_nsc, + project_wavefunction, + tile_if_k, +) + + +@pytest.fixture(scope="module", params=["Gamma", "X"]) +def k(request): + if request.param == "Gamma": + return (0, 0, 0) + elif request.param == "X": + return (0.5, 0, 0) + +@pytest.fixture(scope="module") +def graphene(): + + r = np.linspace(0, 3.5, 50) + f = np.exp(-r) + + orb = sisl.AtomicOrbital('2pzZ', (r, f)) + return sisl.geom.graphene(orthogonal=True, atoms=sisl.Atom(6, orb)) + +@pytest.fixture(scope="module") +def eigenstate(k, graphene): + # Create a simple graphene tight binding Hamiltonian + H = sisl.Hamiltonian(graphene) + H.construct([(0.1, 1.44), (0, -2.7)]) + + return H.eigenstate(k=k) + +def test_get_eigenstate(eigenstate, graphene): + + sel_eigenstate = get_eigenstate(eigenstate, 2) + + assert sel_eigenstate.state.shape == (1, graphene.no) + assert np.allclose(sel_eigenstate.state, eigenstate.state[2]) + + eigenstate = eigenstate.copy() + + eigenstate.info["index"] = np.array([0, 3, 1, 2]) + + sel_eigenstate = get_eigenstate(eigenstate, 2) + + assert sel_eigenstate.state.shape == (1, graphene.no) + assert np.allclose(sel_eigenstate.state, eigenstate.state[3]) + +def test_eigenstate_geometry(eigenstate, graphene): + + # It should give us the geometry associated with the eigenstate + assert eigenstate_geometry(eigenstate) is graphene + + # Unless we provide a geometry + graphene_copy = graphene.copy() + assert eigenstate_geometry(eigenstate, graphene_copy) is graphene_copy + +def test_tile_if_k(eigenstate, graphene): + + # If the eigenstate is calculated at gamma, we don't need to tile + tiled_geometry = tile_if_k(graphene, (2, 2, 2), eigenstate) + + if eigenstate.info["k"] == (0,0,0): + # If the eigenstate is calculated at gamma, we don't need to tile + assert tiled_geometry is graphene + elif eigenstate.info["k"] == (0.5, 0, 0): + # If the eigenstate is calculated at X, we need to tile + # but only the first lattice vector. + assert tiled_geometry is not graphene + assert np.allclose(tiled_geometry.cell, graphene.cell * (2, 1, 1)) + +def test_get_grid_nsc(eigenstate): + + grid_nsc = get_grid_nsc((2, 2, 2), eigenstate) + + if eigenstate.info["k"] == (0,0,0): + assert grid_nsc == (2, 2, 2) + elif eigenstate.info["k"] == (0.5, 0, 0): + assert grid_nsc == (1, 2, 2) + +def test_create_wf_grid(eigenstate, graphene): + + new_graphene = graphene.copy() + grid = create_wf_grid(eigenstate, grid_prec=0.2, geometry=new_graphene) + + assert isinstance(grid, sisl.Grid) + assert grid.geometry is new_graphene + + # Check that the datatype is correct + if eigenstate.info["k"] == (0,0,0): + assert grid.grid.dtype == np.float64 + else: + assert grid.grid.dtype == np.complex128 + + # Check that the grid precision is right. + assert np.allclose(np.linalg.norm(grid.dcell, axis=1), 0.2, atol=0.01) + + provided_grid = sisl.Grid(0.2, geometry=new_graphene, dtype=np.float64) + + grid = create_wf_grid(eigenstate, grid=provided_grid) + + assert grid is provided_grid + +def test_project_wavefunction(eigenstate, graphene): + + k = eigenstate.info["k"] + + grid = project_wavefunction(eigenstate[2], geometry=graphene) + + assert isinstance(grid, sisl.Grid) + + # Check that the datatype is correct + if k == (0,0,0): + assert grid.grid.dtype == np.float64 + else: + assert grid.grid.dtype == np.complex128 + + # Check that the grid is not empty + assert not np.allclose(grid.grid, 0) \ No newline at end of file diff --git a/src/sisl/viz/processors/tests/test_geometry.py b/src/sisl/viz/processors/tests/test_geometry.py new file mode 100644 index 0000000000..4f596de25d --- /dev/null +++ b/src/sisl/viz/processors/tests/test_geometry.py @@ -0,0 +1,306 @@ +import numpy as np +import pytest +import xarray as xr + +import sisl +from sisl.viz.processors.geometry import ( + add_xyz_to_bonds_dataset, + add_xyz_to_dataset, + bonds_to_lines, + find_all_bonds, + get_atoms_bonds, + parse_atoms_style, + sanitize_arrows, + sanitize_atoms, + sanitize_bonds_selection, + stack_sc_data, + style_bonds, + tile_data_sc, + tile_geometry, +) + + +@pytest.fixture(scope="module") +def geometry(): + return sisl.geom.bcc(2.93, "Au", True) + +@pytest.fixture(scope="module") +def coords_dataset(geometry): + + return xr.Dataset( + {"xyz": (("atom", "axis"), geometry.xyz)}, + coords={"axis": [0,1,2]}, + attrs={"geometry": geometry} + ) + +def test_tile_geometry(): + geom = sisl.geom.graphene() + + tiled_geometry = tile_geometry(geom, (2, 3, 5)) + + assert np.allclose(tiled_geometry.cell.T, geom.cell.T * (2, 3, 5)) + +def test_find_all_bonds(): + geom = sisl.geom.graphene() + + bonds = find_all_bonds(geom, 1.5) + + assert isinstance(bonds, xr.Dataset) + + assert "geometry" in bonds.attrs + assert bonds.attrs["geometry"] is geom + + assert "bonds" in bonds.data_vars + + assert bonds.bonds.shape == (23, 2) + + # Now get bonds only for the unit cell + geom.set_nsc([1,1,1]) + bonds = find_all_bonds(geom, 1.5) + + assert bonds.bonds.shape == (1, 2) + assert np.all(bonds.bonds == (0,1)) + + # Run function with just one atom + bonds = find_all_bonds(geom.sub(0), 1.5) + +def test_get_atom_bonds(): + + bonds = np.array([[0,1], [0,2], [1,2]]) + + mask = get_atoms_bonds(bonds, [0], ret_mask=True) + + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert np.all(mask == [True, True, False]) + + atom_bonds = get_atoms_bonds(bonds, [0]) + + assert isinstance(atom_bonds, np.ndarray) + assert atom_bonds.shape == (2, 2) + assert np.all(atom_bonds == [[0,1], [0,2]]) + +def test_sanitize_atoms(): + + geom = sisl.geom.graphene() + + sanitized = sanitize_atoms(geom, 3) + + assert len(sanitized) == 1 + assert sanitized[0] == 3 + +def test_data_sc(coords_dataset): + + assert "isc" not in coords_dataset.dims + + # First, check that not tiling works as expected + tiled = tile_data_sc(coords_dataset, nsc=(1, 1, 1)) + assert "isc" in tiled.dims + assert len(tiled.isc) == 1 + assert np.all(tiled.sel(isc=0).xyz == coords_dataset.xyz) + + # Now, check that tiling works as expected + tiled = tile_data_sc(coords_dataset, nsc=(2, 1, 1)) + + assert "isc" in tiled.dims + assert len(tiled.isc) == 2 + assert np.allclose(tiled.sel(isc=0).xyz, coords_dataset.xyz) + assert np.allclose(tiled.sel(isc=1).xyz, coords_dataset.xyz + coords_dataset.attrs["geometry"].cell[0]) + +def test_stack_sc_data(coords_dataset): + + tiled = tile_data_sc(coords_dataset, nsc=(3, 3, 1)) + + assert "isc" in tiled.dims + + stacked = stack_sc_data(tiled, newname="sc_atom", dims=["atom"]) + + assert "isc" not in stacked.dims + assert "sc_atom" in stacked.dims + assert len(stacked.sc_atom) == 9 * len(coords_dataset.atom) + +@pytest.mark.parametrize("data_type", [list, dict]) +def test_parse_atoms_style_empty(data_type): + g = sisl.geom.graphene() + styles = parse_atoms_style(g, data_type()) + + assert isinstance(styles, xr.Dataset) + + assert "atom" in styles.coords + assert len(styles.coords["atom"]) == 2 + + for data_var in styles.data_vars: + assert len(styles[data_var].shape) == 0 + +@pytest.mark.parametrize("data_type", [list, dict]) +def test_parse_atoms_style_single_values(data_type): + g = sisl.geom.graphene() + + unparsed = {"color": "green", "size": 14} + if data_type == list: + unparsed = [unparsed] + + styles = parse_atoms_style(g, unparsed) + + assert isinstance(styles, xr.Dataset) + + assert "atom" in styles.coords + assert len(styles.coords["atom"]) == 2 + + for data_var in styles.data_vars: + assert len(styles[data_var].shape) == 0 + + if data_var == "color": + assert styles[data_var].values == "green" + elif data_var == "size": + assert styles[data_var].values == 14 + +def test_add_xyz_to_dataset(geometry): + + parsed_atoms_style = parse_atoms_style(geometry, {"color": "green", "size": 14}) + + atoms_dataset = add_xyz_to_dataset(parsed_atoms_style) + + assert isinstance(atoms_dataset, xr.Dataset) + + assert "xyz" in atoms_dataset.data_vars + assert atoms_dataset.xyz.shape == (geometry.na, 3) + assert np.allclose(atoms_dataset.xyz, geometry.xyz) + +@pytest.mark.parametrize("data_type", [list, dict]) +def test_sanitize_arrows_empty(data_type): + g = sisl.geom.graphene() + arrows = sanitize_arrows(g, data_type(), atoms=None, ndim=3, axes="xyz" ) + + assert isinstance(arrows, list) + + assert len(arrows) == 0 + +def test_sanitize_arrows(): + + data = np.array([[0,0,0],[1,1,1]]) + + g = sisl.geom.graphene() + + unparsed = [{"data": data}] + arrows = sanitize_arrows(g, unparsed, atoms=None, ndim=3, axes="xyz" ) + + assert isinstance(arrows, list) + assert np.allclose(arrows[0]['data'], data) + + arrows_from_dict = sanitize_arrows(g, unparsed[0], atoms=None, ndim=3, axes="xyz" ) + assert isinstance(arrows_from_dict, list) + + for k, v in arrows[0].items(): + if not isinstance(v, np.ndarray): + assert arrows[0][k] == arrows_from_dict[0][k] + +def test_style_bonds(geometry): + + bonds = find_all_bonds(geometry, 1.5) + + # Test no styles + styled_bonds = style_bonds(bonds, {}) + + assert isinstance(styled_bonds, xr.Dataset) + assert "bonds" in styled_bonds.data_vars + for k in ("color", "width", "opacity"): + assert k in styled_bonds.data_vars, f"Missing {k}" + assert styled_bonds[k].shape == (), f"Wrong shape for {k}" + + # Test single values + styles = {"color": "green", "width": 14, "opacity": 0.2} + styled_bonds = style_bonds(bonds, styles) + assert isinstance(styled_bonds, xr.Dataset) + assert "bonds" in styled_bonds.data_vars + for k in ("color", "width", "opacity"): + assert k in styled_bonds.data_vars, f"Missing {k}" + assert styled_bonds[k].shape == (), f"Wrong shape for {k}" + assert styled_bonds[k].values == styles[k], f"Wrong value for {k}" + + # Test callable + def some_property(geometry, bonds): + return np.arange(len(bonds)) + + styles = {"color": some_property, "width": some_property, "opacity": some_property} + styled_bonds = style_bonds(bonds, styles) + assert isinstance(styled_bonds, xr.Dataset) + assert "bonds" in styled_bonds.data_vars + for k in ("color", "width", "opacity"): + assert k in styled_bonds.data_vars, f"Missing {k}" + assert styled_bonds[k].shape == (len(bonds.bonds),), f"Wrong shape for {k}" + assert np.all(styled_bonds[k].values == np.arange(len(bonds.bonds))), f"Wrong value for {k}" + + # Test scale + styles = {"color": some_property, "width": some_property, "opacity": some_property} + styled_bonds = style_bonds(bonds, styles, scale=2) + assert isinstance(styled_bonds, xr.Dataset) + assert "bonds" in styled_bonds.data_vars + for k in ("color", "width", "opacity"): + assert k in styled_bonds.data_vars, f"Missing {k}" + assert styled_bonds[k].shape == (len(bonds.bonds),), f"Wrong shape for {k}" + if k == "width": + assert np.all(styled_bonds[k].values == 2 * np.arange(len(bonds.bonds))), f"Wrong value for {k}" + else: + assert np.all(styled_bonds[k].values == np.arange(len(bonds.bonds))), f"Wrong value for {k}" + +def test_add_xyz_to_bonds_dataset(geometry): + + bonds = find_all_bonds(geometry, 1.5) + + xyz_bonds = add_xyz_to_bonds_dataset(bonds) + + assert isinstance(xyz_bonds, xr.Dataset) + assert "xyz" in xyz_bonds.data_vars + assert xyz_bonds.xyz.shape == (len(bonds.bonds), 2, 3) + assert np.allclose(xyz_bonds.xyz[:, 0], geometry.xyz[bonds.bonds[:, 0]]) + +def test_sanitize_bonds_selection(geometry): + + bonds = find_all_bonds(geometry, 1.5) + + # No selection + assert sanitize_bonds_selection(bonds) is None + + # No bonds + bonds_sel = sanitize_bonds_selection(bonds, show_bonds=False) + assert isinstance(bonds_sel, np.ndarray) + assert len(bonds_sel) == 0 + + # Assert not bound to atoms + assert sanitize_bonds_selection(bonds, atoms=[0], bind_bonds_to_ats=False) is None + + # Assert bind to atoms. We check that all selected bonds have the only + # requested atom. + bonds_sel = sanitize_bonds_selection(bonds, atoms=[0], bind_bonds_to_ats=True) + + assert isinstance(bonds_sel, np.ndarray) + assert (bonds.sel(bond_index=bonds_sel) == 0).any("bond_atom").all("bond_index") + +def test_bonds_to_lines(geometry): + + bonds = find_all_bonds(geometry, 1.5) + xyz_bonds = add_xyz_to_bonds_dataset(bonds) + + assert isinstance(xyz_bonds, xr.Dataset) + assert "bond_atom" in xyz_bonds.dims + assert len(xyz_bonds.bond_atom) == 2 + + # No interpolation. Nan is added between bonds. + bond_lines = bonds_to_lines(xyz_bonds) + assert isinstance(bond_lines, xr.Dataset) + assert "point_index" in bond_lines.dims + assert len(bond_lines.point_index) == len(xyz_bonds.bond_index) * 3 + + # Interpolation. + bond_lines = bonds_to_lines(xyz_bonds, points_per_bond=10) + assert isinstance(bond_lines, xr.Dataset) + assert "point_index" in bond_lines.dims + assert len(bond_lines.point_index) == len(xyz_bonds.bond_index) * 11 + + + + + + + diff --git a/src/sisl/viz/processors/tests/test_grid.py b/src/sisl/viz/processors/tests/test_grid.py new file mode 100644 index 0000000000..3b1cad6e46 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_grid.py @@ -0,0 +1,387 @@ +import numpy as np +import pytest +import xarray as xr + +from sisl import Geometry, Grid, Lattice +from sisl.viz.processors.grid import ( + apply_transforms, + get_ax_vals, + get_grid_axes, + get_grid_representation, + get_isos, + get_offset, + grid_geometry, + grid_to_dataarray, + interpolate_grid, + orthogonalize_grid, + orthogonalize_grid_if_needed, + reduce_grid, + should_transform_grid_cell_plotting, + sub_grid, + tile_grid, + transform_grid_cell, +) + + +@pytest.fixture(scope="module", params=["orthogonal", "skewed"]) +def skewed(request) -> bool: + return request.param == "skewed" + +real_part = np.arange(10*10*10).reshape(10,10,10) +imag_part = np.arange(10*10*10).reshape(10,10,10) + 1 + +@pytest.fixture(scope="module") +def grid(skewed) -> Grid: + + if skewed: + lattice = Lattice([[3, 0, 0], [1, -1, 0], [0, 0, 3]]) + else: + lattice = Lattice([[3, 0, 0], [0, 2, 0], [0, 0, 6]]) + + geometry = Geometry([[0, 0, 0]], lattice=lattice) + grid = Grid([10, 10, 10], geometry=geometry, dtype=np.complex128) + + grid.grid[:] = ( real_part + imag_part * 1j).reshape(10, 10, 10) + + return grid + +def test_get_grid_representation(grid): + + assert np.allclose(get_grid_representation(grid, "real").grid, real_part) + assert np.allclose(get_grid_representation(grid, "imag").grid, imag_part) + assert np.allclose(get_grid_representation(grid, "mod").grid, np.sqrt(real_part**2 + imag_part**2)) + assert np.allclose(get_grid_representation(grid, "phase").grid, np.arctan2(imag_part, real_part)) + assert np.allclose(get_grid_representation(grid, "rad_phase").grid, np.arctan2(imag_part, real_part)) + assert np.allclose(get_grid_representation(grid, "deg_phase").grid, np.arctan2(imag_part, real_part) * 180 / np.pi) + +def test_tile_grid(grid): + + # By default it is not tiled + tiled = tile_grid(grid) + assert isinstance(tiled, Grid) + assert tiled.shape == grid.shape + assert np.allclose(tiled.grid, grid.grid) + + # Now tile it + tiled = tile_grid(grid, (1, 2, 1)) + assert isinstance(tiled, Grid) + assert tiled.shape == (grid.shape[0], grid.shape[1] * 2, grid.shape[2]) + assert np.allclose(tiled.grid[:, :grid.shape[1]], grid.grid) + assert np.allclose(tiled.grid[:, grid.shape[1]:], grid.grid) + +def test_transform_grid_cell(grid, skewed): + + # Convert to a cartesian cell + new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10)) + + assert new_grid.shape == (10, 10, 10) + assert new_grid.lattice.is_cartesian() + + if not skewed: + assert np.allclose(new_grid.lattice.cell, grid.lattice.cell) + assert np.allclose(new_grid.grid, grid.grid) + else: + assert not np.allclose(new_grid.grid, grid.grid) + + assert not np.allclose(new_grid.grid, 0) + + # Convert to a skewed cell + directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]]) + new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5)) + + assert new_grid.shape == (5, 5, 5) + for i in range(3): + n = new_grid.lattice.cell[i] / directions[i] + assert np.allclose(n, n[0]) + +@pytest.mark.parametrize("interp", [1, 2]) +def test_orthogonalize_grid(grid, interp, skewed): + + ort_grid = orthogonalize_grid(grid, interp=(interp, interp, interp)) + + assert ort_grid.shape == (10, 10, 10) if interp == 1 else (20, 20, 20) + assert ort_grid.lattice.is_cartesian() + + if not skewed: + assert np.allclose(ort_grid.lattice.cell, grid.lattice.cell) + if interp == 1: + assert np.allclose(ort_grid.grid, grid.grid) + else: + if interp == 1: + assert not np.allclose(ort_grid.grid, grid.grid) + + assert not np.allclose(ort_grid.grid, 0) + +def test_should_transform_grid_cell_plotting(grid, skewed): + + assert should_transform_grid_cell_plotting(grid, axes=["x", "y"]) == skewed + assert should_transform_grid_cell_plotting(grid, axes=["z"]) == False + +@pytest.mark.parametrize("interp", [1, 2]) +def test_orthogonalize_grid_if_needed(grid, skewed, interp): + + # Orthogonalize the skewed cell, since it is xy skewed. + ort_grid = orthogonalize_grid_if_needed(grid, axes=["x", "y"], interp=(interp, interp, interp)) + + assert ort_grid.shape == (10, 10, 10) if interp == 1 else (20, 20, 20) + assert ort_grid.lattice.is_cartesian() + + if not skewed: + assert np.allclose(ort_grid.lattice.cell, grid.lattice.cell) + if interp == 1: + assert np.allclose(ort_grid.grid, grid.grid) + else: + if interp == 1: + assert not np.allclose(ort_grid.grid, grid.grid) + + assert not np.allclose(ort_grid.grid, 0) + + # Do not orthogonalize the skewed cell, since it is not z skewed. + ort_grid = orthogonalize_grid_if_needed(grid, axes=["z"], interp=(interp, interp, interp)) + + assert ort_grid.shape == (10, 10, 10) if interp == 1 else (20, 20, 20) + + if skewed: + assert not ort_grid.lattice.is_cartesian() + + assert np.allclose(ort_grid.lattice.cell, grid.lattice.cell) + assert np.allclose(ort_grid.grid, grid.grid) + +def test_apply_transforms(grid): + + # Apply a function + transf = apply_transforms(grid, transforms=[np.sqrt]) + assert np.allclose(transf.grid, np.sqrt(grid.grid)) + + # Apply a numpy function specifying a string + transf = apply_transforms(grid, transforms=["sqrt"]) + assert np.allclose(transf.grid, np.sqrt(grid.grid)) + + # Apply two consecutive functions + transf = apply_transforms(grid, transforms=[np.angle, "sqrt"]) + assert np.allclose(transf.grid, np.sqrt(np.angle(grid.grid))) + +@pytest.mark.parametrize("reduce_method", ["sum", "mean"]) +def test_reduce_grid(grid, reduce_method): + + reduce_func = { + "sum": np.sum, + "mean": np.mean + }[reduce_method] + + reduced = reduce_grid(grid, reduce_method, keep_axes=[0, 1]) + + assert reduced.shape == (10, 10, 1) + assert np.allclose(reduced.grid[:, :, 0], reduce_func(grid.grid, axis=2)) + +@pytest.mark.parametrize("direction", ["x", "y", "z"]) +def test_sub_grid(grid, skewed, direction): + + coord_ax = "xyz".index(direction) + kwargs = {f"{direction}_range": (0.5, 1.5)} + + if skewed and direction != "z": + with pytest.raises(ValueError): + sub = sub_grid(grid, **kwargs, cart_tol=1e-3) + else: + sub = sub_grid(grid, **kwargs, cart_tol=1e-3) + + # Check that the lattice has been reduced to contain the requested range, + # taking into account that the bounds of the range might not be exactly + # on the grid points. + assert 1 + sub.dcell[:, coord_ax].sum()*2 >= sub.lattice.cell[:, coord_ax].sum() >= 1 - sub.dcell[:, coord_ax].sum()*2 + +def test_interpolate_grid(grid): + + interp = interpolate_grid(grid, (20, 20, 20)) + + # Check that the shape has been augmented + assert np.all(interp.shape == np.array((20, 20, 20)) * grid.shape) + + # The integral over the grid should be the same (or very similar) + assert (grid.grid.sum() * 20**3 - interp.grid.sum()) < 1e-3 + +def test_grid_geometry(grid): + + assert grid_geometry(grid) is grid.geometry + + geom_copy = grid.geometry.copy() + + assert grid_geometry(grid, geom_copy) is geom_copy + +def test_get_grid_axes(grid, skewed): + + assert get_grid_axes(grid, ['x', 'y', 'z']) == [0, 1, 2] + # This function doesn't care about what the axes are in 3D + assert get_grid_axes(grid, ['y', '-x', 'z']) == [0, 1, 2] + + if skewed: + with pytest.raises(ValueError): + get_grid_axes(grid, ['x', 'y']) + else: + assert get_grid_axes(grid, ['x', 'y']) == [0, 1] + assert get_grid_axes(grid, ['y', 'x']) == [1, 0] + +def test_get_ax_vals(grid, skewed): + + r = get_ax_vals(grid, "x", nsc=(1, 1, 1)) + + assert isinstance(r, np.ndarray) + assert r.shape == (grid.shape[0], ) + + if not skewed: + assert r[0] == 0 + assert abs(r[-1] - (grid.lattice.cell[0, 0] - grid.dcell[0, 0])) < 1e-3 + + r = get_ax_vals(grid, "a", nsc=(2, 1, 1)) + + assert isinstance(r, np.ndarray) + assert r.shape == (grid.shape[0], ) + + assert r[0] == 0 + assert abs(r[-1] - 2) < 1e-3 + +def test_get_offset(grid): + + assert get_offset(grid, "x") == 0 + assert get_offset(grid, "b") == 0 + assert get_offset(grid, 2) == 0 + + off_grid = grid.copy() + off_grid.lattice.origin = [1, 2, 3] + + assert get_offset(off_grid, "x") == 1 + assert get_offset(off_grid, "b") == 0 + assert get_offset(off_grid, 2) == 0 + +def test_grid_to_dataarray(grid, skewed): + # Test 1D + av_grid = grid.average(0).average(1) + + arr = grid_to_dataarray(av_grid, ['z'], [2], nsc=(1,1,1)) + + assert isinstance(arr, xr.DataArray) + assert len(arr.coords) == 1 + assert "x" in arr.coords + assert arr.x.shape == (grid.shape[2], ) + + assert np.allclose(arr.values, av_grid.grid[0, 0, :]) + + if skewed: + return + + # Test 2D + av_grid = grid.average(0) + + arr = grid_to_dataarray(av_grid, ['y', 'z'], [1, 2], nsc=(1,1,1)) + + assert isinstance(arr, xr.DataArray) + assert len(arr.coords) == 2 + assert "x" in arr.coords + assert arr.x.shape == (grid.shape[1], ) + assert "y" in arr.coords + assert arr.y.shape == (grid.shape[2], ) + + assert np.allclose(arr.values, av_grid.grid[0, :, :]) + + # Test 2D with unordered axes + av_grid = grid.average(0) + + arr = grid_to_dataarray(av_grid, ['z', 'y'], [2, 1], nsc=(1,1,1)) + + assert isinstance(arr, xr.DataArray) + assert len(arr.coords) == 2 + assert "x" in arr.coords + assert arr.x.shape == (grid.shape[2], ) + assert "y" in arr.coords + assert arr.y.shape == (grid.shape[1], ) + + assert np.allclose(arr.values, av_grid.grid[0, :, :].T) + + # Test 3D + av_grid = grid + + arr = grid_to_dataarray(av_grid, ['x', 'y', 'z'], [0, 1, 2], nsc=(1,1,1)) + + assert isinstance(arr, xr.DataArray) + assert len(arr.coords) == 3 + assert "x" in arr.coords + assert arr.x.shape == (grid.shape[0], ) + assert "y" in arr.coords + assert arr.y.shape == (grid.shape[1], ) + assert "z" in arr.coords + assert arr.z.shape == (grid.shape[2], ) + + assert np.allclose(arr.values, av_grid.grid) + +def test_get_isos(grid, skewed): + + if skewed: + return + + # Test isocontours (2D) + arr = grid_to_dataarray(grid.average(2), ['x', 'y'], [0, 1, 2], nsc=(1,1,1)) + + assert get_isos(arr, []) == [] + + contours = get_isos(arr, [{'frac': 0.5}]) + + assert isinstance(contours, list) + assert len(contours) == 1 + assert isinstance(contours[0], dict) + assert "x" in contours[0] + assert isinstance(contours[0]["x"], list) + assert "y" in contours[0] + assert isinstance(contours[0]["y"], list) + assert "z" not in contours[0] + + # Test isosurfaces (3D) + arr = grid_to_dataarray(grid, ['x', 'y', 'z'], [0, 1, 2], nsc=(1,1,1)) + + surfs = get_isos(arr, []) + + assert isinstance(surfs, list) + assert len(surfs) == 2 + assert isinstance(surfs[0], dict) + + # Sanity checks on the first surface + assert "color" in surfs[0] + assert surfs[0]["color"] is None + assert "opacity" in surfs[0] + assert surfs[0]["opacity"] is None + assert "name" in surfs[0] + assert isinstance(surfs[0]["name"], str) + assert "vertices" in surfs[0] + assert isinstance(surfs[0]["vertices"], np.ndarray) + assert surfs[0]["vertices"].dtype == np.float64 + assert surfs[0]["vertices"].shape[1] == 3 + assert "faces" in surfs[0] + assert isinstance(surfs[0]["faces"], np.ndarray) + assert surfs[0]["faces"].dtype == np.int32 + assert surfs[0]["faces"].shape[1] == 3 + + surfs = get_isos(arr, [{'val': 3, "color": "red", "opacity": 0.5, "name": "test"}]) + + assert isinstance(surfs, list) + assert len(surfs) == 1 + assert isinstance(surfs[0], dict) + assert "color" in surfs[0] + assert surfs[0]["color"] == "red" + assert "opacity" in surfs[0] + assert surfs[0]["opacity"] == 0.5 + assert "name" in surfs[0] + assert surfs[0]["name"] == "test" + assert "vertices" in surfs[0] + assert isinstance(surfs[0]["vertices"], np.ndarray) + assert surfs[0]["vertices"].dtype == np.float64 + assert surfs[0]["vertices"].shape[1] == 3 + assert "faces" in surfs[0] + assert isinstance(surfs[0]["faces"], np.ndarray) + assert surfs[0]["faces"].dtype == np.int32 + assert surfs[0]["faces"].shape[1] == 3 + + + + + + diff --git a/src/sisl/viz/processors/tests/test_groupreduce.py b/src/sisl/viz/processors/tests/test_groupreduce.py new file mode 100644 index 0000000000..5236067688 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_groupreduce.py @@ -0,0 +1,317 @@ +import numpy as np +import pytest +import xarray as xr + +from sisl.viz.processors.xarray import group_reduce + + +@pytest.fixture(scope="module") +def dataarray(): + return xr.DataArray([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + ], coords=[("x", [0,1,2,3]), ("y", [0,1,2])], name="vals" + ) + +@pytest.fixture(scope="module") +def dataset(): + arr = xr.DataArray([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + ], coords=[("x", [0,1,2,3]), ("y", [0,1,2])], name="vals" + ) + + arr2 = arr * 2 + return xr.Dataset({"vals": arr, "double": arr2}) + +def test_dataarray(dataarray): + + new = group_reduce(dataarray, [{"selector": [0,1]}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection") + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + assert new.sel(selection=0).sum() == 1 + 2 + 3 + 4 + 5 + 6 + assert new.sel(selection=1).sum() == 7 + 8 + 9 + 10 + 11 + 12 + +def test_dataarray_multidim(dataarray): + + new = group_reduce(dataarray, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], reduce_dim=("x", "y"), reduce_func=np.sum, groups_dim="selection") + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "y" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + assert new.sel(selection=0).sum() == 1 + 2 + 4 + 5 + assert new.sel(selection=1).sum() == 7 + 8 + 10 + 11 + +def test_dataarray_multidim_multireduce(dataarray): + + new = group_reduce(dataarray, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], reduce_dim=("x", "y"), + reduce_func=(np.sum, np.mean), groups_dim="selection") + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "y" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + assert new.sel(selection=0).sum() == (1 + 4) / 2 + (2 + 5) / 2 + assert new.sel(selection=1).sum() == (7 + 10) / 2 + (8 + 11) / 2 + +def test_dataarray_sangroup(dataarray): + # We use sanitize group to simply set all selectors to [0,1] + new = group_reduce(dataarray, [{"selector": [0,1]}, {"selector": [2, 3]}], + reduce_dim="x", reduce_func=np.sum, groups_dim="selection", + sanitize_group=lambda group: {**group, "selector": [0,1]} + ) + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + assert new.sel(selection=0).sum() == 1 + 2 + 3 + 4 + 5 + 6 + assert new.sel(selection=1).sum() == 1 + 2 + 3 + 4 + 5 + 6 + +def test_dataarray_names(dataarray): + + new = group_reduce(dataarray, [{"selector": [0,1], "name": "first"}, {"selector": [2, 3], "name": "second"}], + reduce_dim="x", reduce_func=np.sum, groups_dim="selection") + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == ["first", "second"] + assert new.sel(selection="first").sum() == 1 + 2 + 3 + 4 + 5 + 6 + assert new.sel(selection="second").sum() == 7 + 8 + 9 + 10 + 11 + 12 + +def test_dataarray_groupvars(dataarray): + + new = group_reduce(dataarray, + [ + {"selector": [0,1], "name": "first", "color": "red", "size": 3}, + {"selector": [2, 3], "name": "second", "color": "blue", "size": 4} + ], + reduce_dim="x", reduce_func=np.sum, groups_dim="selection", group_vars=["color", "size"] + ) + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == ["first", "second"] + + assert "vals" in new + assert new.vals.sel(selection="first").sum() == 1 + 2 + 3 + 4 + 5 + 6 + assert new.vals.sel(selection="second").sum() == 7 + 8 + 9 + 10 + 11 + 12 + + for k, vals in {"color": ["red", "blue"], "size": [3, 4]}.items(): + assert k in new + + k_data = getattr(new, k) + assert "selection" in k_data.dims + assert list(k_data.coords["selection"]) == ["first", "second"] + assert list(k_data) == vals + +@pytest.mark.parametrize("drop", [True, False]) +def test_dataarray_empty_selector(dataarray, drop): + + new = group_reduce(dataarray, [{"selector": []}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection", drop_empty=drop, fill_empty=0.) + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "selection" in new.dims + if drop: + assert len(new.coords["selection"]) == 1 + assert list(new.coords["selection"]) == [1] + else: + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + + if not drop: + assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=1).sum() == 7 + 8 + 9 + 10 + 11 + 12 + +def test_dataarray_empty_selector_0d(): + """When reducing an array along its only dimension, you get a 0d array. + + This was creating an error when filling empty selections. This test + ensures that it doesn' happen again + """ + + new = group_reduce( + xr.DataArray([1,2,3], coords=[("x", [0,1,2])]), + [{"selector": []}, {"selector": [1, 2]}], reduce_dim="x", reduce_func=np.sum, + groups_dim="selection", drop_empty=False, fill_empty=0. + ) + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + + assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=1).sum() == 2 + 3 + +def test_dataset(dataset): + + new = group_reduce(dataset, [{"selector": [0,1]}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection") + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0, 1] + + assert "vals" in new + assert "double" in new + assert new.sel(selection=0).sum() == (1 + 2 + 3 + 4 + 5 + 6) * 3 + assert new.sel(selection=1).sum() == (7 + 8 + 9 + 10 + 11 + 12) * 3 + +def test_dataset_multidim(dataset): + + new = group_reduce(dataset, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], + reduce_dim=("x", "y"), reduce_func=np.sum, groups_dim="selection") + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "y" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0, 1] + + assert "vals" in new + assert "double" in new + assert new.sel(selection=0).sum() == (1 + 2 + 4 + 5) * 3 + assert new.sel(selection=1).sum() == (7 + 8 + 10 + 11) * 3 + +def test_dataset_multidim_multireduce(dataset): + + new = group_reduce(dataset, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], + reduce_dim=("x", "y"), reduce_func=(np.sum, np.mean), groups_dim="selection") + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "y" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0, 1] + + assert "vals" in new + assert "double" in new + assert new.sel(selection=0).sum() == ((1 + 4) / 2 + (2 + 5) / 2 ) * 3 + assert new.sel(selection=1).sum() == ((7 + 10) / 2 + (8 + 11) / 2) * 3 + +def test_dataarray_multidim_multireduce(dataarray): + + new = group_reduce(dataarray, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], reduce_dim=("x", "y"), + reduce_func=(np.sum, np.mean), groups_dim="selection") + + assert isinstance(new, xr.DataArray) + assert "x" not in new.dims + assert "y" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + assert new.sel(selection=0).sum() == (1 + 4) / 2 + (2 + 5) / 2 + assert new.sel(selection=1).sum() == (7 + 10) / 2 + (8 + 11) / 2 + +def test_dataset_names(dataset): + + new = group_reduce(dataset, [{"selector": [0,1], "name": "first"}, {"selector": [2, 3], "name": "second"}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection") + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == ["first", "second"] + + assert "vals" in new + assert "double" in new + assert new.sel(selection="first").sum() == (1 + 2 + 3 + 4 + 5 + 6) * 3 + assert new.sel(selection="second").sum() == (7 + 8 + 9 + 10 + 11 + 12) * 3 + +def test_dataset_groupvars(dataset): + + new = group_reduce(dataset, + [ + {"selector": [0,1], "name": "first", "color": "red", "size": 3}, + {"selector": [2, 3], "name": "second", "color": "blue", "size": 4} + ], + reduce_dim="x", reduce_func=np.sum, groups_dim="selection", group_vars=["color", "size"] + ) + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == ["first", "second"] + + assert "vals" in new + assert "double" in new + assert new.double.sel(selection="first").sum() == (1 + 2 + 3 + 4 + 5 + 6) * 2 + assert new.double.sel(selection="second").sum() == (7 + 8 + 9 + 10 + 11 + 12) * 2 + + for k, vals in {"color": ["red", "blue"], "size": [3, 4]}.items(): + assert k in new + + k_data = getattr(new, k) + assert "selection" in k_data.dims + assert list(k_data.coords["selection"]) == ["first", "second"] + assert list(k_data) == vals + +@pytest.mark.parametrize("drop", [True, False]) +def test_dataset_empty_selector(dataset, drop): + + new = group_reduce(dataset, [{"selector": []}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection", drop_empty=drop, fill_empty=0.) + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "selection" in new.dims + if drop: + assert len(new.coords["selection"]) == 1 + assert list(new.coords["selection"]) == [1] + else: + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + + assert "vals" in new + assert "double" in new + if not drop: + assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=1).sum() == (7 + 8 + 9 + 10 + 11 + 12) * 3 + +def test_dataaset_empty_selector_0d(): + """When reducing an array along its only dimension, you get a 0d array. + + This was creating an error when filling empty selections. This test + ensures that it doesn' happen again + """ + dataset = xr.Dataset({"vals": (["x"], [1,2,3])}) + + new = group_reduce( + dataset, + [{"selector": []}, {"selector": [1, 2]}], reduce_dim="x", reduce_func=np.sum, + groups_dim="selection", drop_empty=False, fill_empty=0. + ) + + assert isinstance(new, xr.Dataset) + assert "x" not in new.dims + assert "selection" in new.dims + assert len(new.coords["selection"]) == 2 + assert list(new.coords["selection"]) == [0,1] + + assert "vals" in new + assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=1).sum() == 2 + 3 \ No newline at end of file diff --git a/src/sisl/viz/processors/tests/test_logic.py b/src/sisl/viz/processors/tests/test_logic.py new file mode 100644 index 0000000000..00a6cb7361 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_logic.py @@ -0,0 +1,33 @@ +import pytest + +from sisl.viz.processors.logic import matches, swap, switch + + +def test_swap(): + + assert swap(1, (1, 2)) == 2 + assert swap(2, (1, 2)) == 1 + + with pytest.raises(ValueError): + swap(3, (1, 2)) + +def test_matches(): + + assert matches(1, 1) == True + assert matches(1, 2) == False + + assert matches(1, 1, "a", "b") == "a" + assert matches(1, 2, "a", "b") == "b" + + assert matches(1, 1, "a") == "a" + assert matches(1, 2, "a") == False + + assert matches(1, 1, ret_false="b") == True + assert matches(1, 2, ret_false="b") == "b" + +def test_switch(): + + assert switch(True, "a", "b") == "a" + assert switch(False, "a", "b") == "b" + + \ No newline at end of file diff --git a/src/sisl/viz/processors/tests/test_orbital.py b/src/sisl/viz/processors/tests/test_orbital.py new file mode 100644 index 0000000000..a540ab7349 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_orbital.py @@ -0,0 +1,224 @@ +import numpy as np +import pytest +import xarray as xr + +import sisl +from sisl import AtomicOrbital, Geometry +from sisl.messages import SislError +from sisl.viz.data import PDOSData +from sisl.viz.processors.orbital import ( + OrbitalQueriesManager, + atom_data_from_orbital_data, + reduce_orbital_data, +) + + +@pytest.fixture(scope="module") +def geometry(): + + orbs = [ + AtomicOrbital("2sZ1"), AtomicOrbital("2sZ2"), + AtomicOrbital("2pxZ1"), AtomicOrbital("2pyZ1"), AtomicOrbital("2pzZ1"), + AtomicOrbital("2pxZ2"), AtomicOrbital("2pyZ2"), AtomicOrbital("2pzZ2"), + ] + + atoms = [ + sisl.Atom(5, orbs), + sisl.Atom(7, orbs), + ] + return sisl.geom.graphene(atoms=atoms) + +@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"]) +def spin(request): + return sisl.Spin(request.param) + +@pytest.fixture(scope="module") +def orb_manager(geometry, spin): + return OrbitalQueriesManager(geometry, spin=spin) + +def test_get_orbitals(orb_manager, geometry: Geometry): + + orbs = orb_manager.get_orbitals({"atoms": [0]}) + assert len(orbs) == geometry.atoms.atom[0].no + assert np.all(orbs == np.arange(geometry.atoms.atom[0].no)) + + orbs = orb_manager.get_orbitals({"orbitals": [0, 1]}) + assert len(orbs) == 2 + assert np.all(orbs == np.array([0,1])) + +def test_get_atoms(orb_manager, geometry: Geometry): + + ats = orb_manager.get_atoms({"atoms": [0]}) + assert len(ats) == 1 + assert ats[0] == 0 + + if geometry.orbitals[0] > 1: + at_orbitals = [0, 1] + else: + at_orbitals = [0] + + ats = orb_manager.get_atoms({"orbitals": at_orbitals}) + assert len(ats) == 1 + assert np.all(ats == np.array([0])) + + +def test_split(orb_manager, geometry: Geometry): + + # Check that it can split over species + queries = orb_manager.generate_queries(split="species") + + assert len(queries) == geometry.atoms.nspecie + + atom_tags = [atom.tag for atom in geometry.atoms.atom] + + for query in queries: + assert isinstance(query, dict), f"Query is not a dict: {query}" + + assert "species" in query, f"Query does not have species: {query}" + assert isinstance(query["species"], list) + assert len(query["species"]) == 1 + assert query["species"][0] in atom_tags + + # Check that it can split over atoms + queries = orb_manager.generate_queries(split="atoms") + + assert len(queries) == geometry.na + + for i_atom in range(geometry.na): + + query = queries[i_atom] + + assert isinstance(query, dict), f"Query is not a dict: {query}" + + assert "atoms" in query, f"Query does not have atoms: {query}" + assert isinstance(query["atoms"], list) + assert len(query["atoms"]) == 1 + assert query["atoms"][0] == i_atom + + # Check that it can split over n + queries = orb_manager.generate_queries(split="l") + + assert len(queries) != 0 + + for query in queries: + + assert isinstance(query, dict), f"Query is not a dict: {query}" + + assert "l" in query, f"Query does not have l: {query}" + assert isinstance(query["l"], list) + assert len(query["l"]) == 1 + assert isinstance(query["l"][0], int) + +def test_double_split(orb_manager): + # Check that it can split over two things at the same time + queries = orb_manager.generate_queries(split="l+m") + + assert len(queries) != 0 + + for query in queries: + assert isinstance(query, dict), f"Query is not a dict: {query}" + + assert "l" in query, f"Query does not have l: {query}" + assert isinstance(query["l"], list) + assert len(query["l"]) == 1 + assert isinstance(query["l"][0], int) + + assert "l" in query, f"Query does not have l: {query}" + assert isinstance(query["m"], list) + assert len(query["m"]) == 1 + assert isinstance(query["m"][0], int) + + assert abs(query["m"][0]) <= query["l"][0] + +def test_split_only(orb_manager, geometry): + + queries = orb_manager.generate_queries(split="species", only=[geometry.atoms.atom[0].tag]) + + assert len(queries) == 1 + assert queries[0]['species'] == [geometry.atoms.atom[0].tag] + +def test_split_exclude(orb_manager, geometry): + + queries = orb_manager.generate_queries(split="species", exclude=[geometry.atoms.atom[0].tag]) + + assert len(queries) == geometry.atoms.nspecie - 1 + assert geometry.atoms.atom[0].tag not in [query['species'][0] for query in queries] + +def test_constrained_split(orb_manager, geometry): + + queries = orb_manager.generate_queries(split="species", atoms=[0]) + + assert len(queries) == 1 + assert queries[0]['species'] == [geometry.atoms.atom[0].tag] + +def test_split_name(orb_manager, geometry): + + queries = orb_manager.generate_queries(split="species", name="Tag: $species") + + assert len(queries) == geometry.atoms.nspecie + + for query in queries: + assert "name" in query, f"Query does not have name: {query}" + assert query['name'] == f"Tag: {query['species'][0]}" + +def test_sanitize_query(orb_manager, geometry): + + san_query = orb_manager.sanitize_query({"atoms": [0]}) + + atom_orbitals = geometry.atoms.atom[0].orbitals + + assert len(san_query['orbitals']) == len(atom_orbitals) + assert np.all(san_query['orbitals'] == np.arange(len(atom_orbitals))) + +def test_reduce_orbital_data(geometry, spin): + + data = PDOSData.toy_example(geometry=geometry, spin=spin)._data + + reduced = reduce_orbital_data(data, [{"name": "all"}] ) + + assert isinstance(reduced, xr.DataArray) + + for dim in data.dims: + if dim == "orb": + assert dim not in reduced.dims + else: + assert dim in reduced.dims + assert len(data[dim]) == len(reduced[dim]) + + assert "group" in reduced.dims + assert len(reduced.group) == 1 + assert reduced.group[0] == "all" + assert np.allclose(reduced.sel(group="all").values, data.sum("orb").values) + + data_no_geometry = data.copy() + data_no_geometry.attrs.pop("geometry") + + with pytest.raises(SislError): + reduced = reduce_orbital_data(data_no_geometry, [{"name": "all"}] ) + +def test_atom_data_from_orbital_data(geometry: Geometry, spin): + + data = PDOSData.toy_example(geometry=geometry, spin=spin)._data + + atom_data = atom_data_from_orbital_data(data, geometry) + + assert isinstance(atom_data, xr.DataArray) + + for dim in data.dims: + if dim == "orb": + assert dim not in atom_data.dims + else: + assert dim in atom_data.dims + assert len(data[dim]) == len(atom_data[dim]) + + assert "atom" in atom_data.dims + assert len(atom_data.atom) == geometry.na + assert np.all(atom_data.atom == np.arange(geometry.na)) + + atom_values = [] + firsto = geometry.firsto + lasto = geometry.lasto + for i in range(geometry.na): + atom_values.append(data.sel(orb=slice(firsto[i], lasto[i])).sum("orb").values) + + assert np.allclose(atom_data.values, np.array(atom_values)) diff --git a/src/sisl/viz/processors/tests/test_sci_groupreduce.py b/src/sisl/viz/processors/tests/test_sci_groupreduce.py new file mode 100644 index 0000000000..df2ad50afc --- /dev/null +++ b/src/sisl/viz/processors/tests/test_sci_groupreduce.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest +import xarray as xr + +import sisl +from sisl.viz.processors.atom import reduce_atom_data +from sisl.viz.processors.orbital import reduce_orbital_data + + +@pytest.fixture(scope="module") +def atom_x(): + geom = sisl.geom.graphene().tile(20,1).tile(20,0) + + return xr.DataArray(geom.xyz[:, 0], coords=[("atom", range(geom.na))], attrs={"geometry": geom}) + +@pytest.fixture(scope="module") +def atom_xyz(): + geom = sisl.geom.graphene().tile(20,1).tile(20,0) + + return xr.Dataset({ + "x": xr.DataArray(geom.xyz[:, 0], coords=[("atom", range(geom.na))]), + "y": xr.DataArray(geom.xyz[:, 1], coords=[("atom", range(geom.na))]), + "z": xr.DataArray(geom.xyz[:, 2], coords=[("atom", range(geom.na))]), + }, attrs={"geometry": geom}) + +def test_reduce_atom_dataarray(atom_x): + + grouped = reduce_atom_data( + atom_x, + [{"atoms": [0,1], "name": "first"}, {"atoms": [5, 6], "name": "second"}], + reduce_func=np.sum, groups_dim="group" + ) + + assert isinstance(grouped, xr.DataArray) + assert float(grouped.sel(group="first")) == np.sum(atom_x.values[0:2].sum()) + assert float(grouped.sel(group="second")) == np.sum(atom_x.values[5:7].sum()) + +def test_reduce_atom_dataarray_cat(atom_x): + """We test that the atoms field is correctly sanitized using the geometry attached to the xarray object.""" + grouped = reduce_atom_data( + atom_x, + [{"atoms": {"x": (0, 10)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], + reduce_func=np.max, groups_dim="group" + ) + + assert isinstance(grouped, xr.DataArray) + assert float(grouped.sel(group="first")) <= 10 + assert float(grouped.sel(group="second")) == atom_x.max() + +def test_reduce_atom_cat_nogeom(atom_x): + """We test that the atoms field is correctly sanitized using the geometry attached to the xarray object.""" + atom_x = atom_x.copy() + geometry = atom_x.attrs["geometry"] + + # Remove the geometry + atom_x.attrs = {} + + # Without a geometry, it should fail to sanitize atoms specifications + with pytest.raises(Exception): + grouped = reduce_atom_data( + atom_x, + [{"atoms": {"x": (0, 10)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], + reduce_func=np.max, groups_dim="group" + ) + + # If we explicitly pass the geometry it should again be able to sanitize the atoms + grouped = reduce_atom_data( + atom_x, + [{"atoms": {"x": (0, 10)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], + geometry=geometry, reduce_func=np.max, groups_dim="group" + ) + + assert isinstance(grouped, xr.DataArray) + assert float(grouped.sel(group="first")) <= 10 + assert float(grouped.sel(group="second")) == atom_x.max() + +def test_reduce_atom_dataset_cat(atom_xyz): + """We test that the atoms field is correctly sanitized using the geometry attached to the xarray object.""" + grouped = reduce_atom_data( + atom_xyz, + [{"atoms": {"x": (0, 10), "y": (1, 3)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], + reduce_func=np.max, groups_dim="group" + ) + + assert isinstance(grouped, xr.Dataset) + assert float(grouped.sel(group="first").x) <= 10 + assert float(grouped.sel(group="second").x) == atom_xyz.x.max() + assert float(grouped.sel(group="first").y) <= 3 + assert float(grouped.sel(group="second").y) == atom_xyz.y.max() + + + + + + + + diff --git a/src/sisl/viz/processors/tests/test_spin.py b/src/sisl/viz/processors/tests/test_spin.py new file mode 100644 index 0000000000..917274c601 --- /dev/null +++ b/src/sisl/viz/processors/tests/test_spin.py @@ -0,0 +1,26 @@ +from sisl.viz.processors.spin import get_spin_options + + +def test_get_spin_options(): + + # Unpolarized spin + assert len(get_spin_options("unpolarized")) == 0 + + # Polarized spin + options = get_spin_options("polarized") + assert len(options) == 4 + assert 0 in options + assert 1 in options + assert "total" in options + assert "z" in options + + # Non colinear spin + options = get_spin_options("noncolinear") + assert len(options) == 4 + assert "total" in options + assert "x" in options + assert "y" in options + assert "z" in options + + options = get_spin_options("noncolinear", only_if_polarized=True) + assert len(options) == 0 \ No newline at end of file diff --git a/src/sisl/viz/processors/wavefunction.py b/src/sisl/viz/processors/wavefunction.py new file mode 100644 index 0000000000..407a2d0404 --- /dev/null +++ b/src/sisl/viz/processors/wavefunction.py @@ -0,0 +1,124 @@ +from typing import Optional + +import numpy as np + +import sisl +from sisl.geometry import Geometry +from sisl.grid import Grid +from sisl.physics.electron import EigenstateElectron, wavefunction +from sisl.physics.hamiltonian import Hamiltonian +from sisl.physics.spin import Spin +from sisl.viz.nodes.data_sources.file.siesta import FileDataSIESTA +from sisl.viz.nodes.node import Node + +from .grid import GridDataNode + + +@Node.from_func +def get_ith_eigenstate(eigenstate: EigenstateElectron, i: int): + """Gets the ith eigenstate. + + This is useful because an EigenstateElectron contains all the eigenstates. + Sometimes a post-processing tool calculates only a subset of eigenstates, + and this is what you have inside the EigenstateElectron. + therefore getting eigenstate[0] does not mean that + + Parameters + ---------- + eigenstate : EigenstateElectron + The object containing all eigenstates. + i : int + The index of the eigenstate to get. + + Returns + ---------- + EigenstateElectron + The ith eigenstate. + """ + + if "index" in eigenstate.info: + wf_i = np.nonzero(eigenstate.info["index"] == i)[0] + if len(wf_i) == 0: + raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {eigenstate.info['index']}.") + wf_i = wf_i[0] + else: + max_index = len(eigenstate) + if i > max_index: + raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}].") + wf_i = i + + return eigenstate[wf_i] + +class WavefunctionDataNode(GridDataNode): + ... + +@WavefunctionDataNode.register +def eigenstate_wf(eigenstate: EigenstateElectron, i: int, grid: Optional[Grid] = None, geometry: Optional[Geometry] = None, + k = [0,0,0], grid_prec: float = 0.2, spin: Optional[Spin] = None +): + if geometry is None: + if isinstance(eigenstate.parent, Geometry): + geometry = eigenstate.parent + else: + geometry = getattr(eigenstate.parent, "geometry", None) + if geometry is None: + raise ValueError('No geometry was provided and we need it the basis orbitals to build the wavefunctions from the coefficients!') + + if spin is None: + spin = getattr(eigenstate.parent, "spin", Spin()) + + if grid is None: + dtype = eigenstate.dtype + grid = Grid(grid_prec, geometry=geometry, dtype=dtype) + + # GridPlot's after_read basically sets the x_range, y_range and z_range options + # which need to know what the grid is, that's why we are calling it here + # super()._after_read() + + # Get the particular WF that we want from the eigenstate object + wf_state = get_ith_eigenstate(eigenstate, i) + + # Ensure we are dealing with the R gauge + wf_state.change_gauge('R') + + # Finally, insert the wavefunction values into the grid. + wavefunction( + wf_state.state, grid, geometry=geometry, + k=k, spinor=0, spin=spin + ) + + return grid + + +@WavefunctionDataNode.register +def hamiltonian_wf(H: Hamiltonian, i: int, grid: Optional[Grid] = None, geometry: Optional[Geometry] = None, + k = [0,0,0], grid_prec: float = 0.2, spin: int = 0 +): + eigenstate = H.eigenstate(k=k, spin=spin) + + return eigenstate_wf(eigenstate, i, grid, geometry, k, grid_prec, spin) + +@WavefunctionDataNode.register +def wfsx_wf(fdf, wfsx_file, i: int, grid: Optional[Grid] = None, geometry: Optional[Geometry] = None, + k = [0,0,0], grid_prec: float = 0.2, spin: int = 0 +): + fdf = FileDataSIESTA(path=fdf) + geometry = fdf.read_geometry(output=True) + + # Get the WFSX file. If not provided, it is inferred from the fdf. + wfsx = FileDataSIESTA(fdf=fdf, path=wfsx_file, cls=sisl.io.wfsxSileSiesta) + + # Now that we have the file, read the spin size and create a fake Hamiltonian + sizes = wfsx.read_sizes() + H = sisl.Hamiltonian(geometry, dim=sizes.nspin) + + # Read the wfsx again, this time passing the Hamiltonian as the parent + wfsx = sisl.get_sile(wfsx.file, parent=H) + + # Try to find the eigenstate that we need + eigenstate = wfsx.read_eigenstate(k=k, spin=spin) + if eigenstate is None: + # We have not found it. + raise ValueError(f"A state with k={k} was not found in file {wfsx.file}.") + + return eigenstate_wf(eigenstate, i, grid, geometry, k, grid_prec) \ No newline at end of file diff --git a/src/sisl/viz/processors/xarray.py b/src/sisl/viz/processors/xarray.py new file mode 100644 index 0000000000..8966af80f6 --- /dev/null +++ b/src/sisl/viz/processors/xarray.py @@ -0,0 +1,197 @@ +from collections import defaultdict +from functools import singledispatchmethod +from typing import Any, Callable, Optional, Sequence, Tuple, TypedDict, Union + +import numpy as np +import xarray as xr +from xarray import DataArray, Dataset + +from sisl import Geometry +from sisl.messages import SislError + + +class XarrayData: + @singledispatchmethod + def __init__(self, data: Union[DataArray, Dataset]): + if isinstance(data, self.__class__): + data = data._data + + self._data = data + + def __getattr__(self, key): + return getattr(self._data, key) + + def __dir__(self): + return dir(self._data) + +class Group(TypedDict, total=False): + name: str + selector: Any + reduce_func: Optional[Callable] + ... + +def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[Group], + reduce_dim: Union[str, Tuple[str, ...]], reduce_func: Union[Callable, Tuple[Callable, ...]] = np.mean, groups_dim: str = "group", + sanitize_group: Callable = lambda x: x, group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, fill_empty: Any = 0. +) -> Union[DataArray, Dataset]: + """Groups contributions of orbitals into a new dimension. + + Given an xarray object containing orbital information and the specification of groups of orbitals, this function + computes the total contribution for each group of orbitals. It therefore removes the orbitals dimension and + creates a new one to account for the groups. + + It can also reduce spin in the same go if requested. In that case, groups can also specify particular spin components. + + Parameters + ---------- + data : DataArray or Dataset + The xarray object to reduce. + groups : Sequence[Group] + A sequence containing the specifications for each group of orbitals. See ``Group``. + reduce_func : Callable or tuple of Callable, optional + The function that will compute the reduction along the reduced dimension once the selection is done. + This could be for example ``numpy.mean`` or ``numpy.sum``. + Notice that this will only be used in case the group specification doesn't specify a particular function + in its "reduce_func" field, which will take preference. + If ``reduce_dim`` is a tuple, this can also be a tuple to indicate different reducing methods for each + dimension. + reduce_dim: str or tuple of str, optional + Name of the dimension that should be reduced. If a tuple is provided, multiple dimensions will be reduced. + groups_dim: str, optional + Name of the new dimension that will be created for the groups. + sanitize_group: Callable, optional + A function that will be used to sanitize the group specification before it is used. + group_vars: Sequence[str], optional + If set, this argument specifies extra variables that depend on the group and the user would like to + introduce in the new xarray object. These variables will be searched as fields for each group specification. + A data variable will be created for each group_var and they will be added to the final xarray object. + Note that this forces the returned object to be a Dataset, even if the input data is a DataArray. + drop_empty: bool, optional + If set to `True`, group specifications that do not result in any matches will not appear in the final + returned object. + fill_empty: Any, optional + If ``drop_empty`` is set to ``False``, this argument specifies the value to use for group specifications + that do not result in any matches. + + Returns + ---------- + DataArray or Dataset + The new xarray object with the grouped and reduced dataarray object. + """ + if len(groups) == 0: + if isinstance(data, Dataset): + return data.drop_dims(reduce_dim) + else: + raise ValueError("Must specify at least one group.") + + input_is_dataarray = isinstance(data, DataArray) + + if not isinstance(reduce_dim, tuple): + reduce_dim = (reduce_dim,) + + group_vars_dict = defaultdict(list) + groups_vals = [] + names = [] + for i_group, group in enumerate(groups): + group = sanitize_group(group) + # Get the orbitals of the group + selector = group['selector'] + if not isinstance(selector, tuple): + selector = (selector,) + + # Select the data we are interested in + group_vals = data.sel(**{dim: sel for dim, sel in zip(reduce_dim, selector) if sel is not None}) + + empty = False + for dim in reduce_dim: + selected = getattr(group_vals, dim, []) + empty = len(selected) == 0 + if empty: + break + + if empty: + # Handle the case where the selection found no matches. + if drop_empty: + continue + else: + group_vals = data.isel({dim: 0 for dim in reduce_dim}, drop=True).copy(deep=True) + if input_is_dataarray: + group_vals[...] = fill_empty + else: + for da in group_vals.values(): + da[...] = fill_empty + + else: + # If it did find matches, reduce the data. + reduce_funcs = group.get("reduce_func", reduce_func) + if not isinstance(reduce_funcs, tuple): + reduce_funcs = tuple([reduce_funcs] * len(reduce_dim)) + for dim, func in zip(reduce_dim, reduce_funcs): + group_vals = group_vals.reduce(func, dim=dim) + + + # Assign the name to this group and add it to the list of groups. + name = group.get('name') or i_group + names.append(name) + if input_is_dataarray: + group_vals.name = name + groups_vals.append(group_vals) + + # Add the extra variables to the group. + if group_vars is not None: + for var in group_vars: + group_vars_dict[var].append(group.get(var)) + + # Concatenate all the groups into a single xarray object creating a new coordinate. + new_obj = xr.concat(groups_vals, dim=groups_dim).assign_coords({groups_dim: names}) + if input_is_dataarray: + new_obj.name = data.name + # Set the attributes of the passed array to the new one. + new_obj.attrs = {**data.attrs, **new_obj.attrs} + + # If there were extra group variables, then create a Dataset with them + if group_vars is not None: + + if isinstance(new_obj, DataArray): + new_obj = new_obj.to_dataset() + + new_obj = new_obj.assign({ + k: DataArray(v, dims=[groups_dim], name=k) for k,v in group_vars_dict.items() + }) + + return new_obj + +def scale_variable(dataset: Dataset, var: str, scale: float = 1, default_value: Union[float, None] = None, allow_not_present: bool = False) -> Dataset: + new = dataset.copy() + + if var not in new: + if allow_not_present: + return new + else: + raise ValueError(f"Variable {var} not present in dataset.") + + try: + new[var] = new[var] * scale + except TypeError: + if default_value is not None: + new[var][new[var] == None] = default_value * scale + return new + +def select(dataset: Dataset, dim: str, selector: Any) -> Dataset: + if selector is not None: + dataset = dataset.sel(**{dim: selector}) + return dataset + +def filter_energy_range( + data: Union[DataArray, Dataset], Erange: Optional[Tuple[float, float]]=None, E0: float = 0 +) -> Union[DataArray, Dataset]: + # Shift the energies + E_data = data.assign_coords(E=data.E - E0) + # Select a given energy range + if Erange is not None: + # Get the energy range that has been asked for. + Emin, Emax = Erange + E_data = data.sel(E=slice(Emin, Emax)) + + return E_data diff --git a/src/sisl/viz/session.py b/src/sisl/viz/session.py deleted file mode 100644 index 426da01496..0000000000 --- a/src/sisl/viz/session.py +++ /dev/null @@ -1,814 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import os -import uuid -from copy import copy, deepcopy -from pathlib import Path - -from sisl._environ import get_environ_variable -from sisl.messages import warn - -from ._shortcuts import ShortCutable -from .configurable import Configurable, vizplotly_settings -from .input_fields import ( - Array1DInput, - BoolInput, - FilePathInput, - RangeSliderInput, - TextInput, -) -from .plot import Plot -from .plotutils import ( - call_method_if_present, - find_files, - find_plotable_siles, - get_plot_classes, -) - -__all__ = ["Session"] - - -class Warehouse: - """ Class to store everything related to a session. - - A warehouse can be shared between multiple sessions. - THIS SHOULD ONLY CONTAIN PLOTS!!!! - (The rest: tabs, structures and plotables should be session-specific) - """ - - def __init__(self): - self._warehouse = { - "plots": {}, - "structs": {}, - "plotables": {}, - "tabs": [] - } - - def __getitem__(self, item): - """ Gets an item from the warehouse """ - return self._warehouse[item] - - def __setitem__(self, item, value): - """ Stores an item to the warehouse """ - self._warehouse[item] = value - - -class Session(Configurable, ShortCutable): - """ Represents a session of the graphical interface - - Plots are organized in different tabs and each tab has a layout - that defines how plots are displayed in the dashboard. - - Contains different methods that help managing the session and that are - directly called by the front end of the graphical interface. - Therefore: IF A METHOD NAME IS CHANGED, THE SAME CHANGE MUST BE DONE - IN THE GUI FILE "apis/PythonApi". - - Parameters - ----------- - root_dir: str, optional - - file_storage_dir: str, optional - Directory where files uploaded in the GUI will be stored - keep_uploaded: bool, optional - Whether uploaded files should be kept in disk or directly removed - after plotting them. - search_depth: array-like of shape (2,), optional - Determines the depth limits of the search for structures (from the - root directory). - showTooltips: bool, optional - Tooltips help you understand how something works or what something - will do.If you are already familiar with the interface, you can - turn this off. - listenForUpdates: bool, optional - Determines whether the session updates plots when files change - This is very useful to track progress. It is only meaningful in the - GUI. - updateInterval: int, optional - The time in ms between consecutive checks for updates. - plot_dims: array-like, optional - The initial width and height of a new plot. Width is in columns - (out of a total of 12). For height, you really should try what works - best for you - plot_preset: str, optional - Preset that is passed directly to each plot initialization - plotly_template: str, optional - Plotly template that should be used as the default for this session - """ - - _param_groups = ( - - { - "key": "gui", - "name": "Interface tweaks", - "icon": "aspect_ratio", - "description": "There's a spanish saying that goes something like: for each taste there's a color. Since we know this is true, you can tweak these parameters too make the interface feel as comfortable as possible. " - }, - - { - "key": "filesystem", - "name": "File system settings", - "icon": "folder", - "description": "Your computer is pretty big and most of it is not important for the analysis of simulations (e.g. the folder with your holidays pictures). Also, everyone likes to store things differently. Please indicate how exactly do you want the interface to look for simulations results in your filesystem. " - }, - - ) - - _parameters = ( - - TextInput( - key = "root_dir", name = "Root directory", - group = "filesystem", - default = os.getcwd(), - params = { - "placeholder": "Write the path here..." - } - ), - - FilePathInput( - key="file_storage_dir", name="File storage directory", - group="filesystem", - default= get_environ_variable("SISL_TMP"), - params={ - "placeholder": "Write the path here..." - }, - help="Directory where files uploaded in the GUI will be stored" - ), - - BoolInput( - key="keep_uploaded", name="Keep uploaded files", - group="filesystem", - default=False, - help="Whether uploaded files should be kept in disk or directly removed after plotting them." - ), - - RangeSliderInput( - key = "search_depth", name = "Search depth", - group = "filesystem", - default = [0, 3], - params = { - "min": 0, - "max": 15, - "allowCross": False, - "step": 1, - "marks": {i: str(i) for i in range(0, 16)}, - "updatemode": "drag", - "units": "eV", - }, - help = "Determines the depth limits of the search for structures (from the root directory)." - ), - - BoolInput( - key = "showTooltips", name = "Show Tooltips", - group = "gui", - default = True, - params = { - "offLabel": "No", - "onLabel": "Yes" - }, - help = "Tooltips help you understand how something works or what something will do.
If you are already familiar with the interface, you can turn this off." - ), - - BoolInput( - key = "listenForUpdates", name = "Listen for updates", - group = "gui", - default = True, - params = { - "offLabel": "No", - "onLabel": "Yes" - }, - help = "Determines whether the session updates plots when files change
This is very useful to track progress. It is only meaningful in the GUI." - ), - - Array1DInput( - key="plot_dims", name="Initial plot dimensions", - default=[4, 30], - group="gui", - help="""The initial width and height of a new plot.
Width is in columns (out of a total of 12). For height, you really should try what works best for you""" - ), - - TextInput(key="plot_preset", name="Plot presets", - default=None, - help="Preset that is passed directly to each plot initialization" - ), - - TextInput(key="plotly_template", name="Plotly template", - default=None, - help="Plotly template that should be used as the default for this session" - ) - - ) - - @vizplotly_settings('before', init=True) - def __init__(self, *args, **kwargs): - self.id = str(uuid.uuid4()) - - self.before_plot_update = None - self.on_plot_change = None - self.on_plot_change_error = None - - self.warehouse = Warehouse() - - # Initialize shortcut management - ShortCutable.__init__(self) - - call_method_if_present(self, "_after_init") - - #----------------------------------------- - # PLOT MANAGEMENT - #----------------------------------------- - - @property - def plots(self): - """ The plots that this session contains """ - return self.warehouse["plots"] - - def plot(self, plotID): - """ Method to get a plot that is already in the session's warehouse - - Arguments - ----------- - plotID: str - The ID of the desired plot - - Returns - --------- - plot: sisl.viz.plotly.Plot() - The instance of the desired plot - """ - plot = self.plots[plotID] - - if not hasattr(plot, "grid_dims"): - plot.grid_dims = self.get_setting("plot_dims") - - return plot - - @staticmethod - def get_plot_classes(): - """ This method provides all the plot subclasses, even the nested ones - - Returns - ------- - list - all the plot classes that the module is aware of. - """ - return get_plot_classes() - - def add_plot(self, plot, tabID = None, noTab = False): - """ Adds an already initialized plot object to the session - - Parameters - ----- - plot: Plot() - the plot object that we want to add to the session - tab: str, optional - the name of the tab where we want to add the plot or - the ID of the tab where we want to add the plot. - - If neither tab or tabID are provided, it will be appended to the first tab - noTab: boolean, optional - if set to true, prevents the plot from being added to a tab - """ - self.warehouse["plots"][plot.id] = plot - - if not noTab: - tabID = self._tab_id(tabID) if tabID is not None else self.tabs[0]["id"] - - self._add_plot_to_tab(plot.id, tabID) - - call_method_if_present(self, "_on_plot_added", plot, tabID) - - return self - - def new_plot(self, plotClass=None, tabID=None, structID=None, plotable_path=None, animation = False, **kwargs): - """ Get a new plot from the specified class - - Arguments - ----------- - plotClass: str, optional - The name of the desired class. - If not provided, the session will try to initialize from the `Plot` parent class. - This may be useful if the keyword argument filename is provided, for example, to let - `Plot` guess which type of plot to use. - tabID: str, optional - Tab where the plot should be stored - structID: str, optional - The ID of the structure for which we want the plot. - plotableID: str, optional - The ID of the plotable file that we want to plot. - animation: bool, optional - Whether the initialized plot should be an animation. - - If true, it uses the `Plot.animated` method to initialize the plot - **kwargs: - Passed directly to plot initialization - - Returns - ----------- - new_plot: sisl.viz.plotly.Plot() - The initialized new plot - """ - args = [] - - if plotClass is None: - ReqPlotClass = Plot - else: - for PlotClass in self.get_plot_classes(): - if PlotClass.__name__ == plotClass: - ReqPlotClass = PlotClass - break - else: - raise ValueError(f"Didn't find the desired plot class: {plotClass}") - - if plotable_path is not None: - args = (plotable_path,) - if structID: - kwargs = {**kwargs, "root_fdf": self.warehouse["structs"][structID]["path"]} - - if animation: - wdir = self.warehouse["structs"][structID]["path"].parent if structID else self.get_setting("root_dir") - new_plot = ReqPlotClass.animated(wdir = wdir) - else: - plot_preset = self.get_setting("plot_preset") - if plot_preset is not None: - kwargs["presets"] = [*[plot_preset], *kwargs.get("presets", [])] - plotly_template = self.get_setting("plotly_template") - if plotly_template is not None: - layout = kwargs.get("layout", {}) - template = layout.get("template", "") - kwargs["layout"] = {"template": f'{plotly_template}{"+" + template if template else ""}', **layout} - new_plot = ReqPlotClass(*args, **kwargs) - - self.add_plot(new_plot, tabID) - - return self.plot(new_plot.id) - - def update_plot(self, plotID, newSettings): - """ Method to update the settings of a plot that is in the session's warehouse - - Arguments - ----------- - plotID: str - The ID of the plot whose settings need to be updated - newSettings: dict - Dictionary with the key and new value of all the settings that need to be updated. - - Returns - --------- - plot: sisl.viz.plotly.Plot() - The instance of the updated plot - """ - return self.plot(plotID).update_settings(**newSettings) - - def undo_plot_settings(self, plotID): - """ Method undo the settings of a plot that is in the session's warehouse - - Arguments - ----------- - plotID: str - The ID of the plot whose settings need to be undone - - Returns - --------- - plot: sisl.viz.plotly.Plot() - The instance of the plot with the settings rolled back. - """ - return self.plot(plotID).undo_settings() - - def remove_plot_from_tab(self, plotID, tab): - """ Method to remove a plot only from a given tab. - - Parameters - ----------- - plotID: str - the ID of the plot that you want to remove. - tab: str - the ID or name of the tab that you want to remove the plot from. - """ - tab = self.tab(tab) - - tab["plots"] = [plot for plot in tab["plots"] if plot != plotID] - - def remove_plot(self, plotID): - """ Method to remove a plot from all tabs. - - Parameters - ---------- - plotID: str - the ID of the plot that you want to remove. - """ - plot = self.plot(plotID) - - self.warehouse["plots"] = {ID: plot for ID, plot in self.plots.items() if ID != plotID} - - self.remove_plot_from_all_tabs(plotID) - - call_method_if_present(self, "_on_plot_removed", plot) - - return self - - def merge_plots(self, plots, to="multiple", tab=None, remove=True, **kwargs): - """ Merges two or more plots present in the session using `Plot.merge`. - - Parameters - ----------- - plots: array-like of (str and/or Plot) - A list with the ids of the plots (or the actual plots) that you want to merge. - Note that THE PLOTS PASSED HERE ARE NOT NECESSARILY IN THE SESSION beforehand. - to: {"multiple", "subplots", "animation"}, optional - the merge method. Each option results in a different way of putting all the plots - together: - - "multiple": All plots are shown in the same canvas at the same time. Useful for direct - comparison. - - "subplots": The layout is divided in different subplots. - - "animation": Each plot is converted into the frame of an animation. - tab: str, optional - the name or id of the tab where you want the new plot to go. - If not provided it will go to the tab where the first plot belongs. - remove: boolean, optional - whether the plots used to do the merging should be removed from the session's layout. - Remember that you are always in time to split the merged plots into individual plots - again. - **kwargs: - go directly extra arguments that are directly passed to `MultiplePlot`, `Subplots` - or `Animation` initialization. (see `Plot.merge`) - """ - # Get the plots if ids where passed. Note that we can accept plots that are not in the warehouse yet - plots = [self.plot(plot) if isinstance(plot, str) else plot for plot in plots] - - merged = plots[0].merge(plots[1:], to=to, **kwargs) - - if tab is None: - for session_tab in self.tabs: - if plots[0].id in session_tab["plots"]: - tab = session_tab["id"] - break - - if remove: - for plot in plots: - self.remove_plot(plot.id) - - self.add_plot(merged, tabID=tab) - - return self - - def updates_available(self): - """ Checks if the session's plots have pending updates due to changes in files. - - Returns - --------- - list - the ids of the plots where an update is available. - """ - updates_avail = [plotID for plotID, plot in self.plots.items() if plot.updates_available()] - - return updates_avail - - def commit_updates(self): - """ Updates the plots that can be updated according to `updates_available`. - - Note that this method can be safely called since it has no effect when no updates are available. - """ - for plotID in self.updates_available(): - try: - self.plots[plotID].read_data(update_fig=True) - except Exception as e: - warn(f"Could not update plot {plotID}.\nError: {e}") - - return self - - def listen(self, forever=False): - """ Listens for updates in the followed files (see the `updates_available` method) - - Parameters - --------- - forever: boolean, optional - whether to keep listening after the first plot updates. - """ - from threading import Event - - exit_event = Event() - - while exit_event.is_set(): - - exit_event.wait(1) - - updates_avail = self.updates_available() - - if len(updates_avail) != 0: - - for plotID in updates_avail: - self.plots[plotID].read_data(update_fig=True) - - if not forever: - exit_event.set() - - def figures_only(self): - """ Removes all plot data from this session's plots except the actual figure. - - This is very useful to save just for display, since it can decrease the size of the session - DRAMATICALLY. - """ - for plotID, plot in self.plots.items(): - - plot = Plot.from_plotly(plot.figure) - plot.id = plotID - - self.warehouse["plots"][plotID] = plot - - def _run_plot_method(self, plotID, method_name, *args, **kwargs): - """ Generic private method to run methods on plots that belong to this session. - - Any public method that runs plot methods should use this private method under the hood. - - In this way, the session will be able to consistently respond to plot updates. E.g. - """ - plot = self.plot(plotID) - - method = getattr(plot.autosync, method_name) - - return method(*args, **kwargs) - - #----------------------------------------- - # TABS MANAGEMENT - #----------------------------------------- - @property - def tabs(self): - """ The tabs that this session contains """ - return self.warehouse["tabs"] - - def tab(self, tab): - """ Get a tab by its name or ID. - - If it does not exist, it will be created (this acts as a shortcut for add_tab in that case) - - Parameters - -------- - tab: str - The name or ID of the tab you want to get - """ - tab_str = tab - - tabID = self._tab_id(tab_str) - - for tab in self.tabs: - if tab["id"] == tabID: - return tab - else: - self.add_tab(tab_str) - return self.tab(tab_str) - - def add_tab(self, name="New tab", plots=[]): - """ Adds a new tab to the session - - Arguments - ---------- - name: optional, str ("New tab") - The name of the new tab - plots: optional, array-like - Array of ids (as strings) that identify the plots that you want to put inside your tab. - Keep in mind that the plots with these ids must be present in self.plots. - """ - new_tab = {"id": str(uuid.uuid4()), "name": name, "plots": deepcopy(plots), "layouts": {"lg": []}} - - self.tabs.append(new_tab) - - return self - - def update_tab(self, tabID, newParams={}, **kwargs): - """ Method to update the parameters of a given tab """ - tab = self.tab(tabID) - - for key, val in {**newParams, **kwargs}.items(): - tab[key] = val - - return self - - def remove_tab(self, tabID): - """ Removes a tab from the current session """ - tabID = self._tab_id(tabID) - - for iTab, tab in enumerate(self.warehouse["tabs"]): - if tab["id"] == tabID: - del self.warehouse["tabs"][iTab] - break - - return self - - def move_plot(self, plot, tab, keep=False): - """ Moves a plot to a tab - - Parameters - ---------- - plot: str or sisl.viz.plotly.Plot - the plot's ID or the plot's instance - tab: str - the tab's id or the tab's name. - keep: boolean, optional - if True the plot is also kept in the previous tab. - This doesn't waste any additional memory, - since the tabs only hold references of the plots they have, - each plot is stored only once - """ - plotID = plot - if isinstance(plot, Plot): - plotID = plot.id - - if not keep: - self.remove_plot_from_all_tabs(plotID) - - self._add_plot_to_tab(plotID, tab) - - return self - - def _add_plot_to_tab(self, plot, tab): - """ Adds a plot to the requested tab. - - If the plot is not part of the session already, it will be added. - - Parameters - ---------- - plot: str or sisl.viz.plotly.Plot - the plot's ID or the plot's instance - tab: str - the tab's id or the tab's name. - """ - if isinstance(plot, Plot): - plotID = plot.id - if plotID not in self.plots: - self.add_plot(plot, tab) - else: - plotID = plot - - tab = self.tab(tab) - - tab["plots"] = [*tab["plots"], plotID] - - return self - - def remove_plot_from_all_tabs(self, plotID): - """ Removes a given plot from all tabs where it is located. - - Parameters - ----------- - plotID: str - the id of the plot you want to remove. - """ - for tab in self.tabs: - self.remove_plot_from_tab(plotID, tab["id"]) - - return self - - def get_tab_plots(self, tab): - """ Returns all the plots of a given tab. - - Parameters - ------------ - tab: str - the id or name of the tab. - """ - tab = self.tab(tab) - - return [self.plot(plotID) for plotID in tab["plots"]] if tab else None - - def set_tab_plots(self, tab, plots): - """ Sets the plots list of a tab - - Parameters - -------- - tab: str - tab's id or name - plots: array-like of str or sisl.viz.plotly.Plot (or combination of the two) - plots ids or plot instances. - """ - tab = self.tab(tab) - - tab["plots"] = [] - - for plot in plots: - self.add_plot(plot, tab) - - def tab_id(self, tab_name): - """ Gets the id of a given tab. - - Parameters - ----------- - tab_name: str - the name of the tab - - Returns - --------- - str or None. - the ID of the tab. None if there's no such tab. - """ - for tab in self.tabs: - if tab["name"] == tab_name: - return tab["id"] - - def _tab_id(self, tab_id_or_name): - """ Gets the id of a given tab. - - Parameters - ----------- - tab_id_or_name: str - the id or name of the tab. - - Returns - --------- - str or None. - the ID of the tab. None if there's no such tab. - """ - try: - uuid.UUID(str(tab_id_or_name)) - return tab_id_or_name - except Exception: - return self.tab_id(tab_id_or_name) - - #----------------------------------------- - # STRUCTURES MANAGEMENT - #----------------------------------------- - - def get_structures(self, path=None, root_dir=".", search_depth=None): - """ Gets all the structures that are in the scope of this session - - Parameters - ----------- - path: str, optional - the path where to start looking for structures. - - If not provided, the session's "root_dir" will be used. - - Returns - ---------- - dict - keys are the structure ID and values are info about each structure. - """ - path = Path(path or root_dir) - - #Get the structures - self.warehouse["structs"] = { - str(uuid.uuid4()): {"name": path.name, "path": path} for path in find_files(root_dir, "*fdf", search_depth) - } - - #Avoid passing unnecessary info to the browser. - return {structID: {"id": structID, **{k: struct[k] for k in ["name", "path"]}} for structID, struct in self.warehouse["structs"].items()} - - def get_plotables(self, path=None, root_dir=".", search_depth=None): - """ Gets all the plotables that are in the scope of this session. - - Parameters - ----------- - path: str, optional - the path where to start looking for plotables. - - If not provided, the session's "root_dir" will be used. - - Returns - ---------- - dict - keys are the plotable ID and values are info about each structure. - """ - # Empty the plotables dictionary - plotables = {} - path = Path(path or root_dir) - - # Get all the files that correspond to registered plotable siles - files = find_plotable_siles(path, search_depth) - - for SileClass, filepaths in files.items(): - - avail_plots = list(SileClass.plot._dispatchs) - default_plot = SileClass.plot._default - - # Extend the plotables dict with the files that we find that belong to this sile - plotables = {**plotables, **{ - str(uuid.uuid4()): {"name": path.name, "path": path, "plots": avail_plots, "default_plot": default_plot} for path in filepaths - }} - - self.warehouse["plotables"] = plotables - - #Avoid passing unnecessary info to the browser. - return {id: {"id": id, **{k: plotable[k] for k in ["name", "path", "plots", "default_plot"]}, "chosenPlots": [plotable["default_plot"]]} - for id, plotable in self.warehouse["plotables"].items()} - - def save(self, path, figs_only=False): - """ Stores the session in disk. - - Parameters - ---------- - path: str - Path where the session should be saved. - figs_only: boolean, optional - Whether only figures should be saved, the rest of plot's data will be ignored. - """ - import dill - session = copy(self) - - if figs_only: - session.figures_only() - - with open(path, 'wb') as handle: - dill.dump(session, handle, protocol=dill.HIGHEST_PROTOCOL) - - return self diff --git a/src/sisl/viz/sessions/__init__.py b/src/sisl/viz/sessions/__init__.py deleted file mode 100644 index 90472e8612..0000000000 --- a/src/sisl/viz/sessions/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from .blank import BlankSession diff --git a/src/sisl/viz/sessions/blank.py b/src/sisl/viz/sessions/blank.py deleted file mode 100644 index 789e91193c..0000000000 --- a/src/sisl/viz/sessions/blank.py +++ /dev/null @@ -1,51 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import os - -from ..session import Session - - -class BlankSession(Session): - """ - The most basic session one could have, really. - - Parameters - ------------ - root_dir: str, optional - - file_storage_dir: str, optional - Directory where files uploaded in the GUI will be stored - keep_uploaded: bool, optional - Whether uploaded files should be kept in disk or directly removed - after plotting them. - searchDepth: array-like of shape (2,), optional - Determines the depth limits of the search for structures (from the - root directory). - showTooltips: bool, optional - Tooltips help you understand how something works or what something - will do.If you are already familiar with the interface, you can - turn this off. - listenForUpdates: bool, optional - Determines whether the session updates plots when files change - This is very useful to track progress. It is only meaningful in the - GUI. - updateInterval: int, optional - The time in ms between consecutive checks for updates. - plotDims: array-like, optional - The initial width and height of a new plot. Width is in columns - (out of a total of 12). For height, you really should try what works - best for you - plot_preset: str, optional - Preset that is passed directly to each plot initialization - plotly_template: str, optional - Plotly template that should be used as the default for this session - """ - - _sessionName = "Blank session" - - _description = "The most basic session one could have, really." - - def _after_init(self): - # Add a first tab so that the user can see something :) - self.add_tab("First tab") diff --git a/src/sisl/viz/sessions/tests/__init__.py b/src/sisl/viz/sessions/tests/__init__.py deleted file mode 100644 index 448bb8652d..0000000000 --- a/src/sisl/viz/sessions/tests/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. diff --git a/src/sisl/viz/sessions/tests/test_sessions.py b/src/sisl/viz/sessions/tests/test_sessions.py deleted file mode 100644 index 8920271e19..0000000000 --- a/src/sisl/viz/sessions/tests/test_sessions.py +++ /dev/null @@ -1,24 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" -These tests check that all session subclasses fulfill at least the most basic stuff -More tests should be run on each session, but these are the most basic ones to -ensure that at least they do not break basic session functionality. -""" -import pytest - -from sisl.viz import Session -from sisl.viz.sessions import * -from sisl.viz.tests.test_session import _TestSessionClass - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - - -@pytest.fixture(autouse=True, scope="class", params=Session.__subclasses__()) -def plot_class(request): - request.cls._cls = request.param - - -class TestSessionSubClass(_TestSessionClass): - pass diff --git a/src/sisl/viz/tests/__init__.py b/src/sisl/viz/tests/__init__.py deleted file mode 100644 index 448bb8652d..0000000000 --- a/src/sisl/viz/tests/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. diff --git a/src/sisl/viz/tests/test_namedhistory.py b/src/sisl/viz/tests/test_namedhistory.py deleted file mode 100644 index db67e1316b..0000000000 --- a/src/sisl/viz/tests/test_namedhistory.py +++ /dev/null @@ -1,77 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -import numpy as np -import pytest - -from sisl.viz.configurable import NamedHistory - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - - -def test_named_history(): - - test_keys = ["hey", "nope"] - val_key, def_key = test_keys - - s = NamedHistory({val_key: 2}, defaults={def_key: 5}) - - #Check that all keys have been incorporated - assert np.all([key in s for key in test_keys]) - assert np.all([len(s._vals[key]) == 1 for key in test_keys]) - assert s.current[val_key] == 2 - - # Check that the history updates succesfully - s.update(**{val_key: 3}) - assert s.last_updated == [val_key] - assert s.diff_keys(1, 0) == [val_key] - assert s.last_update_for(val_key) == 1 - assert len(s._vals[val_key]) == 2 - assert s.current[val_key] == 3 - assert val_key in s.last_delta - assert s.last_delta[val_key]["before"] == 2 - - # Check that it can correctly undo settings - s.undo() - assert len(s._vals[val_key]) == 2 - assert s._vals[val_key][1] is None - assert s.current[val_key] == 2 - - # One last check with multiple updates - s.update(**{val_key: 5}) - s.update(**{val_key: 6}) - assert len(s._vals[val_key]) == 4 - - -def test_history_item_getting(): - - test_keys = ["hey", "nope"] - val_key, def_key = test_keys - - s = NamedHistory({val_key: 2}, defaults={def_key: 5}) - s.update(**{val_key: 6}) - - assert s[-1][val_key] == 6 - - assert len(s[[-1, -2]]) == 2 - assert isinstance(s[:], dict) - assert len(s[0:1][val_key]) == 1 - - assert s[val_key] == [2, 6] - assert isinstance(s[test_keys], dict) - assert len(s[test_keys][val_key]) == len(s) - - -def test_update_array(): - - # Here we are just checking that we can update with numpy arrays - # without an error. This is because comparing two numpy arrays - # raises an Exception, so we need to make sure this doesn't happen - - test_keys = ["hey", "nope"] - val_key, def_key = test_keys - - s = NamedHistory({val_key: 2}, defaults={def_key: 5}) - - s.update(**{val_key: np.array([1, 2, 3])}) - s.update(**{val_key: np.array([1, 2, 3, 4])}) diff --git a/src/sisl/viz/tests/test_plot.py b/src/sisl/viz/tests/test_plot.py deleted file mode 100644 index 0ee9a03671..0000000000 --- a/src/sisl/viz/tests/test_plot.py +++ /dev/null @@ -1,233 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -""" - -This file tests general Plot behavior. - -""" -import os -import warnings -from copy import deepcopy - -import numpy as np -import pytest - -import sisl -from sisl.messages import SislInfo, SislWarning -from sisl.viz._presets import PRESETS -from sisl.viz.plot import Animation, MultiplePlot, Plot, SubPlots -from sisl.viz.plots import * -from sisl.viz.plotutils import load - -try: - import dill - skip_dill = pytest.mark.skipif(False, reason="dill not available") -except ImportError: - skip_dill = pytest.mark.skipif(True, reason="dill not available") - -# ------------------------------------------------------------ -# Checks that will be available to be used on any plot class -# ------------------------------------------------------------ - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - - -class _TestPlotClass: - - _cls = Plot - - def _init_plot_without_warnings(self, *args, **kwargs): - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return self._cls(*args, **kwargs) - - def test_documentation(self): - - doc = self._cls.__doc__ - - # Check that it has documentation - assert doc is not None, f'{self._cls.__name__} does not have documentation' - - # Check that all params are in the documentation - params = [param.key for param in self._cls._get_class_params()[0]] - missing_params = list(filter(lambda key: key not in doc, params)) - assert len(missing_params) == 0, f"The following parameters are missing in the documentation of {self._cls.__name__}: {missing_params}" - - missing_help = list(map(lambda p: p.key, - filter(lambda p: not getattr(p, "help", None), self._cls._parameters) - )) - assert len(missing_help) == 0, f"Parameters {missing_help} in {self._cls.__name__} are missing a help message. Don't be lazy!" - - def test_plot_settings(self): - - plot = self._init_plot_without_warnings() - # Check that all the parameters have been passed to the settings - assert np.all([param.key in plot.settings for param in self._cls._parameters]) - # Build some test settings - new_settings = {'root_fdf': 'Test'} - # Update settings and check they have been succesfully updated - old_settings = deepcopy(plot.settings) - plot.update_settings(**new_settings, run_updates=False) - assert np.all([plot.settings[key] == val for key, val in new_settings.items()]) - # Undo settings and check if they go back to the previous ones - plot.undo_settings(run_updates=False) - - assert np.all([plot.settings[key] == - val for key, val in old_settings.items()]) - - # Build a plot directly with test settings and check if it works - plot = self._init_plot_without_warnings(**new_settings) - assert np.all([plot.settings[key] == val for key, val in new_settings.items()]) - - def test_plot_shortcuts(self): - - plot = self._init_plot_without_warnings() - # Build a fake shortcut and test it. - def dumb_shortcut(a=2): - plot.a_value = a - # Add it without extra parameters - plot.add_shortcut("ctrl+a", "Dumb shortcut", dumb_shortcut) - # Call it without extra parameters - plot.call_shortcut("ctrl+a") - assert plot.a_value == 2 - # Call it with extra parameters - plot.call_shortcut("ctrl+a", a=5) - assert plot.a_value == 5 - # Add the shortcut directly with extra parameters - plot.add_shortcut("ctrl+alt+a", "Dumb shortcut 2", dumb_shortcut, a=8) - # And test that it works - plot.call_shortcut("ctrl+alt+a") - assert plot.a_value == 8 - - def test_presets(self): - - plot = self._init_plot_without_warnings(presets="dark") - - assert np.all([key not in plot.settings or plot.settings[key] == val for key, val in PRESETS["dark"].items()]) - - @skip_dill - def test_save_and_load(self, obj=None): - - file_name = "./__sislsaving_test" - - if obj is None: - obj = self._init_plot_without_warnings() - - obj.save(file_name) - - try: - plot = load(file_name) - except Exception as e: - os.remove(file_name) - raise e - - os.remove(file_name) - -# ------------------------------------------------------------ -# Actual tests on the Plot parent class -# ------------------------------------------------------------ - - -class TestPlot(_TestPlotClass): - - _cls = Plot - - -# ------------------------------------------------------------ -# Tests for the MultiplePlot class -# ------------------------------------------------------------ - -class TestMultiplePlot(_TestPlotClass): - - _cls = MultiplePlot - - def test_init_from_kw(self): - - kw = MultiplePlot._kw_from_cls(self._cls) - - geom = sisl.geom.graphene() - - multiple_plot = geom.plot(show_cell=["box", False, False], backend=None, axes=[0, 1], **{kw: "show_cell"}) - - assert isinstance(multiple_plot, self._cls), f"{self._cls} was not correctly initialized using the {kw} keyword argument" - assert len(multiple_plot.children) == 3, "Child plots were not properly generated" - - def test_object_sharing(self): - - kw = MultiplePlot._kw_from_cls(self._cls) - - geom = sisl.geom.graphene() - - multiple_plot = geom.plot(show_cell=["box", False, False], backend=None, axes=[0, 1], **{kw: "show_cell"}) - geoms_ids = [id(plot.geometry) for plot in multiple_plot] - assert len(set(geoms_ids)) == 1, f"{self._cls} is not properly sharing objects" - - multiple_plot = GeometryPlot(geometry=[sisl.geom.graphene(bond=bond) for bond in (1.2, 1.6)], backend=None, axes=[0, 1], **{kw: "geometry"}) - geoms_ids = [id(plot.geometry) for plot in multiple_plot] - assert len(set(geoms_ids)) > 1, f"{self._cls} is sharing objects that should not be shared" - - def test_update_settings(self): - - kw = MultiplePlot._kw_from_cls(self._cls) - - geom = sisl.geom.graphene() - show_cell = ["box", False, False] - - multiple_plot = geom.plot(show_cell=show_cell, backend=None, axes=[0, 1], **{kw: "show_cell"}) - assert len(multiple_plot.children) == 3 - - for i, show_cell_val in enumerate(show_cell): - assert multiple_plot[i]._for_backend["show_cell"] == show_cell_val - - multiple_plot.update_children_settings(show_cell="box", children_sel=[1]) - for i, show_cell_val in enumerate(show_cell): - if i == 1: - show_cell_val = "box" - assert multiple_plot[i]._for_backend["show_cell"] == show_cell_val - -# ------------------------------------------------------------ -# Tests for the SubPlots class -# ------------------------------------------------------------ - - -class TestSubPlots(TestMultiplePlot): - - _cls = SubPlots - - def test_subplots_arrangement(self): - - geom = sisl.geom.graphene() - - # We are going to try some things here and check that they don't fail - # as we have no way of checking the actual layout of the subplots - plot = GeometryPlot.subplots('show_bonds', [True, False], backend=None, - fixed={'geometry': geom, 'axes': [0, 1], "backend": None}, _debug=True) - - plot.update_settings(cols=2) - - plot.update_settings(rows=2) - - # This should issue a warning stating that one plot will be missing - with pytest.warns(SislWarning): - plot.update_settings(cols=1, rows=1) - - plot.update_settings(cols=None, rows=None, arrange='square') - -# ------------------------------------------------------------ -# Tests for the Animation class -# ------------------------------------------------------------ - - -class _TestAnimation(TestMultiplePlot): - - PlotClass = Animation - - -def test_calling_Plot(): - # Just check that it doesn't raise any error - with pytest.warns(SislInfo): - plot = Plot("nonexistent.LDOS") - - assert isinstance(plot, GridPlot) diff --git a/src/sisl/viz/tests/test_session.py b/src/sisl/viz/tests/test_session.py deleted file mode 100644 index fdd06e2970..0000000000 --- a/src/sisl/viz/tests/test_session.py +++ /dev/null @@ -1,64 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from copy import deepcopy - -import numpy as np -import pytest - -from sisl.viz import Session -from sisl.viz.plots import * -from sisl.viz.tests.test_plot import _TestPlotClass - -try: - import dill - skip_dill = pytest.mark.skipif(False, reason="dill not available") -except ImportError: - skip_dill = pytest.mark.skipif(True, reason="dill not available") - -# This file tests general session behavior - -# ------------------------------------------------------------ -# Checks that will be available to be used on any session class -# ------------------------------------------------------------ - -pytestmark = [pytest.mark.viz, pytest.mark.plotly] - - -class _TestSessionClass: - - _cls = Session - - def test_session_settings(self): - - session = self._cls() - # Check that all the parameters have been passed to the settings - assert np.all([param.key in session.settings for param in self._cls._parameters]) - # Build some test settings - new_settings = {'root_dir': 'Test', 'search_depth': [4, 6]} - # Update settings and check they have been succesfully updated - old_settings = deepcopy(session.settings) - session.update_settings(**new_settings, run_updates=False) - assert np.all([session.settings[key] == val for key, val in new_settings.items()]) - # Undo settings and check if they go back to the previous ones - session.undo_settings(run_updates=False) - assert np.all([session.settings[key] == - val for key, val in old_settings.items()]) - - # Build a session directly with test settings and check if it works - session = self._cls(**new_settings) - assert np.all([session.settings[key] == val for key, val in new_settings.items()]) - - @skip_dill - def test_save_and_load(self): - - _TestPlotClass.test_save_and_load(self, obj=self._cls()) - - -# ------------------------------------------------------------ -# Actual tests on the Session parent class -# ------------------------------------------------------------ - -class TestSession(_TestSessionClass): - - _cls = Session diff --git a/src/sisl/viz/types.py b/src/sisl/viz/types.py new file mode 100644 index 0000000000..ceaf3027ca --- /dev/null +++ b/src/sisl/viz/types.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, NewType, Optional, Sequence, TypedDict, Union + +import numpy as np +import numpy.typing as npt + +import sisl +from sisl.geometry import AtomCategory, Geometry +from sisl.io.sile import BaseSile +from sisl.lattice import Lattice, LatticeChild +from sisl.typing import AtomsArgument + +PathLike = Union[str, Path, BaseSile] + +Color = NewType("Color", str) + +GeometryLike = Union[sisl.Geometry, Any] + +Axis = Union[Literal["x", "y", "z", "-x", "-y", "-z", "a", "b", "c", "-a", "-b", "-c"], Sequence[float]] +Axes = Sequence[Axis] + +GeometryLike = Union[Geometry, PathLike] + +@dataclass +class StyleSpec: + color: Optional[Color] = None + size: Optional[float] = None + opacity: Optional[float] = 1 + dash: Optional[Literal["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"]] = None + +@dataclass +class AtomsStyleSpec(StyleSpec): + atoms: AtomsArgument = None + vertices: Optional[float] = 15 + +class AtomsStyleSpecDict(TypedDict): + atoms: AtomsArgument + color: Optional[Color] + size: Optional[float] + opacity: Optional[float] + vertices: Optional[float] + +@dataclass +class Query: + active: bool = True + name: str = "" + +Queries = Sequence[Query] + +SpeciesSpec = NewType("SpeciesSpec", Optional[Sequence[str]]) + +OrbitalsNames = NewType("OrbitalsNames", Optional[Sequence[str]]) +SpinIndex = NewType("SpinIndex", Optional[Sequence[Literal[0, 1]]]) + +@dataclass +class OrbitalQuery(Query): + atoms: AtomsArgument = None + species : SpeciesSpec = None + orbitals: OrbitalsNames = None + n: Optional[Sequence[int]] = None + l: Optional[Sequence[int]] = None + m: Optional[Sequence[int]] = None + spin: SpinIndex = None + scale: float = 1 + reduce: Literal["mean", "sum"] = "sum" + spin_reduce: Literal["mean", "sum"] = "sum" + +@dataclass +class OrbitalStyleQuery(StyleSpec, OrbitalQuery): + ... + +OrbitalQueries = Sequence[OrbitalQuery] +OrbitalStyleQueries = Sequence[OrbitalStyleQuery] + +CellLike = Union[npt.NDArray[np.float_], Lattice, LatticeChild] + +@dataclass +class ArrowSpec: + scale: float = 1. + color: Any = None + width: float = 1. + opacity: float = 1. + name: str = "arrow" + annotate: bool = False + arrowhead_scale: float = 0.2 + arrowhead_angle: float = 20 + +@dataclass +class AtomArrowSpec: + data: Any + atoms: AtomsArgument = None + scale: float = 1. + color: Any = None + width: float = 1. + opacity: float = 1. + name: str = "arrow" + annotate: bool = False + arrowhead_scale: float = 0.2 + arrowhead_angle: float = 20 +