Source code for scanspec.regions

"""`Region` and its subclasses.

.. inheritance-diagram:: scanspec.regions
    :top-classes: scanspec.regions.Region
    :parts: 1
"""

from __future__ import annotations

from collections.abc import Iterator, Mapping
from dataclasses import is_dataclass
from typing import Any, Generic, cast

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.dataclasses import dataclass

from .core import (
    AxesPoints,
    Axis,
    StrictConfig,
    discriminated_union_of_subclasses,
    if_instance_do,
)

__all__ = [
    "Region",
    "get_mask",
    "CombinationOf",
    "UnionOf",
    "IntersectionOf",
    "DifferenceOf",
    "SymmetricDifferenceOf",
    "Range",
    "Rectangle",
    "Polygon",
    "Circle",
    "Ellipse",
    "find_regions",
]

NpMask = npt.NDArray[np.bool_]


[docs] @discriminated_union_of_subclasses class Region(Generic[Axis]): """Abstract baseclass for a Region that can `Mask` a `Spec`. Supports operators: - ``|``: `UnionOf` two Regions, midpoints present in either - ``&``: `IntersectionOf` two Regions, midpoints present in both - ``-``: `DifferenceOf` two Regions, midpoints present in first not second - ``^``: `SymmetricDifferenceOf` two Regions, midpoints present in one not both """
[docs] def axis_sets(self) -> list[set[Axis]]: # noqa: D102 """Produce the non-overlapping sets of axes this region spans.""" raise NotImplementedError(self)
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 """Produce a mask of which points are in the region.""" raise NotImplementedError(self)
def __or__(self, other: Region[Axis]) -> UnionOf[Axis]: return if_instance_do(other, Region, lambda o: UnionOf(self, o)) def __and__(self, other: Region[Axis]) -> IntersectionOf[Axis]: return if_instance_do(other, Region, lambda o: IntersectionOf(self, o)) def __sub__(self, other: Region[Axis]) -> DifferenceOf[Axis]: return if_instance_do(other, Region, lambda o: DifferenceOf(self, o)) def __xor__(self, other: Region[Axis]) -> SymmetricDifferenceOf[Axis]: return if_instance_do(other, Region, lambda o: SymmetricDifferenceOf(self, o))
[docs] def serialize(self) -> Mapping[str, Any]: """Serialize the Region to a dictionary.""" return TypeAdapter(Region[Any]).dump_python(self)
[docs] @staticmethod def deserialize(obj: Any) -> Region[Any]: """Deserialize a Region from a dictionary.""" return TypeAdapter(Region[Any]).validate_python(obj)
[docs] def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> NpMask: """Return a mask of the points inside the region. If there is an overlap of axes of region and points return a mask of the points in the region, otherwise return all ones """ axes = set(points) needs_mask = any(ks & axes for ks in region.axis_sets()) if needs_mask: return region.mask(points) else: return np.ones(len(list(points.values())[0]), dtype=np.bool_)
def _merge_axis_sets(axis_sets: list[set[Axis]]) -> Iterator[set[Axis]]: # Take overlapping axis sets and merge any that overlap into each # other for ks in axis_sets: # ks = key_sets - left over from a previous naming standard axis_set = ks.copy() # Empty matching sets into this axis_set for ks in axis_sets: if ks & axis_set: while ks: axis_set.add(ks.pop()) # It might be emptied already, only yield if it isn't if axis_set: yield axis_set
[docs] @dataclass(config=StrictConfig) class CombinationOf(Region[Axis]): """Abstract baseclass for a combination of two regions, left and right.""" left: Region[Axis] = Field(description="The left-hand Region to combine") right: Region[Axis] = Field(description="The right-hand Region to combine")
[docs] def axis_sets(self) -> list[set[Axis]]: # noqa: D102 axis_sets = list( _merge_axis_sets(self.left.axis_sets() + self.right.axis_sets()) ) return axis_sets
# Naming so we don't clash with typing.Union
[docs] @dataclass(config=StrictConfig) class UnionOf(CombinationOf[Axis]): """A point is in UnionOf(a, b) if in either a or b. Typically created with the ``|`` operator >>> r = Range("x", 0.5, 2.5) | Range("x", 1.5, 3.5) >>> r.mask({"x": np.array([0, 1, 2, 3, 4])}) array([False, True, True, True, False]) """
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 mask = get_mask(self.left, points) | get_mask(self.right, points) return mask
[docs] @dataclass(config=StrictConfig) class IntersectionOf(CombinationOf[Axis]): """A point is in IntersectionOf(a, b) if in both a and b. Typically created with the ``&`` operator. >>> r = Range("x", 0.5, 2.5) & Range("x", 1.5, 3.5) >>> r.mask({"x": np.array([0, 1, 2, 3, 4])}) array([False, False, True, False, False]) """
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 mask = get_mask(self.left, points) & get_mask(self.right, points) return mask
[docs] @dataclass(config=StrictConfig) class DifferenceOf(CombinationOf[Axis]): """A point is in DifferenceOf(a, b) if in a and not in b. Typically created with the ``-`` operator. >>> r = Range("x", 0.5, 2.5) - Range("x", 1.5, 3.5) >>> r.mask({"x": np.array([0, 1, 2, 3, 4])}) array([False, True, False, False, False]) """
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 left_mask = get_mask(self.left, points) # Return the xor restricted to the left region mask = left_mask ^ get_mask(self.right, points) & left_mask return mask
[docs] @dataclass(config=StrictConfig) class SymmetricDifferenceOf(CombinationOf[Axis]): """A point is in SymmetricDifferenceOf(a, b) if in either a or b, but not both. Typically created with the ``^`` operator. >>> r = Range("x", 0.5, 2.5) ^ Range("x", 1.5, 3.5) >>> r.mask({"x": np.array([0, 1, 2, 3, 4])}) array([False, True, False, True, False]) """
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 mask = get_mask(self.left, points) ^ get_mask(self.right, points) return mask
[docs] @dataclass(config=StrictConfig) class Range(Region[Axis]): """Mask contains points of axis >= min and <= max. >>> r = Range("x", 1, 2) >>> r.mask({"x": np.array([0, 1, 2, 3, 4])}) array([False, True, True, False, False]) """ axis: Axis = Field(description="The name matching the axis to mask in spec") min: float = Field(description="The minimum inclusive value in the region") max: float = Field(description="The minimum inclusive value in the region")
[docs] def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.axis}]
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 v = points[self.axis] mask = np.bitwise_and(v >= self.min, v <= self.max) return mask
[docs] @dataclass(config=StrictConfig) class Rectangle(Region[Axis]): """Mask contains points of axis within a rotated xy rectangle. .. example_spec:: from scanspec.regions import Rectangle from scanspec.specs import Line grid = Line("y", 1, 3, 10) * ~Line("x", 0, 2, 10) spec = grid & Rectangle("x", "y", 0, 1.1, 1.5, 2.1, 30) """ x_axis: Axis = Field(description="The name matching the x axis of the spec") y_axis: Axis = Field(description="The name matching the y axis of the spec") x_min: float = Field(description="Minimum inclusive x value in the region") y_min: float = Field(description="Minimum inclusive y value in the region") x_max: float = Field(description="Maximum inclusive x value in the region") y_max: float = Field(description="Maximum inclusive y value in the region") angle: float = Field( description="Clockwise rotation angle of the rectangle", default=0.0 )
[docs] def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}]
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 x = points[self.x_axis] - self.x_min y = points[self.y_axis] - self.y_min if self.angle != 0: # Rotate src points by -angle phi = np.radians(-self.angle) rx = x * np.cos(phi) - y * np.sin(phi) ry = x * np.sin(phi) + y * np.cos(phi) x = rx y = ry mask_x = np.bitwise_and(x >= 0, x <= (self.x_max - self.x_min)) mask_y = np.bitwise_and(y >= 0, y <= (self.y_max - self.y_min)) return mask_x & mask_y
[docs] @dataclass(config=StrictConfig) class Polygon(Region[Axis]): """Mask contains points of axis within a rotated xy polygon. .. example_spec:: from scanspec.regions import Polygon from scanspec.specs import Line grid = Line("y", 3, 8, 10) * ~Line("x", 1 ,8, 10) spec = grid & Polygon("x", "y", [1.0, 6.0, 8.0, 2.0], [4.0, 10.0, 6.0, 1.0]) """ x_axis: Axis = Field(description="The name matching the x axis of the spec") y_axis: Axis = Field(description="The name matching the y axis of the spec") x_verts: list[float] = Field( description="The Nx1 x coordinates of the polygons vertices", min_length=3 ) y_verts: list[float] = Field( description="The Nx1 y coordinates of the polygons vertices", min_length=3 )
[docs] def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}]
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 x = points[self.x_axis] y = points[self.y_axis] v1x, v1y = self.x_verts[-1], self.y_verts[-1] mask = np.full(len(x), False, dtype=np.bool_) for v2x, v2y in zip(self.x_verts, self.y_verts, strict=False): # skip horizontal edges if v2y != v1y: vmask = np.full(len(x), False, dtype=np.bool_) vmask |= (y < v2y) & (y >= v1y) vmask |= (y < v1y) & (y >= v2y) t = (y - v1y) / (v2y - v1y) vmask &= x < v1x + t * (v2x - v1x) mask ^= vmask v1x, v1y = v2x, v2y return mask
[docs] @dataclass(config=StrictConfig) class Circle(Region[Axis]): """Mask contains points of axis within an xy circle of given radius. .. example_spec:: from scanspec.regions import Circle from scanspec.specs import Line grid = Line("y", 1, 3, 10) * ~Line("x", 0, 2, 10) spec = grid & Circle("x", "y", 1, 2, 0.9) """ x_axis: Axis = Field(description="The name matching the x axis of the spec") y_axis: Axis = Field(description="The name matching the y axis of the spec") x_middle: float = Field(description="The central x point of the circle") y_middle: float = Field(description="The central y point of the circle") radius: float = Field(description="Radius of the circle", gt=0.0)
[docs] def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}]
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 x = points[self.x_axis] - self.x_middle y = points[self.y_axis] - self.y_middle mask = x * x + y * y <= (self.radius * self.radius) return mask
[docs] @dataclass(config=StrictConfig) class Ellipse(Region[Axis]): """Mask contains points of axis within an xy ellipse of given radius. .. example_spec:: from scanspec.regions import Ellipse from scanspec.specs import Line grid = Line("y", 3, 8, 10) * ~Line("x", 1 ,8, 10) spec = grid & Ellipse("x", "y", 5, 5, 2, 3, 75) """ x_axis: Axis = Field(description="The name matching the x axis of the spec") y_axis: Axis = Field(description="The name matching the y axis of the spec") x_middle: float = Field(description="The central x point of the ellipse") y_middle: float = Field(description="The central y point of the ellipse") x_radius: float = Field( description="The radius along the x axis of the ellipse", gt=0.0 ) y_radius: float = Field( description="The radius along the y axis of the ellipse", gt=0.0 ) angle: float = Field(description="The angle of the ellipse (degrees)", default=0.0)
[docs] def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}]
[docs] def mask(self, points: AxesPoints[Axis]) -> NpMask: # noqa: D102 x = points[self.x_axis] - self.x_middle y = points[self.y_axis] - self.y_middle if self.angle != 0: # Rotate src points by -angle phi = np.radians(-self.angle) tx = x * np.cos(phi) - y * np.sin(phi) ty = x * np.sin(phi) + y * np.cos(phi) x = tx y = ty mask = (x / self.x_radius) ** 2 + (y / self.y_radius) ** 2 <= 1 return mask
[docs] def find_regions(obj: Any) -> Iterator[Region[Any]]: # noqa: D102 """Recursively yield Regions from obj and its children.""" if ( hasattr(obj, "__pydantic_model__") and issubclass(obj.__pydantic_model__, BaseModel) or is_dataclass(obj) ): if isinstance(obj, Region): yield obj for name in obj.__dict__.keys(): regions: Iterator[Region[Any]] = find_regions(getattr(cast(Any, obj), name)) yield from regions