Source code for sisl.viz.plots.matrix

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Optional, Union

import numpy as np
from scipy.sparse import spmatrix

import sisl

from ..figure import Figure, get_figure
from ..plot import Plot
from ..plotters.grid import draw_grid, draw_grid_arrows
from ..plotters.matrix import draw_matrix_separators, set_matrix_axes
from ..plotters.plot_actions import combined
from ..processors.matrix import (
    determine_color_midpoint,
    get_geometry_from_matrix,
    get_matrix_mode,
    matrix_as_array,
    sanitize_matrix_arrows,
)
from ..types import Colorscale


def atomic_matrix_plot(
    matrix: Union[
        np.ndarray, sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, spmatrix
    ],
    dim: int = 0,
    isc: Optional[int] = None,
    fill_value: Optional[float] = None,
    geometry: Union[sisl.Geometry, None] = None,
    atom_lines: Union[bool, dict] = False,
    orbital_lines: Union[bool, dict] = False,
    sc_lines: Union[bool, dict] = False,
    color_pixels: bool = True,
    colorscale: Optional[Colorscale] = "RdBu",
    crange: Optional[tuple[float, float]] = None,
    cmid: Optional[float] = None,
    text: Optional[str] = None,
    textfont: Optional[dict] = {},
    set_labels: bool = False,
    constrain_axes: bool = True,
    arrows: list[dict] = [],
    backend: str = "plotly",
) -> Figure:
    """Plots a (possibly sparse) matrix where rows and columns are either orbitals or atoms.

    Parameters
    ----------
    matrix:
        the matrix, either as a numpy array or as a sisl sparse matrix.
    dim:
        If the matrix has a third dimension (e.g. spin), which index to
        plot in that third dimension.
    isc:
        If the matrix contains data for an auxiliary supercell, the index of the
        cell to plot. If None, the whole matrix is plotted.
    fill_value:
        If the matrix is sparse, the value to use for the missing entries.
    geometry:
        Only needed if the matrix does not contain a geometry (e.g. it is a numpy array)
        and separator lines or labels are requested.
    atom_lines:
        If a boolean, whether to draw lines separating atom blocks, using default styles.
        If a dict, draws the lines with the specified plotly line styles.
    orbital_lines:
        If a boolean, whether to draw lines separating blocks of orbital sets, using default styles.
        If a dict, draws the lines with the specified plotly line styles.
    sc_lines:
        If a boolean, whether to draw lines separating the supercells, using default styles.
        If a dict, draws the lines with the specified plotly line styles.
    color_pixels:
        Whether to color the pixels of the matrix according to the colorscale.
    colorscale:
        The colorscale to use to color the pixels.
    crange:
        The minimum and maximum values of the colorscale.
    cmid:
        The midpoint of the colorscale. If ``crange`` is provided, this is ignored.

        If None and crange is also None, the midpoint
        is set to 0 if the data contains both positive and negative values.
    text:
        If provided, show text of pixel value with the specified format.
        E.g. text=".3f" shows the value with three decimal places.
    textfont:
        The font to use for the text.
        This is a dictionary that may contain the keys "family", "size", "color".
    set_labels:
        Whether to set the axes labels to the atom/orbital that each row and column corresponds to.
        For orbitals the labels will be of the form "Atom: (l, m)", where `Atom` is the index of
        the atom and l and m are the quantum numbers of the orbital.
    constrain_axes:
        Whether to set the ranges of the axes to exactly fit the matrix.
    backend:
        The backend to use for plotting.
    """

    geometry = get_geometry_from_matrix(matrix, geometry)
    mode = get_matrix_mode(matrix)

    matrix_array = matrix_as_array(matrix, dim=dim, isc=isc, fill_value=fill_value)

    color_midpoint = determine_color_midpoint(matrix, cmid=cmid, crange=crange)

    matrix_actions = draw_grid(
        matrix_array,
        crange=crange,
        cmid=color_midpoint,
        color_pixels_2d=color_pixels,
        colorscale=colorscale,
        coloraxis_name="matrix_vals",
        textformat=text,
        textfont=textfont,
    )

    arrows = sanitize_matrix_arrows(arrows)

    arrow_actions = draw_grid_arrows(matrix_array, arrows)

    draw_supercells = isc is None

    axes_actions = set_matrix_axes(
        matrix_array,
        geometry,
        matrix_mode=mode,
        constrain_axes=constrain_axes,
        set_labels=set_labels,
    )

    sc_lines_actions = draw_matrix_separators(
        sc_lines,
        geometry,
        matrix_mode=mode,
        separator_mode="supercells",
        draw_supercells=draw_supercells,
        showlegend=False,
    )

    atom_lines_actions = draw_matrix_separators(
        atom_lines,
        geometry,
        matrix_mode=mode,
        separator_mode="atoms",
        draw_supercells=draw_supercells,
        showlegend=False,
    )

    orbital_lines_actions = draw_matrix_separators(
        orbital_lines,
        geometry,
        matrix_mode=mode,
        separator_mode="orbitals",
        draw_supercells=draw_supercells,
        showlegend=False,
    )

    all_actions = combined(
        matrix_actions,
        arrow_actions,
        orbital_lines_actions,
        atom_lines_actions,
        sc_lines_actions,
        axes_actions,
    )

    return get_figure(backend, all_actions)


[docs] class AtomicMatrixPlot(Plot): function = staticmethod(atomic_matrix_plot)