Source code for ophyd_async.core.flyer

import asyncio
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
    AsyncIterator,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    Sequence,
    TypeVar,
)

from bluesky.protocols import (
    Asset,
    Collectable,
    Descriptor,
    Flyable,
    HasHints,
    Hints,
    Preparable,
    Reading,
    Stageable,
    WritesExternalAssets,
)

from .async_status import AsyncStatus
from .detector import DetectorControl, DetectorTrigger, DetectorWriter
from .device import Device
from .signal import SignalR
from .utils import DEFAULT_TIMEOUT, gather_list, merge_gathered_dicts

T = TypeVar("T")


[docs] @dataclass(frozen=True) class TriggerInfo: #: Number of triggers that will be sent num: int #: Sort of triggers that will be sent trigger: DetectorTrigger #: What is the minimum deadtime between triggers deadtime: float #: What is the maximum high time of the triggers livetime: float
[docs] class DetectorGroupLogic(ABC): # Read multipliers here, exposure is set in the plan
[docs] @abstractmethod async def open(self) -> Dict[str, Descriptor]: """Open all writers, wait for them to be open and return their descriptors"""
[docs] @abstractmethod async def ensure_armed(self, trigger_info: TriggerInfo): """Ensure the detectors are armed, return AsyncStatus that waits for disarm."""
[docs] @abstractmethod def collect_asset_docs(self) -> AsyncIterator[Asset]: """Collect asset docs from all writers"""
[docs] @abstractmethod async def wait_for_index( self, index: int, timeout: Optional[float] = DEFAULT_TIMEOUT ): """Wait until a specific index is ready to be collected"""
[docs] @abstractmethod async def disarm(self): """Disarm detectors"""
[docs] @abstractmethod async def close(self): """Close all writers and wait for them to be closed"""
[docs] @abstractmethod def hints(self) -> Hints: """Produce hints specifying which dataset(s) are most important"""
[docs] class SameTriggerDetectorGroupLogic(DetectorGroupLogic): def __init__( self, controllers: Sequence[DetectorControl], writers: Sequence[DetectorWriter], ) -> None: self._controllers = controllers self._writers = writers self._arm_statuses: Sequence[AsyncStatus] = () self._trigger_info: Optional[TriggerInfo] = None async def open(self) -> Dict[str, Descriptor]: return await merge_gathered_dicts(writer.open() for writer in self._writers) async def ensure_armed(self, trigger_info: TriggerInfo): if ( not self._arm_statuses or any(status.done for status in self._arm_statuses) or trigger_info != self._trigger_info ): # We need to re-arm await self.disarm() for controller in self._controllers: required = controller.get_deadtime(trigger_info.livetime) assert required <= trigger_info.deadtime, ( f"Detector {controller} needs at least {required}s deadtime, " f"but trigger logic provides only {trigger_info.deadtime}s" ) self._arm_statuses = await gather_list( controller.arm( num=trigger_info.num, trigger=trigger_info.trigger, exposure=trigger_info.livetime, ) for controller in self._controllers ) self._trigger_info = trigger_info async def collect_asset_docs(self) -> AsyncIterator[Asset]: # the below is confusing: gather_list does return an awaitable, but it itself # is a coroutine so we must call await twice... indices_written = min( await gather_list(writer.get_indices_written() for writer in self._writers) ) for writer in self._writers: async for doc in writer.collect_stream_docs(indices_written): yield doc async def wait_for_index( self, index: int, timeout: Optional[float] = DEFAULT_TIMEOUT ): await gather_list( writer.wait_for_index(index, timeout=timeout) for writer in self._writers ) async def disarm(self): await gather_list(controller.disarm() for controller in self._controllers) await gather_list(self._arm_statuses) async def close(self): await gather_list(writer.close() for writer in self._writers) def hints(self) -> Hints: return { "fields": [ field for writer in self._writers if hasattr(writer, "hints") for field in writer.hints.get("fields") ] }
[docs] class TriggerLogic(ABC, Generic[T]):
[docs] @abstractmethod def trigger_info(self, value: T) -> TriggerInfo: """Return info about triggers that will be produced for a given value"""
[docs] @abstractmethod async def prepare(self, value: T): """Move to the start of the flyscan"""
[docs] @abstractmethod async def start(self): """Start the flyscan"""
[docs] @abstractmethod async def stop(self): """Stop flying and wait everything to be stopped"""
[docs] class HardwareTriggeredFlyable( Device, Preparable, Stageable, Flyable, Collectable, WritesExternalAssets, HasHints, Generic[T], ): def __init__( self, detector_group_logic: DetectorGroupLogic, trigger_logic: TriggerLogic[T], configuration_signals: Sequence[SignalR], trigger_to_frame_timeout: Optional[float] = DEFAULT_TIMEOUT, name: str = "", ): self._detector_group_logic = detector_group_logic self._trigger_logic = trigger_logic self._configuration_signals = tuple(configuration_signals) self._describe: Dict[str, Descriptor] = {} self._watchers: List[Callable] = [] self._fly_status: Optional[AsyncStatus] = None self._fly_start = 0.0 self._offset = 0 # Add this to index to get frame number self._current_frame = 0 # The current frame we are on self._last_frame = 0 # The last frame that will be emitted self._trigger_to_frame_timeout = trigger_to_frame_timeout super().__init__(name=name) @AsyncStatus.wrap async def stage(self) -> None: await self.unstage() self._describe = await self._detector_group_logic.open() self._offset = 0 self._current_frame = 0
[docs] def prepare(self, value: T) -> AsyncStatus: """Arm detectors and setup trajectories""" # index + offset = current_frame, but starting a new scan so want it to be 0 # so subtract current_frame from both sides return AsyncStatus(self._prepare(value))
async def _prepare(self, value: T) -> None: self._offset -= self._current_frame self._current_frame = 0 trigger_info = self._trigger_logic.trigger_info(value) # Move to start and setup the flyscan, and arm dets in parallel await asyncio.gather( self._detector_group_logic.ensure_armed(trigger_info), self._trigger_logic.prepare(value), ) self._last_frame = self._current_frame + trigger_info.num async def describe_configuration(self) -> Dict[str, Descriptor]: return await merge_gathered_dicts( [sig.describe() for sig in self._configuration_signals] ) async def read_configuration(self) -> Dict[str, Reading]: return await merge_gathered_dicts( [sig.read() for sig in self._configuration_signals] ) async def describe_collect(self) -> Dict[str, Descriptor]: return self._describe @AsyncStatus.wrap async def kickoff(self) -> None: self._watchers = [] self._fly_status = AsyncStatus(self._fly(), self._watchers) self._fly_start = time.monotonic() async def _fly(self) -> None: await self._trigger_logic.start() # Wait for all detectors to have written up to a particular frame await self._detector_group_logic.wait_for_index( self._last_frame - self._offset, timeout=self._trigger_to_frame_timeout ) async def collect_asset_docs(self) -> AsyncIterator[Asset]: current_frame = self._current_frame stream_datums: List[Asset] = [] async for asset in self._detector_group_logic.collect_asset_docs(): name, doc = asset if name == "stream_datum": current_frame = doc["indices"]["stop"] + self._offset # Defer stream_datums until all stream_resources have been produced # In a single collect, all the stream_resources are produced first # followed by their stream_datums stream_datums.append(asset) else: yield asset for asset in stream_datums: yield asset if current_frame != self._current_frame: self._current_frame = current_frame for watcher in self._watchers: watcher( name=self.name, current=current_frame, initial=0, target=self._last_frame, unit="", precision=0, time_elapsed=time.monotonic() - self._fly_start, ) def complete(self) -> AsyncStatus: assert self._fly_status, "Kickoff not run" return self._fly_status @AsyncStatus.wrap async def unstage(self) -> None: await asyncio.gather( self._trigger_logic.stop(), self._detector_group_logic.close(), self._detector_group_logic.disarm(), ) @property def hints(self) -> Hints: return self._detector_group_logic.hints()