"""Equivalent of bluesky.protocols.Status for asynchronous tasks."""
import asyncio
import functools
import time
from dataclasses import asdict, replace
from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast
from bluesky.protocols import Status
from ._protocol import Watcher
from ._utils import Callback, P, T, WatcherUpdate
AS = TypeVar("AS", bound="AsyncStatus")
WAS = TypeVar("WAS", bound="WatchableAsyncStatus")
class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""
def __init__(self, awaitable: Awaitable):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(awaitable)
self.task.add_done_callback(self._run_callbacks)
self._callbacks: list[Callback[Status]] = []
def __await__(self):
return self.task.__await__()
def add_callback(self, callback: Callback[Status]):
if self.done:
callback(self)
else:
self._callbacks.append(callback)
def _run_callbacks(self, task: asyncio.Task):
for callback in self._callbacks:
callback(self)
def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
raise ValueError(
"cannot honour any timeout other than 0 in an asynchronous function"
)
if self.task.done():
try:
return self.task.exception()
except asyncio.CancelledError as e:
return e
return None
@property
def done(self) -> bool:
return self.task.done()
@property
def success(self) -> bool:
return (
self.task.done()
and not self.task.cancelled()
and self.task.exception() is None
)
def __repr__(self) -> str:
if self.done:
if e := self.exception():
status = f"errored: {repr(e)}"
else:
status = "done"
else:
status = "pending"
return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>"
__str__ = __repr__
[docs]
class AsyncStatus(AsyncStatusBase):
[docs]
@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
"""Wrap an async function in an AsyncStatus."""
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))
# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
return cast(Callable[P, AS], wrap_f)
[docs]
class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""
def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))
async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
self._last_update = (
update
if update.time_elapsed is not None
else replace(update, time_elapsed=time.monotonic() - self._start)
)
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)
def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
vals = asdict(
update, dict_factory=lambda d: {k: v for k, v in d if v is not None}
)
watcher(**vals)
[docs]
def watch(self, watcher: Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)
[docs]
@classmethod
def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus."""
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
return cls(f(*args, **kwargs))
return cast(Callable[P, WAS], wrap_f)