Source code for scanspec.plot

"""`plot_spec` to visualize a scan."""

from collections.abc import Generator, Iterable
from itertools import cycle
from typing import Any

import numpy as np
import numpy.typing as npt
from matplotlib import colors, patches
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d import Axes3D, proj3d  # type: ignore
from scipy import interpolate  # type: ignore

from .core import stack2dimension
from .specs import Ellipse, Polygon, Spec

__all__ = ["plot_spec"]


def _plot_arrays(axes: Axes, arrays: list[npt.NDArray[np.float64]], **kwargs: Any):
    if len(arrays) > 2:
        axes.plot3D(arrays[2], arrays[1], arrays[0], **kwargs)  # type: ignore
    elif len(arrays) == 2:
        axes.plot(arrays[1], arrays[0], **kwargs)  # type: ignore
    else:
        axes.plot(arrays[0], np.zeros(len(arrays[0])), **kwargs)  # type: ignore


# https://stackoverflow.com/a/11156353
class Arrow3D(patches.FancyArrowPatch):
    def __init__(
        self,
        xs: npt.NDArray[np.float64],
        ys: npt.NDArray[np.float64],
        zs: npt.NDArray[np.float64],
        *args: Any,
        **kwargs: Any,
    ):
        super().__init__((0, 0), (0, 0), *args, **kwargs)  # type: ignore
        self._verts3d = xs, ys, zs

    # Added here because of https://github.com/matplotlib/matplotlib/issues/21688
    def do_3d_projection(self, renderer: Any = None):  # type: ignore
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)  # type: ignore
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))  # type: ignore

        return np.min(zs)  # type: ignore

    @property
    def verts3d(
        self,
    ) -> tuple[
        npt.NDArray[np.float64],
        npt.NDArray[np.float64],
        npt.NDArray[np.float64],
    ]:
        return self._verts3d


def _plot_arrow(axes: Axes, arrays: list[npt.NDArray[np.float64]]):
    if len(arrays) == 1:
        arrays = [np.array([0, 0])] + arrays
    if len(arrays) == 2:
        head = [a[-1] for a in reversed(arrays)]
        tail = [a[-1] - (a[-1] - a[-2]) * 0.1 for a in reversed(arrays)]
        axes.annotate(  # type: ignore
            "",
            tuple(head[:2]),
            tuple(tail[:2]),
            arrowprops={"color": "lightgrey", "arrowstyle": "-|>"},
        )
    elif len(arrays) == 3:
        arrows = [a[-2:] for a in reversed(arrays)]
        a = Arrow3D(*arrows[:3], mutation_scale=10, arrowstyle="-|>", color="lightgrey")
        axes.add_artist(a)


def _plot_spline(
    axes: Axes,
    ranges: list[float],
    arrays: list[npt.NDArray[np.float64]],
    index_colours: dict[int, str],
) -> Iterable[list[npt.NDArray[np.float64]]]:
    scaled_arrays = [a / r for a, r in zip(arrays, ranges, strict=False)]
    # Define curves parametrically
    t = np.zeros(len(arrays[0]))
    t[1:] = np.sqrt(sum((arr[1:] - arr[:-1]) ** 2 for arr in scaled_arrays))
    t = np.cumsum(t)
    if t[-1] > 0:
        # Can't make a spline that starts and ends in the same place, so add a small
        # delta
        for s, r in zip(scaled_arrays, ranges, strict=False):
            if s[0] == s[-1]:
                s += np.linspace(0, r * 1e-7, len(s))
        # There are no duplicated points, plot a spline
        t /= t[-1]
        # Scale the arrays so splines don't favour larger scaled axes
        tck, _ = interpolate.splprep(scaled_arrays, k=2, s=0)  # type: ignore
        starts = sorted(index_colours)
        stops = starts[1:] + [len(arrays[0]) - 1]
        for start, stop in zip(starts, stops, strict=False):
            start_value: float = t[start]
            stop_value: float = t[stop]
            tnew = np.linspace(start_value, stop_value, num=1001)
            spline: npt.NDArray[np.float64] = interpolate.splev(tnew, tck)  # type: ignore
            # Scale the splines back to the original scaling
            unscaled_splines = [a * r for a, r in zip(spline, ranges, strict=False)]
            _plot_arrays(axes, list(unscaled_splines), color=index_colours[start])  # type: ignore
            yield unscaled_splines  # type: ignore


def _get_boundaries(spec: Spec[Any]) -> Generator[patches.Patch, None, None]:
    if isinstance(spec, Ellipse):
        xy = spec.x_centre, spec.y_centre
        width = spec.x_diameter
        height = spec.y_diameter
        yield patches.Ellipse(xy, width, height, fill=False)
    elif isinstance(spec, Polygon):
        xy_verts = spec.vertices
        yield patches.Polygon(xy_verts, fill=False)
    else:
        for name in spec.__dict__.keys():
            if isinstance(s := getattr(spec, name), Spec):
                yield from _get_boundaries(s)  # type: ignore


