"""Equivalent of bluesky.protocols.Status for asynchronous tasks."""
import asyncio
import functools
import time
from dataclasses import asdict, replace
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
Type,
TypeVar,
cast,
)
from bluesky.protocols import Status
from ..protocols import Watcher
from .utils import Callback, P, T, WatcherUpdate
AS = TypeVar("AS", bound="AsyncStatus")
WAS = TypeVar("WAS", bound="WatchableAsyncStatus")
class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""
def __init__(self, awaitable: Awaitable):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(awaitable)
self.task.add_done_callback(self._run_callbacks)
self._callbacks: list[Callback[Status]] = []
def __await__(self):
return self.task.__await__()
def add_callback(self, callback: Callback[Status]):
if self.done:
callback(self)
else:
self._callbacks.append(callback)
def _run_callbacks(self, task: asyncio.Task):
for callback in self._callbacks:
callback(self)
def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
raise ValueError(
"cannot honour any timeout other than 0 in an asynchronous function"
)
if self.task.done():
try:
return self.task.exception()
except asyncio.CancelledError as e:
return e
return None
@property
def done(self) -> bool:
return self.task.done()
@property
def success(self) -> bool:
return (
self.task.done()
and not self.task.cancelled()
and self.task.exception() is None
)
def __repr__(self) -> str:
if self.done:
if e := self.exception():
status = f"errored: {repr(e)}"
else:
status = "done"
else:
status = "pending"
return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>"
__str__ = __repr__
[docs]
class AsyncStatus(AsyncStatusBase):
[docs]
@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
"""Wrap an async function in an AsyncStatus."""
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))
# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
return cast(Callable[P, AS], wrap_f)
[docs]
class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""
def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))
async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
self._last_update = (
update
if update.time_elapsed is not None
else replace(update, time_elapsed=time.monotonic() - self._start)
)
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)
def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
vals = asdict(
update, dict_factory=lambda d: {k: v for k, v in d if v is not None}
)
watcher(**vals)
[docs]
def watch(self, watcher: Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)
[docs]
@classmethod
def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus."""
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
return cls(f(*args, **kwargs))
return cast(Callable[P, WAS], wrap_f)