Source code for scanspec.plot

"""`plot_spec` to visualize a scan."""

from collections.abc import Iterable, Iterator
from itertools import cycle
from typing import Any

import numpy as np
import numpy.typing as npt
from matplotlib import colors, patches
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d import Axes3D, proj3d  # type: ignore
from scipy import interpolate  # type: ignore

from .core import stack2dimension
from .regions import Circle, Ellipse, Polygon, Rectangle, Region, find_regions
from .specs import Spec

__all__ = ["plot_spec"]


def _plot_arrays(axes: Axes, arrays: list[npt.NDArray[np.float64]], **kwargs: Any):
    if len(arrays) > 2:
        axes.plot3D(arrays[2], arrays[1], arrays[0], **kwargs)  # type: ignore
    elif len(arrays) == 2:
        axes.plot(arrays[1], arrays[0], **kwargs)  # type: ignore
    else:
        axes.plot(arrays[0], np.zeros(len(arrays[0])), **kwargs)  # type: ignore


# https://stackoverflow.com/a/11156353
class Arrow3D(patches.FancyArrowPatch):
    def __init__(
        self,
        xs: npt.NDArray[np.float64],
        ys: npt.NDArray[np.float64],
        zs: npt.NDArray[np.float64],
        *args: Any,
        **kwargs: Any,
    ):
        super().__init__((0, 0), (0, 0), *args, **kwargs)  # type: ignore
        self._verts3d = xs, ys, zs

    # Added here because of https://github.com/matplotlib/matplotlib/issues/21688
    def do_3d_projection(self, renderer: Any = None):  # type: ignore
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)  # type: ignore
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))  # type: ignore

        return np.min(zs)  # type: ignore

    @property
    def verts3d(
        self,
    ) -> tuple[
        npt.NDArray[np.float64],
        npt.NDArray[np.float64],
        npt.NDArray[np.float64],
    ]:
        return self._verts3d


def _plot_arrow(axes: Axes, arrays: list[npt.NDArray[np.float64]]):
    if len(arrays) == 1:
        arrays = [np.array([0, 0])] + arrays
    if len(arrays) == 2:
        head = [a[-1] for a in reversed(arrays)]
        tail = [a[-1] - (a[-1] - a[-2]) * 0.1 for a in reversed(arrays)]
        axes.annotate(  # type: ignore
            "",
            tuple(head[:2]),
            tuple(tail[:2]),
            arrowprops={"color": "lightgrey", "arrowstyle": "-|>"},
        )
    elif len(arrays) == 3:
        arrows = [a[-2:] for a in reversed(arrays)]
        a = Arrow3D(*arrows[:3], mutation_scale=10, arrowstyle="-|>", color="lightgrey")
        axes.add_artist(a)


def _plot_spline(
    axes: Axes,
    ranges: list[float],
    arrays: list[npt.NDArray[np.float64]],
    index_colours: dict[int, str],
) -> Iterable[list[npt.NDArray[np.float64]]]:
    scaled_arrays = [a / r for a, r in zip(arrays, ranges, strict=False)]
    # Define curves parametrically
    t = np.zeros(len(arrays[0]))
    t[1:] = np.sqrt(sum((arr[1:] - arr[:-1]) ** 2 for arr in scaled_arrays))
    t = np.cumsum(t)
    if t[-1] > 0:
        # Can't make a spline that starts and ends in the same place, so add a small
        # delta
        for s, r in zip(scaled_arrays, ranges, strict=False):
            if s[0] == s[-1]:
                s += np.linspace(0, r * 1e-7, len(s))
        # There are no duplicated points, plot a spline
        t /= t[-1]
        # Scale the arrays so splines don't favour larger scaled axes
        tck, _ = interpolate.splprep(scaled_arrays, k=2, s=0)  # type: ignore
        starts = sorted(index_colours)
        stops = starts[1:] + [len(arrays[0]) - 1]
        for start, stop in zip(starts, stops, strict=False):
            start_value: float = t[start]
            stop_value: float = t[stop]
            tnew = np.linspace(start_value, stop_value, num=1001)
            spline: npt.NDArray[np.float64] = interpolate.splev(tnew, tck)  # type: ignore
            # Scale the splines back to the original scaling
            unscaled_splines = [a * r for a, r in zip(spline, ranges, strict=False)]
            _plot_arrays(axes, list(unscaled_splines), color=index_colours[start])  # type: ignore
            yield unscaled_splines  # type: ignore


