Source code for ophyd_async.core._utils

from __future__ import annotations

import asyncio
import logging
from dataclasses import dataclass
from typing import (
    Awaitable,
    Callable,
    Dict,
    Generic,
    Iterable,
    List,
    Optional,
    ParamSpec,
    Type,
    TypeVar,
    Union,
)

import numpy as np
from bluesky.protocols import Reading

T = TypeVar("T")
P = ParamSpec("P")
Callback = Callable[[T], None]

#: A function that will be called with the Reading and value when the
#: monitor updates
ReadingValueCallback = Callable[[Reading, T], None]
DEFAULT_TIMEOUT = 10.0
ErrorText = Union[str, Dict[str, Exception]]


[docs] class CalculateTimeout: """Sentinel class used to implement ``myfunc(timeout=CalculateTimeout)`` This signifies that the function should calculate a suitable non-zero timeout itself """
CalculatableTimeout = float | None | Type[CalculateTimeout]
[docs] class NotConnected(Exception): """Exception to be raised if a `Device.connect` is cancelled""" _indent_width = " " def __init__(self, errors: ErrorText): """ NotConnected holds a mapping of device/signal names to errors. Parameters ---------- errors: ErrorText Mapping of device name to Exception or another NotConnected. Alternatively a string with the signal error text. """ self._errors = errors def _format_sub_errors(self, name: str, error: Exception, indent="") -> str: if isinstance(error, NotConnected): error_txt = ":" + error.format_error_string(indent + self._indent_width) elif isinstance(error, Exception): error_txt = ": " + err_str + "\n" if (err_str := str(error)) else "\n" else: raise RuntimeError( f"Unexpected type `{type(error)}`, expected an Exception" ) string = f"{indent}{name}: {type(error).__name__}" + error_txt return string def format_error_string(self, indent="") -> str: if not isinstance(self._errors, dict) and not isinstance(self._errors, str): raise RuntimeError( f"Unexpected type `{type(self._errors)}` " "expected `str` or `dict`" ) if isinstance(self._errors, str): return " " + self._errors + "\n" string = "\n" for name, error in self._errors.items(): string += self._format_sub_errors(name, error, indent=indent) return string def __str__(self) -> str: return self.format_error_string(indent="")
[docs] @dataclass(frozen=True) class WatcherUpdate(Generic[T]): """A dataclass such that, when expanded, it provides the kwargs for a watcher""" current: T initial: T target: T name: str | None = None unit: str | None = None precision: float | None = None fraction: float | None = None time_elapsed: float | None = None time_remaining: float | None = None
[docs] async def wait_for_connection(**coros: Awaitable[None]): """Call many underlying signals, accumulating exceptions and returning them Expected kwargs should be a mapping of names to coroutine tasks to execute. """ results = await asyncio.gather(*coros.values(), return_exceptions=True) exceptions = {} for name, result in zip(coros, results): if isinstance(result, Exception): exceptions[name] = result if not isinstance(result, NotConnected): logging.exception( f"device `{name}` raised unexpected exception " f"{type(result).__name__}", exc_info=result, ) if exceptions: raise NotConnected(exceptions)
[docs] def get_dtype(typ: Type) -> Optional[np.dtype]: """Get the runtime dtype from a numpy ndarray type annotation >>> import numpy.typing as npt >>> import numpy as np >>> get_dtype(npt.NDArray[np.int8]) dtype('int8') """ if getattr(typ, "__origin__", None) == np.ndarray: # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]] # so extract numpy.float64 from it return np.dtype(typ.__args__[1].__args__[0]) # type: ignore return None
[docs] def get_unique(values: Dict[str, T], types: str) -> T: """If all values are the same, return that value, otherwise return TypeError >>> get_unique({"a": 1, "b": 1}, "integers") 1 >>> get_unique({"a": 1, "b": 2}, "integers") Traceback (most recent call last): ... TypeError: Differing integers: a has 1, b has 2 """ set_values = set(values.values()) if len(set_values) != 1: diffs = ", ".join(f"{k} has {v}" for k, v in values.items()) raise TypeError(f"Differing {types}: {diffs}") return set_values.pop()
async def merge_gathered_dicts( coros: Iterable[Awaitable[Dict[str, T]]], ) -> Dict[str, T]: """Merge dictionaries produced by a sequence of coroutines. Can be used for merging ``read()`` or ``describe``. For instance:: combined_read = await merge_gathered_dicts(s.read() for s in signals) """ ret: Dict[str, T] = {} for result in await asyncio.gather(*coros): ret.update(result) return ret async def gather_list(coros: Iterable[Awaitable[T]]) -> List[T]: return await asyncio.gather(*coros)
[docs] def in_micros(t: float) -> int: """ Converts between a positive number of seconds and an equivalent number of microseconds. Args: t (float): A time in seconds Raises: ValueError: if t < 0 Returns: t (int): A time in microseconds, rounded up to the nearest whole microsecond, """ if t < 0: raise ValueError(f"Expected a positive time in seconds, got {t!r}") return int(np.ceil(t * 1e6))