from __future__ import annotations
import asyncio
import logging
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import Any, Generic, Literal, ParamSpec, TypeVar, get_args, get_origin
from unittest.mock import Mock
import numpy as np
T = TypeVar("T")
V = TypeVar("V")
P = ParamSpec("P")
Callback = Callable[[T], None]
DEFAULT_TIMEOUT = 10.0
logger = logging.getLogger("ophyd_async")
class StrictEnumMeta(EnumMeta):
def __new__(metacls, *args, **kwargs):
ret = super().__new__(metacls, *args, **kwargs)
lowercase_names = [x.name for x in ret if not x.name.isupper()] # type: ignore
if lowercase_names:
raise TypeError(f"Names {lowercase_names} should be uppercase")
return ret
[docs]
class StrictEnum(str, Enum, metaclass=StrictEnumMeta):
"""All members should exist in the Backend, and there will be no extras."""
class SubsetEnumMeta(StrictEnumMeta):
def __call__(self, value, *args, **kwargs): # type: ignore
"""Return given value if it is a string and not a member of the enum.
If the value is not a string or is an enum member, default enum behavior
is applied. Type checking will complain if provided arbitrary string.
Returns:
Union[str, SubsetEnum]: If the value is a string and not a member of the
enum, the string is returned as is. Otherwise, the corresponding enum
member is returned.
Raises:
ValueError: If the value is not a string and cannot be converted to an enum
member.
"""
if isinstance(value, str) and not isinstance(value, self):
return value
return super().__call__(value, *args, **kwargs)
[docs]
class SubsetEnum(StrictEnum, metaclass=SubsetEnumMeta):
"""All members should exist in the Backend, but there may be extras."""
CALCULATE_TIMEOUT = "CALCULATE_TIMEOUT"
"""Sentinel used to implement ``myfunc(timeout=CalculateTimeout)``
This signifies that the function should calculate a suitable non-zero
timeout itself
"""
CalculatableTimeout = float | None | Literal["CALCULATE_TIMEOUT"]
[docs]
class NotConnected(Exception):
"""Exception to be raised if a `Device.connect` is cancelled.
:param errors:
Mapping of device name to Exception or another NotConnected.
Alternatively a string with the signal error text.
"""
_indent_width = " "
def __init__(self, errors: str | Mapping[str, Exception]):
self._errors = errors
@property
def sub_errors(self) -> Mapping[str, Exception]:
if isinstance(self._errors, dict):
return self._errors.copy()
else:
return {}
def _format_sub_errors(self, name: str, error: Exception, indent="") -> str:
if isinstance(error, NotConnected):
error_txt = ":" + error.format_error_string(indent + self._indent_width)
elif isinstance(error, Exception):
error_txt = ": " + err_str + "\n" if (err_str := str(error)) else "\n"
else:
raise RuntimeError(
f"Unexpected type `{type(error)}`, expected an Exception"
)
string = f"{indent}{name}: {type(error).__name__}" + error_txt
return string
def __str__(self) -> str:
return self.format_error_string(indent="")
[docs]
@classmethod
def with_other_exceptions_logged(
cls, exceptions: Mapping[str, Exception]
) -> NotConnected:
for name, exception in exceptions.items():
if not isinstance(exception, NotConnected):
logger.exception(
f"device `{name}` raised unexpected exception "
f"{type(exception).__name__}",
exc_info=exception,
)
return NotConnected(exceptions)
[docs]
@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
"""A dataclass such that, when expanded, it provides the kwargs for a watcher."""
current: T
"""The current value, where it currently is."""
initial: T
"""The initial value, where it was when it started."""
target: T
"""The target value, where it will be when it finishes."""
name: str | None = None
"""An optional name for the device, if available."""
unit: str | None = None
"""Units of the value, if applicable."""
precision: float | None = None
"""How many decimal places the value should be displayed to."""
fraction: float | None = None
"""The fraction of the way between initial and target."""
time_elapsed: float | None = None
"""The time elapsed since the start of the operation."""
time_remaining: float | None = None
"""The time remaining until the operation completes."""
[docs]
async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them.
Expected kwargs should be a mapping of names to coroutine tasks to execute.
"""
exceptions: dict[str, Exception] = {}
if len(coros) == 1:
# Single device optimization
name, coro = coros.popitem()
try:
await coro
except Exception as e:
exceptions[name] = e
else:
# Use gather to connect in parallel
results = await asyncio.gather(*coros.values(), return_exceptions=True)
for name, result in zip(coros, results, strict=False):
if isinstance(result, Exception):
exceptions[name] = result
if exceptions:
raise NotConnected.with_other_exceptions_logged(exceptions)
[docs]
def get_dtype(datatype: type) -> np.dtype:
"""Get the runtime dtype from a numpy ndarray type annotation.
```python
>>> from ophyd_async.core import Array1D
>>> import numpy as np
>>> get_dtype(Array1D[np.int8])
dtype('int8')
```
"""
if not get_origin(datatype) == np.ndarray:
raise TypeError(f"Expected Array1D[dtype], got {datatype}")
# datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
# so extract numpy.float64 from it
return np.dtype(get_args(get_args(datatype)[1])[0])
[docs]
def get_enum_cls(datatype: type | None) -> type[StrictEnum] | None:
"""Get the enum class from a datatype.
:raises TypeError: if type is not a [](#StrictEnum) or [](#SubsetEnum) subclass
```python
>>> from ophyd_async.core import StrictEnum
>>> from collections.abc import Sequence
>>> class MyEnum(StrictEnum):
... A = "A value"
>>> get_enum_cls(str)
>>> get_enum_cls(MyEnum)
<enum 'MyEnum'>
>>> get_enum_cls(Sequence[MyEnum])
<enum 'MyEnum'>
```
"""
if get_origin(datatype) is Sequence:
datatype = get_args(datatype)[0]
if datatype and issubclass(datatype, Enum):
if not issubclass(datatype, StrictEnum):
raise TypeError(
f"{datatype} should inherit from ophyd_async.core.SubsetEnum "
"or ophyd_async.core.StrictEnum"
)
return datatype
[docs]
def get_unique(values: dict[str, T], types: str) -> T:
"""If all values are the same, return that value, otherwise raise TypeError.
```python
>>> get_unique({"a": 1, "b": 1}, "integers")
1
>>> get_unique({"a": 1, "b": 2}, "integers")
Traceback (most recent call last):
...
TypeError: Differing integers: a has 1, b has 2
```
"""
set_values = set(values.values())
if len(set_values) != 1:
diffs = ", ".join(f"{k} has {v}" for k, v in values.items())
raise TypeError(f"Differing {types}: {diffs}")
return set_values.pop()
async def merge_gathered_dicts(
coros: Iterable[Awaitable[dict[str, T]]],
) -> dict[str, T]:
"""Merge dictionaries produced by a sequence of coroutines.
Can be used for merging `read()` or `describe()`.
:example:
```python
combined_read = await merge_gathered_dicts(s.read() for s in signals)
```
"""
ret: dict[str, T] = {}
for result in await asyncio.gather(*coros):
ret.update(result)
return ret
[docs]
async def gather_dict(coros: dict[T, Awaitable[V]]) -> dict[T, V]:
"""Take named coros and return a dict of their name to their return value."""
values = await asyncio.gather(*coros.values())
return dict(zip(coros, values, strict=True))
[docs]
def in_micros(t: float) -> int:
"""Convert between a seconds and microseconds.
:param t: A time in seconds
:return: A time in microseconds, rounded up to the nearest whole microsecond
:raises ValueError: if t < 0
"""
if t < 0:
raise ValueError(f"Expected a positive time in seconds, got {t!r}")
return int(np.ceil(t * 1e6))
def get_origin_class(annotatation: Any) -> type | None:
origin = get_origin(annotatation) or annotatation
if isinstance(origin, type):
return origin
[docs]
class Reference(Generic[T]):
"""Hide an object behind a reference.
Used to opt out of the naming/parent-child relationship of `Device`.
:example:
```python
class DeviceWithRefToSignal(Device):
def __init__(self, signal: SignalRW[int]):
self.signal_ref = Reference(signal)
super().__init__()
def set(self, value) -> AsyncStatus:
return self.signal_ref().set(value + 1)
```
"""
def __init__(self, obj: T):
self._obj = obj
def __call__(self) -> T:
return self._obj
[docs]
class LazyMock:
"""A lazily created Mock to be used when connecting in mock mode.
Creating Mocks is reasonably expensive when each Device (and Signal)
requires its own, and the tree is only used when ``Signal.set()`` is
called. This class allows a tree of lazily connected Mocks to be
constructed so that when the leaf is created, so are its parents.
Any calls to the child are then accessible from the parent mock.
```python
>>> parent = LazyMock()
>>> child = parent.child("child")
>>> child_mock = child()
>>> child_mock() # doctest: +ELLIPSIS
<Mock name='mock.child()' id='...'>
>>> parent_mock = parent()
>>> parent_mock.mock_calls
[call.child()]
```
"""
def __init__(self, name: str = "", parent: LazyMock | None = None) -> None:
self.parent = parent
self.name = name
self._mock: Mock | None = None
[docs]
def child(self, name: str) -> LazyMock:
"""Return a child of this LazyMock with the given name."""
return LazyMock(name, self)
def __call__(self) -> Mock:
if self._mock is None:
self._mock = Mock(spec=object)
if self.parent is not None:
self.parent().attach_mock(self._mock, self.name)
return self._mock