Source code for ophyd_async.core._status

"""Equivalent of bluesky.protocols.Status for asynchronous tasks."""

from __future__ import annotations

import asyncio
import functools
import time
from collections.abc import AsyncIterator, Callable, Coroutine
from dataclasses import asdict, replace
from typing import Generic

from bluesky.protocols import Status

from ._device import Device
from ._protocol import Watcher
from ._utils import Callback, P, T, WatcherUpdate


class AsyncStatusBase(Status):
    def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None):
        if isinstance(awaitable, asyncio.Task):
            self.task = awaitable
        else:
            self.task = asyncio.create_task(awaitable)
        self.task.add_done_callback(self._run_callbacks)
        self._callbacks: list[Callback[Status]] = []
        self._name = name

    def __await__(self):
        return self.task.__await__()

    def add_callback(self, callback: Callback[Status]):
        if self.done:
            callback(self)
        else:
            self._callbacks.append(callback)

    def _run_callbacks(self, task: asyncio.Task):
        for callback in self._callbacks:
            callback(self)

    def exception(self, timeout: float | None = 0.0) -> BaseException | None:
        """Return any exception raised by the task.

        :param timeout:
            Taken for compatibility with the Status interface, but must be 0.0 as we
            cannot wait for an async function in a sync call.
        """
        if timeout != 0.0:
            raise ValueError(
                "cannot honour any timeout other than 0 in an asynchronous function"
            )
        if self.task.done():
            try:
                return self.task.exception()
            except asyncio.CancelledError as e:
                return e
        return None

    @property
    def done(self) -> bool:
        return self.task.done()

    @property
    def success(self) -> bool:
        return (
            self.task.done()
            and not self.task.cancelled()
            and self.task.exception() is None
        )

    def __repr__(self) -> str:
        if self.done:
            if e := self.exception():
                status = f"errored: {repr(e)}"
            else:
                status = "done"
        else:
            status = "pending"
        device_str = f"device: {self._name}, " if self._name else ""
        return (
            f"<{type(self).__name__}, {device_str}"
            f"task: {self.task.get_coro()}, {status}>"
        )

    __str__ = __repr__


[docs] class AsyncStatus(AsyncStatusBase): """Convert an asyncio awaitable to bluesky Status interface. :param awaitable: The coroutine or task to await. :param name: The name of the device, if available. For example: ```python status = AsyncStatus(asyncio.sleep(1)) assert not status.done await status # waits for 1 second assert status.done ``` """
[docs] @classmethod def wrap(cls, f: Callable[P, Coroutine]) -> Callable[P, AsyncStatus]: """Wrap an async function in an AsyncStatus and return it. Used to make an async function conform to a bluesky protocol. For example: ```python class MyDevice(Device): @AsyncStatus.wrap async def trigger(self): await asyncio.sleep(1) ``` """ @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AsyncStatus: if args and isinstance(args[0], Device): name = args[0].name else: name = None return cls(f(*args, **kwargs), name=name) return wrap_f
[docs] class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): """Convert an asyncio async iterable to bluesky Status and Watcher interface. :param iterator: The async iterable to await. :param name: The name of the device, if available. """ def __init__( self, iterator: AsyncIterator[WatcherUpdate[T]], name: str | None = None ): self._watchers: list[Watcher] = [] self._start = time.monotonic() self._last_update: WatcherUpdate[T] | None = None super().__init__(self._notify_watchers_from(iterator), name) async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): async for update in iterator: self._last_update = ( update if update.time_elapsed is not None else replace(update, time_elapsed=time.monotonic() - self._start) ) for watcher in self._watchers: self._update_watcher(watcher, self._last_update) def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]): vals = asdict( update, dict_factory=lambda d: {k: v for k, v in d if v is not None} ) watcher(**vals)
[docs] def watch(self, watcher: Watcher): """Add a watcher to the status. It is called: - immediately if there has already been an update - on every subsequent update """ self._watchers.append(watcher) if self._last_update: self._update_watcher(watcher, self._last_update)
[docs] @classmethod def wrap( cls, f: Callable[P, AsyncIterator[WatcherUpdate[T]]], ) -> Callable[P, WatchableAsyncStatus[T]]: """Wrap an AsyncIterator in a WatchableAsyncStatus. For example: ```python class MyDevice(Device): @WatchableAsyncStatus.wrap async def trigger(self): # sleep for a second, updating on progress every 0.1 seconds for i in range(10): yield WatcherUpdate(initial=0, current=i*0.1, target=1) await asyncio.sleep(0.1) ``` """ @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WatchableAsyncStatus[T]: if args and isinstance(args[0], Device): name = args[0].name else: name = None return cls(f(*args, **kwargs), name=name) return wrap_f
[docs] @AsyncStatus.wrap async def completed_status(exception: Exception | None = None): """Return a completed AsyncStatus. :param exception: If given, then raise this exception when awaited. """ if exception: raise exception return None