Source code for ophyd_async.core._status

"""Equivalent of bluesky.protocols.Status for asynchronous tasks."""

from __future__ import annotations

import asyncio
import contextlib
import functools
import time
from asyncio import CancelledError
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
from dataclasses import asdict, replace
from typing import Generic

from bluesky.protocols import Status

from ._device import Device
from ._protocol import Watcher
from ._utils import Callback, P, T, WatcherUpdate


class AsyncStatusBase(Status, Awaitable[None]):
    """Convert asyncio awaitable to bluesky Status interface.

    Can be used as an async context manager to automatically cancel the calling
    task when the status completes. This is useful for bounding loop execution:
    when the status completes, the calling task is cancelled, causing the loop
    to exit. If the loop completes first, the status task is automatically cancelled.
    """

    def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None):
        if isinstance(awaitable, asyncio.Task):
            self.task = awaitable
        else:

            async def wait_with_error_message(awaitable):
                try:
                    await awaitable
                except CancelledError as e:
                    raise CancelledError(
                        f"CancelledError while awaiting {awaitable} on {name}"
                    ) from e

            self.task = asyncio.create_task(wait_with_error_message(awaitable))
            # There is a small chance we could be cancelled before
            # wait_with_error_message starts.
            # Avoid complaints about awaitable not awaited if task is
            # pre-emptively cancelled, by ensuring it is always disposed
            self.task.add_done_callback(lambda _: awaitable.close())
        self.task.add_done_callback(self._run_callbacks)
        self._callbacks: list[Callback[Status]] = []
        self._name = name
        self._cancelled_error_ok = False

    def __await__(self):
        return self.task.__await__()

    def add_callback(self, callback: Callback[Status]):
        if self.done:
            callback(self)
        else:
            self._callbacks.append(callback)

    def _run_callbacks(self, task: asyncio.Task):
        for callback in self._callbacks:
            callback(self)

    def exception(self, timeout: float | None = 0.0) -> BaseException | None:
        """Return any exception raised by the task.

        :param timeout:
            Taken for compatibility with the Status interface, but must be 0.0 as we
            cannot wait for an async function in a sync call.
        """
        if timeout != 0.0:
            raise ValueError(
                "cannot honour any timeout other than 0 in an asynchronous function"
            )
        if self.task.done():
            try:
                return self.task.exception()
            except asyncio.CancelledError as exc:
                return exc
        return None

    @property
    def done(self) -> bool:
        return self.task.done()

    @property
    def success(self) -> bool:
        return (
            self.task.done()
            and not self.task.cancelled()
            and self.task.exception() is None
        )

    def __repr__(self) -> str:
        if self.done:
            if e := self.exception():
                status = f"errored: {repr(e)}"
            else:
                status = "done"
        else:
            status = "pending"
        device_str = f"device: {self._name}, " if self._name else ""
        return (
            f"<{type(self).__name__}, {device_str}"
            f"task: {self.task.get_coro()}, {status}>"
        )

    async def __aenter__(self):
        # Grab the calling task, the one that is doing `with status``
        calling_task = asyncio.current_task()
        if calling_task is None:
            raise RuntimeError("Can only use in a context manager inside a task")

        def _cancel_calling_task(task: asyncio.Task, calling_task=calling_task):
            # If no-one cancelled our child task, then it is expected
            # that we want to break out of the calling task with block
            # so mark that the CancelledError should be suppressed on exit
            self._cancelled_error_ok = not task.cancelled()
            calling_task.cancel()

        # When our child task is done, then cancel the calling task
        self.task.add_done_callback(_cancel_calling_task)
        return self

    async def __aexit__(self, exc_type, exc, tb):
        self.task.cancel()
        # Need to await the task to suppress teardown warnings, but
        # we know it will raise CancelledError as we just cancelled it
        with contextlib.suppress(CancelledError):
            await self.task
        if exc_type is CancelledError and self._cancelled_error_ok:
            # Suppress error as we cancelled it in _cancel_calling_task
            return True
        else:
            # Raise error as we didn't cause it
            return False

    __str__ = __repr__


[docs] class AsyncStatus(AsyncStatusBase): """Convert an asyncio awaitable to bluesky Status interface. :param awaitable: The coroutine or task to await. :param name: The name of the device, if available. Can be awaited like a standard Task: ```python status = AsyncStatus(asyncio.sleep(1)) assert not status.done await status # waits for 1 second assert status.done ``` Can also be used as a context manager to bound loop execution. When the status completes, the calling task is cancelled, causing loops to exit: ```python async with motor.set(target_position): async for value in observe_value(detector): process_reading(value) # Loop exits automatically when motor reaches position ``` If the loop completes before the status, the status task is cancelled: ```python async with AsyncStatus(long_operation()): for i in range(3): await process_step(i) # Loop completes, long_operation() is cancelled ``` Note that the body of the with statement will only break at a suspension point like `async for` or `await`, so body code without these suspension points will continue even if the status completes. """
[docs] @classmethod def wrap(cls, f: Callable[P, Coroutine]) -> Callable[P, AsyncStatus]: """Wrap an async function in an AsyncStatus and return it. Used to make an async function conform to a bluesky protocol. For example: ```python class MyDevice(Device): @AsyncStatus.wrap async def trigger(self): await asyncio.sleep(1) ``` """ @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AsyncStatus: if args and isinstance(args[0], Device): name = args[0].name else: name = None return cls(f(*args, **kwargs), name=name) return wrap_f
[docs] class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): """Convert an asyncio async iterable to bluesky Status and Watcher interface. :param iterator: The async iterable to await. :param name: The name of the device, if available. """ def __init__( self, iterator: AsyncIterator[WatcherUpdate[T]], name: str | None = None ): self._watchers: list[Watcher] = [] self._start = time.monotonic() self._last_update: WatcherUpdate[T] | None = None super().__init__(self._notify_watchers_from(iterator), name) async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): async for update in iterator: self._last_update = ( update if update.time_elapsed is not None else replace(update, time_elapsed=time.monotonic() - self._start) ) for watcher in self._watchers: self._update_watcher(watcher, self._last_update) def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]): vals = asdict( update, dict_factory=lambda d: {k: v for k, v in d if v is not None} ) watcher(**vals)
[docs] def watch(self, watcher: Watcher): """Add a watcher to the status. It is called: - immediately if there has already been an update - on every subsequent update """ self._watchers.append(watcher) if self._last_update: self._update_watcher(watcher, self._last_update)
[docs] @classmethod def wrap( cls, f: Callable[P, AsyncIterator[WatcherUpdate[T]]], ) -> Callable[P, WatchableAsyncStatus[T]]: """Wrap an AsyncIterator in a WatchableAsyncStatus. For example: ```python class MyDevice(Device): @WatchableAsyncStatus.wrap async def trigger(self): # sleep for a second, updating on progress every 0.1 seconds for i in range(10): yield WatcherUpdate(initial=0, current=i*0.1, target=1) await asyncio.sleep(0.1) ``` """ @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WatchableAsyncStatus[T]: if args and isinstance(args[0], Device): name = args[0].name else: name = None return cls(f(*args, **kwargs), name=name) return wrap_f
[docs] @AsyncStatus.wrap async def completed_status(exception: Exception | None = None): """Return a completed AsyncStatus. :param exception: If given, then raise this exception when awaited. """ if exception: raise exception return None