"""`plot_spec` to visualize a scan."""
from collections.abc import Generator, Iterable
from itertools import cycle
from typing import Any
import numpy as np
import numpy.typing as npt
from matplotlib import colors, patches
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d import Axes3D, proj3d # type: ignore
from scipy import interpolate # type: ignore
from .core import stack2dimension
from .specs import Ellipse, Polygon, Spec
__all__ = ["plot_spec"]
def _plot_arrays(axes: Axes, arrays: list[npt.NDArray[np.float64]], **kwargs: Any):
if len(arrays) > 2:
axes.plot3D(arrays[2], arrays[1], arrays[0], **kwargs) # type: ignore
elif len(arrays) == 2:
axes.plot(arrays[1], arrays[0], **kwargs) # type: ignore
else:
axes.plot(arrays[0], np.zeros(len(arrays[0])), **kwargs) # type: ignore
# https://stackoverflow.com/a/11156353
class Arrow3D(patches.FancyArrowPatch):
def __init__(
self,
xs: npt.NDArray[np.float64],
ys: npt.NDArray[np.float64],
zs: npt.NDArray[np.float64],
*args: Any,
**kwargs: Any,
):
super().__init__((0, 0), (0, 0), *args, **kwargs) # type: ignore
self._verts3d = xs, ys, zs
# Added here because of https://github.com/matplotlib/matplotlib/issues/21688
def do_3d_projection(self, renderer: Any = None): # type: ignore
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) # type: ignore
self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) # type: ignore
return np.min(zs) # type: ignore
@property
def verts3d(
self,
) -> tuple[
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
]:
return self._verts3d
def _plot_arrow(axes: Axes, arrays: list[npt.NDArray[np.float64]]):
if len(arrays) == 1:
arrays = [np.array([0, 0])] + arrays
if len(arrays) == 2:
head = [a[-1] for a in reversed(arrays)]
tail = [a[-1] - (a[-1] - a[-2]) * 0.1 for a in reversed(arrays)]
axes.annotate( # type: ignore
"",
tuple(head[:2]),
tuple(tail[:2]),
arrowprops={"color": "lightgrey", "arrowstyle": "-|>"},
)
elif len(arrays) == 3:
arrows = [a[-2:] for a in reversed(arrays)]
a = Arrow3D(*arrows[:3], mutation_scale=10, arrowstyle="-|>", color="lightgrey")
axes.add_artist(a)
def _plot_spline(
axes: Axes,
ranges: list[float],
arrays: list[npt.NDArray[np.float64]],
index_colours: dict[int, str],
) -> Iterable[list[npt.NDArray[np.float64]]]:
scaled_arrays = [a / r for a, r in zip(arrays, ranges, strict=False)]
# Define curves parametrically
t = np.zeros(len(arrays[0]))
t[1:] = np.sqrt(sum((arr[1:] - arr[:-1]) ** 2 for arr in scaled_arrays))
t = np.cumsum(t)
if t[-1] > 0:
# Can't make a spline that starts and ends in the same place, so add a small
# delta
for s, r in zip(scaled_arrays, ranges, strict=False):
if s[0] == s[-1]:
s += np.linspace(0, r * 1e-7, len(s))
# There are no duplicated points, plot a spline
t /= t[-1]
# Scale the arrays so splines don't favour larger scaled axes
tck, _ = interpolate.splprep(scaled_arrays, k=2, s=0) # type: ignore
starts = sorted(index_colours)
stops = starts[1:] + [len(arrays[0]) - 1]
for start, stop in zip(starts, stops, strict=False):
start_value: float = t[start]
stop_value: float = t[stop]
tnew = np.linspace(start_value, stop_value, num=1001)
spline: npt.NDArray[np.float64] = interpolate.splev(tnew, tck) # type: ignore
# Scale the splines back to the original scaling
unscaled_splines = [a * r for a, r in zip(spline, ranges, strict=False)]
_plot_arrays(axes, list(unscaled_splines), color=index_colours[start]) # type: ignore
yield unscaled_splines # type: ignore
def _get_boundaries(spec: Spec[Any]) -> Generator[patches.Patch, None, None]:
if isinstance(spec, Ellipse):
xy = spec.x_centre, spec.y_centre
width = spec.x_diameter
height = spec.y_diameter
yield patches.Ellipse(xy, width, height, fill=False)
elif isinstance(spec, Polygon):
xy_verts = spec.vertices
yield patches.Polygon(xy_verts, fill=False)
else:
for name in spec.__dict__.keys():
if isinstance(s := getattr(spec, name), Spec):
yield from _get_boundaries(s) # type: ignore
[docs]
def plot_spec(spec: Spec[Any], title: str | None = None):
"""Plot a spec, drawing the path taken through the scan.
Uses a different colour for each frame, grey for the turnarounds.
.. example_spec::
from scanspec.specs import Linspace, Ellipse
spec = Linspace("z", 1, 3, 3) * Ellipse("x", 1, 01.8, 0.2, "y", 2, snake=True)
"""
dims = spec.calculate()
dim = stack2dimension(dims)
axes = spec.axes()
ndims = len(axes)
# Setup axes
if ndims > 2:
plt.figure(figsize=(6, 6)) # type: ignore
plt_axes: Axes = plt.axes(projection="3d") # type: ignore
plt_axes.grid(False) # type: ignore
if isinstance(plt_axes, Axes3D):
plt_axes.set_zlabel(axes[-3]) # type: ignore
plt_axes.set_ylabel(axes[-2]) # type: ignore
plt_axes.view_init(elev=15) # type: ignore
else:
raise TypeError(
"Expected matplotlib to create an Axes3D object, "
f"instead got: {plt_axes}"
)
elif ndims == 2:
plt.figure(figsize=(6, 6)) # type: ignore
plt_axes = plt.axes() # type: ignore
plt_axes.set_ylabel(axes[-2]) # type: ignore
else:
plt.figure(figsize=(6, 2)) # type: ignore
plt_axes = plt.axes() # type: ignore
plt_axes.yaxis.set_visible(False)
plt_axes.set_xlabel(axes[-1]) # type: ignore
# Title with dimension sizes
title = title or ", ".join(f"Dim[{' '.join(d.axes())} len={len(d)}]" for d in dims)
plt.title(title) # type: ignore
# Plot regions
if ndims <= 2:
for patch in _get_boundaries(spec):
plt_axes.add_patch(patch)
# Plot the splines
tail: dict[str, npt.NDArray[np.float64] | None] = dict.fromkeys(axes)
ranges = [max(float(np.max(v) - np.min(v)), 0.0001) for v in dim.midpoints.values()]
seg_col = cycle(colors.TABLEAU_COLORS)
last_index = 0
splines = None
# The first element of gap is undefined (as there is no previous frame)
# so discard it
gap_indices = list(np.nonzero(dim.gap[1:])[0] + 1)
for index in gap_indices + [len(dim)]:
num_points = index - last_index
arrays: list[npt.NDArray[np.float64]] = []
turnaround: list[npt.NDArray[np.float64]] = []
for a in axes:
# Add the midpoints and the lower and upper bounds
arr = np.empty(num_points * 2 + 1)
arr[:-1:2] = dim.lower[a][last_index:index]
arr[1::2] = dim.midpoints[a][last_index:index]
arr[-1] = dim.upper[a][index - 1]
arrays.append(arr)
# Add the turnaround
axis_tail = tail[a]
if axis_tail is not None:
# Already had a tail, add lead in points
axis_tail[2:] = np.linspace(-0.01, 0, 2) * (arr[1] - arr[0]) + arr[0]
turnaround.append(axis_tail)
# Add tail off points
axis_tail = np.empty(4)
axis_tail[:2] = np.linspace(0, 0.01, 2) * (arr[-1] - arr[-2]) + arr[-1]
tail[a] = axis_tail
last_index = index
arrow_arr = None
if turnaround:
# If we didn't move then plot a straight line from start to stop
if all(t[1] - t[0] == 0 for t in turnaround):
for t in turnaround:
t[1] += (t[2] - t[1]) / 4
if all(t[3] - t[2] == 0 for t in turnaround):
for t in turnaround:
t[2] -= (t[2] - t[1]) / 4
# Plot the turnaround
arrow_arr = list(
_plot_spline(plt_axes, ranges, turnaround, {0: "lightgrey"})
)[0]
# Plot the points
index_colours = {2 * i: next(seg_col) for i in range(num_points)}
splines = list(_plot_spline(plt_axes, ranges, arrays, index_colours))
if arrow_arr:
# Plot the arrow on the turnaround
_plot_arrow(plt_axes, arrow_arr)
elif splines:
# Plot the starting arrow in the direction of the first point
arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]]
_plot_arrow(plt_axes, arrow_arr)
else:
# First point isn't moving, put a right caret marker
_plot_arrays(
plt_axes,
[np.array([dim.lower[a][0]]) for a in axes],
marker=5,
color="lightgrey",
)
# Plot the capture points
if len(dim) < 200:
arrays = [dim.midpoints[a] for a in axes]
_plot_arrays(plt_axes, arrays, linestyle="", marker=".", color="k")
# Plot the end
_plot_arrays(
plt_axes,
[np.array([dim.upper[a][-1]]) for a in axes],
marker="x",
color="lightgrey",
)
plt.show() # type: ignore