[docs] def plot_spec(spec: Spec[Any], title: str | None = None): """Plot a spec, drawing the path taken through the scan. Uses a different colour for each frame, grey for the turnarounds. .. example_spec:: from scanspec.specs import Linspace, Ellipse spec = Linspace("z", 1, 3, 3) * Ellipse("x", 1, 01.8, 0.2, "y", 2, snake=True) """ dims = spec.calculate() dim = stack2dimension(dims) axes = spec.axes() ndims = len(axes) # Setup axes if ndims > 2: plt.figure(figsize=(6, 6)) # type: ignore plt_axes: Axes = plt.axes(projection="3d") # type: ignore plt_axes.grid(False) # type: ignore if isinstance(plt_axes, Axes3D): plt_axes.set_zlabel(axes[-3]) # type: ignore plt_axes.set_ylabel(axes[-2]) # type: ignore plt_axes.view_init(elev=15) # type: ignore else: raise TypeError( "Expected matplotlib to create an Axes3D object, " f"instead got: {plt_axes}" ) elif ndims == 2: plt.figure(figsize=(6, 6)) # type: ignore plt_axes = plt.axes() # type: ignore plt_axes.set_ylabel(axes[-2]) # type: ignore else: plt.figure(figsize=(6, 2)) # type: ignore plt_axes = plt.axes() # type: ignore plt_axes.yaxis.set_visible(False) plt_axes.set_xlabel(axes[-1]) # type: ignore # Title with dimension sizes title = title or ", ".join(f"Dim[{' '.join(d.axes())} len={len(d)}]" for d in dims) plt.title(title) # type: ignore # Plot regions if ndims <= 2: for patch in _get_boundaries(spec): plt_axes.add_patch(patch) # Plot the splines tail: dict[str, npt.NDArray[np.float64] | None] = dict.fromkeys(axes) ranges = [max(float(np.max(v) - np.min(v)), 0.0001) for v in dim.midpoints.values()] seg_col = cycle(colors.TABLEAU_COLORS) last_index = 0 splines = None # The first element of gap is undefined (as there is no previous frame) # so discard it gap_indices = list(np.nonzero(dim.gap[1:])[0] + 1) for index in gap_indices + [len(dim)]: num_points = index - last_index arrays: list[npt.NDArray[np.float64]] = [] turnaround: list[npt.NDArray[np.float64]] = [] for a in axes: # Add the midpoints and the lower and upper bounds arr = np.empty(num_points * 2 + 1) arr[:-1:2] = dim.lower[a][last_index:index] arr[1::2] = dim.midpoints[a][last_index:index] arr[-1] = dim.upper[a][index - 1] arrays.append(arr) # Add the turnaround axis_tail = tail[a] if axis_tail is not None: # Already had a tail, add lead in points axis_tail[2:] = np.linspace(-0.01, 0, 2) * (arr[1] - arr[0]) + arr[0] turnaround.append(axis_tail) # Add tail off points axis_tail = np.empty(4) axis_tail[:2] = np.linspace(0, 0.01, 2) * (arr[-1] - arr[-2]) + arr[-1] tail[a] = axis_tail last_index = index arrow_arr = None if turnaround: # If we didn't move then plot a straight line from start to stop if all(t[1] - t[0] == 0 for t in turnaround): for t in turnaround: t[1] += (t[2] - t[1]) / 4 if all(t[3] - t[2] == 0 for t in turnaround): for t in turnaround: t[2] -= (t[2] - t[1]) / 4 # Plot the turnaround arrow_arr = list( _plot_spline(plt_axes, ranges, turnaround, {0: "lightgrey"}) )[0] # Plot the points index_colours = {2 * i: next(seg_col) for i in range(num_points)} splines = list(_plot_spline(plt_axes, ranges, arrays, index_colours)) if arrow_arr: # Plot the arrow on the turnaround _plot_arrow(plt_axes, arrow_arr) elif splines: # Plot the starting arrow in the direction of the first point arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]] _plot_arrow(plt_axes, arrow_arr) else: # First point isn't moving, put a right caret marker _plot_arrays( plt_axes, [np.array([dim.lower[a][0]]) for a in axes], marker=5, color="lightgrey", ) # Plot the capture points if len(dim) < 200: arrays = [dim.midpoints[a] for a in axes] _plot_arrays(plt_axes, arrays, linestyle="", marker=".", color="k") # Plot the end _plot_arrays( plt_axes, [np.array([dim.upper[a][-1]]) for a in axes], marker="x", color="lightgrey", ) plt.show() # type: ignore