[docs] def plot_spec(spec: Spec[Any], title: str | None = None): """Plot a spec, drawing the path taken through the scan. Uses a different colour for each frame, grey for the turnarounds, and marks the midpoints with a filled circle if there are less than 200 of them. If the scan is 2D then 2D regions are shown in black. .. example_spec:: from scanspec.specs import Line from scanspec.regions import Circle cube = Line("z", 1, 3, 3) * Line("y", 1, 3, 10) * ~Line("x", 0, 2, 10) spec = cube & Circle("x", "y", 1, 2, 0.9) """ dims = spec.calculate() dim = stack2dimension(dims) axes = spec.axes() ndims = len(axes) # Setup axes if ndims > 2: plt.figure(figsize=(6, 6)) # type: ignore plt_axes: Axes = plt.axes(projection="3d") # type: ignore plt_axes.grid(False) # type: ignore if isinstance(plt_axes, Axes3D): plt_axes.set_zlabel(axes[-3]) # type: ignore plt_axes.set_ylabel(axes[-2]) # type: ignore plt_axes.view_init(elev=15) # type: ignore else: raise TypeError( "Expected matplotlib to create an Axes3D object, " f"instead got: {plt_axes}" ) elif ndims == 2: plt.figure(figsize=(6, 6)) # type: ignore plt_axes = plt.axes() # type: ignore plt_axes.set_ylabel(axes[-2]) # type: ignore else: plt.figure(figsize=(6, 2)) # type: ignore plt_axes = plt.axes() # type: ignore plt_axes.yaxis.set_visible(False) plt_axes.set_xlabel(axes[-1]) # type: ignore # Title with dimension sizes title = title or ", ".join(f"Dim[{' '.join(d.axes())} len={len(d)}]" for d in dims) plt.title(title) # type: ignore # Plot any Regions if ndims <= 2: regions: Iterator[Region[Any]] = find_regions(spec) for region in regions: if isinstance(region, Rectangle): xy = (region.x_min, region.y_min) width = region.x_max - region.x_min height = region.y_max - region.y_min plt_axes.add_patch( patches.Rectangle(xy, width, height, angle=region.angle, fill=False) ) elif isinstance(region, Circle): xy = (region.x_middle, region.y_middle) plt_axes.add_patch(patches.Circle(xy, region.radius, fill=False)) elif isinstance(region, Ellipse): xy = (region.x_middle, region.y_middle) width = region.x_radius * 2 height = region.y_radius * 2 angle = region.angle plt_axes.add_patch( patches.Ellipse(xy, width, height, angle=angle, fill=False) ) elif isinstance(region, Polygon): # *xy_verts* is a numpy array with shape Nx2. xy_verts = np.column_stack((region.x_verts, region.y_verts)) plt_axes.add_patch(patches.Polygon(xy_verts, fill=False)) # Plot the splines tail: dict[str, npt.NDArray[np.float64] | None] = dict.fromkeys(axes) ranges = [max(float(np.max(v) - np.min(v)), 0.0001) for v in dim.midpoints.values()] seg_col = cycle(colors.TABLEAU_COLORS) last_index = 0 splines = None # The first element of gap is undefined (as there is no previous frame) # so discard it gap_indices = list(np.nonzero(dim.gap[1:])[0] + 1) for index in gap_indices + [len(dim)]: num_points = index - last_index arrays: list[npt.NDArray[np.float64]] = [] turnaround: list[npt.NDArray[np.float64]] = [] for a in axes: # Add the midpoints and the lower and upper bounds arr = np.empty(num_points * 2 + 1) arr[:-1:2] = dim.lower[a][last_index:index] arr[1::2] = dim.midpoints[a][last_index:index] arr[-1] = dim.upper[a][index - 1] arrays.append(arr) # Add the turnaround axis_tail = tail[a] if axis_tail is not None: # Already had a tail, add lead in points axis_tail[2:] = np.linspace(-0.01, 0, 2) * (arr[1] - arr[0]) + arr[0] turnaround.append(axis_tail) # Add tail off points axis_tail = np.empty(4) axis_tail[:2] = np.linspace(0, 0.01, 2) * (arr[-1] - arr[-2]) + arr[-1] tail[a] = axis_tail last_index = index arrow_arr = None if turnaround: # If we didn't move then plot a straight line from start to stop if all(t[1] - t[0] == 0 for t in turnaround): for t in turnaround: t[1] += (t[2] - t[1]) / 4 if all(t[3] - t[2] == 0 for t in turnaround): for t in turnaround: t[2] -= (t[2] - t[1]) / 4 # Plot the turnaround arrow_arr = list( _plot_spline(plt_axes, ranges, turnaround, {0: "lightgrey"}) )[0] # Plot the points index_colours = {2 * i: next(seg_col) for i in range(num_points)} splines = list(_plot_spline(plt_axes, ranges, arrays, index_colours)) if arrow_arr: # Plot the arrow on the turnaround _plot_arrow(plt_axes, arrow_arr) elif splines: # Plot the starting arrow in the direction of the first point arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]] _plot_arrow(plt_axes, arrow_arr) else: # First point isn't moving, put a right caret marker _plot_arrays( plt_axes, [np.array([dim.lower[a][0]]) for a in axes], marker=5, color="lightgrey", ) # Plot the capture points if len(dim) < 200: arrays = [dim.midpoints[a] for a in axes] _plot_arrays(plt_axes, arrays, linestyle="", marker=".", color="k") # Plot the end _plot_arrays( plt_axes, [np.array([dim.upper[a][-1]]) for a in axes], marker="x", color="lightgrey", ) plt.show() # type: ignore