from __future__ import annotations
import asyncio
import logging
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum, EnumMeta, StrEnum
from typing import (
Any,
Generic,
Literal,
ParamSpec,
TypeVar,
get_args,
get_origin,
)
import numpy as np
from pydantic import BaseModel, ConfigDict
T = TypeVar("T")
V = TypeVar("V")
P = ParamSpec("P")
Callback = Callable[[T], None]
DEFAULT_TIMEOUT = 10.0
logger = logging.getLogger("ophyd_async")
class UppercaseNameEnumMeta(EnumMeta):
def __new__(cls, *args, **kwargs):
ret = super().__new__(cls, *args, **kwargs)
lowercase_names = [x.name for x in ret if not x.name.isupper()] # type: ignore
if lowercase_names:
raise TypeError(f"Names {lowercase_names} should be uppercase")
return ret
class AnyStringUppercaseNameEnumMeta(UppercaseNameEnumMeta):
def __call__(cls, value, *args, **kwargs): # type: ignore
"""Return given value if it is a string and not a member of the enum.
If the value is not a string or is an enum member, default enum behavior
is applied. Type checking will complain if provided arbitrary string.
Returns:
Union[str, SubsetEnum]: If the value is a string and not a member of the
enum, the string is returned as is. Otherwise, the corresponding enum
member is returned.
Raises:
ValueError: If the value is not a string and cannot be converted to an enum
member.
"""
if isinstance(value, str) and not isinstance(value, cls):
return value
return super().__call__(value, *args, **kwargs)
[docs]
class StrictEnum(StrEnum, metaclass=UppercaseNameEnumMeta):
"""All members should exist in the Backend, and there will be no extras."""
[docs]
class SubsetEnum(StrEnum, metaclass=AnyStringUppercaseNameEnumMeta):
"""All members should exist in the Backend, but there may be extras."""
[docs]
class SupersetEnum(StrEnum, metaclass=UppercaseNameEnumMeta):
"""Some members should exist in the Backend, and there should be no extras."""
EnumTypes = StrictEnum | SubsetEnum | SupersetEnum
CALCULATE_TIMEOUT = "CALCULATE_TIMEOUT"
"""Sentinel used to implement ``myfunc(timeout=CalculateTimeout)``
This signifies that the function should calculate a suitable non-zero
timeout itself
"""
CalculatableTimeout = float | None | Literal["CALCULATE_TIMEOUT"]
[docs]
class NotConnectedError(Exception):
"""Exception to be raised if a `Device.connect` is cancelled.
:param errors:
Mapping of device name to Exception or another NotConnectedError.
Alternatively a string with the signal error text.
"""
_indent_width = " "
def __init__(self, errors: str | Mapping[str, Exception]):
self._errors = errors
@property
def sub_errors(self) -> Mapping[str, Exception]:
if isinstance(self._errors, dict):
return self._errors.copy()
else:
return {}
def _format_sub_errors(self, name: str, error: Exception, indent="") -> str:
if isinstance(error, NotConnectedError):
error_txt = ":" + error.format_error_string(indent + self._indent_width)
elif isinstance(error, Exception):
error_txt = ": " + err_str + "\n" if (err_str := str(error)) else "\n"
else:
raise RuntimeError(
f"Unexpected type `{type(error)}`, expected an Exception"
)
string = f"{indent}{name}: {type(error).__name__}" + error_txt
return string
def __str__(self) -> str:
return self.format_error_string(indent="")
[docs]
@classmethod
def with_other_exceptions_logged(
cls, exceptions: Mapping[str, Exception]
) -> NotConnectedError:
for name, exception in exceptions.items():
if not isinstance(exception, NotConnectedError):
logger.exception(
f"device `{name}` raised unexpected exception "
f"{type(exception).__name__}",
exc_info=exception,
)
return NotConnectedError(exceptions)
[docs]
@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
"""A dataclass such that, when expanded, it provides the kwargs for a watcher."""
current: T
"""The current value, where it currently is."""
initial: T
"""The initial value, where it was when it started."""
target: T
"""The target value, where it will be when it finishes."""
name: str | None = None
"""An optional name for the device, if available."""
unit: str | None = None
"""Units of the value, if applicable."""
precision: float | None = None
"""How many decimal places the value should be displayed to."""
fraction: float | None = None
"""The fraction of the way between initial and target."""
time_elapsed: float | None = None
"""The time elapsed since the start of the operation."""
time_remaining: float | None = None
"""The time remaining until the operation completes."""
[docs]
async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them.
Expected kwargs should be a mapping of names to coroutine tasks to execute.
"""
exceptions: dict[str, Exception] = {}
if len(coros) == 1:
# Single device optimization
name, coro = coros.popitem()
try:
await coro
except Exception as exc:
exceptions[name] = exc
else:
# Use gather to connect in parallel
results = await asyncio.gather(*coros.values(), return_exceptions=True)
for name, result in zip(coros, results, strict=False):
if isinstance(result, Exception):
exceptions[name] = result
if exceptions:
raise NotConnectedError.with_other_exceptions_logged(exceptions)
[docs]
def get_dtype(datatype: type) -> np.dtype:
"""Get the runtime dtype from a numpy ndarray type annotation.
```python
>>> from ophyd_async.core import Array1D
>>> import numpy as np
>>> get_dtype(Array1D[np.int8])
dtype('int8')
```
"""
if not get_origin(datatype) == np.ndarray:
raise TypeError(f"Expected Array1D[dtype], got {datatype}")
# datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
# so extract numpy.float64 from it
return np.dtype(get_args(get_args(datatype)[1])[0])
[docs]
def get_enum_cls(datatype: type | None) -> type[EnumTypes] | None:
"""Get the enum class from a datatype.
:raises TypeError: if type is not a [](#StrictEnum) or [](#SubsetEnum)
or [](#SupersetEnum) subclass
```python
>>> from ophyd_async.core import StrictEnum
>>> from collections.abc import Sequence
>>> class MyEnum(StrictEnum):
... A = "A value"
>>> get_enum_cls(str)
>>> get_enum_cls(MyEnum)
<enum 'MyEnum'>
>>> get_enum_cls(Sequence[MyEnum])
<enum 'MyEnum'>
```
"""
if get_origin(datatype) is Sequence:
datatype = get_args(datatype)[0]
datatype = get_origin_class(datatype)
if datatype and issubclass(datatype, Enum):
if not issubclass(datatype, EnumTypes):
raise TypeError(
f"{datatype} should inherit from ophyd_async.core.SubsetEnum "
"or ophyd_async.core.StrictEnum "
"or ophyd_async.core.SupersetEnum."
)
return datatype
return None
[docs]
def get_unique(values: dict[str, T], types: str) -> T:
"""If all values are the same, return that value, otherwise raise TypeError.
```python
>>> get_unique({"a": 1, "b": 1}, "integers")
1
>>> get_unique({"a": 1, "b": 2}, "integers")
Traceback (most recent call last):
...
TypeError: Differing integers: a has 1, b has 2
```
"""
set_values = set(values.values())
if len(set_values) != 1:
diffs = ", ".join(f"{k} has {v}" for k, v in values.items())
raise TypeError(f"Differing {types}: {diffs}")
return set_values.pop()
async def merge_gathered_dicts(
coros: Iterable[Awaitable[dict[str, T]]],
) -> dict[str, T]:
"""Merge dictionaries produced by a sequence of coroutines.
Can be used for merging `read()` or `describe()`.
:example:
```python
combined_read = await merge_gathered_dicts(s.read() for s in signals)
```
"""
ret: dict[str, T] = {}
for result in await asyncio.gather(*coros):
ret.update(result)
return ret
[docs]
async def gather_dict(coros: Mapping[T, Awaitable[V]]) -> dict[T, V]:
"""Take named coros and return a dict of their name to their return value."""
values = await asyncio.gather(*coros.values())
return dict(zip(coros, values, strict=True))
[docs]
def in_micros(t: float) -> int:
"""Convert between a seconds and microseconds.
:param t: A time in seconds
:return: A time in microseconds, rounded up to the nearest whole microsecond
:raises ValueError: if t < 0
"""
if t < 0:
raise ValueError(f"Expected a positive time in seconds, got {t!r}")
return int(np.ceil(t * 1e6))
def get_origin_class(annotatation: Any) -> type | None:
origin = get_origin(annotatation) or annotatation
if isinstance(origin, type):
return origin
return None
[docs]
class Reference(Generic[T]):
"""Hide an object behind a reference.
Used to opt out of the naming/parent-child relationship of `Device`.
:example:
```python
class DeviceWithRefToSignal(Device):
def __init__(self, signal: SignalRW[int]):
self.signal_ref = Reference(signal)
super().__init__()
def set(self, value) -> AsyncStatus:
return self.signal_ref().set(value + 1)
```
"""
def __init__(self, obj: T):
self._obj = obj
def __call__(self) -> T:
return self._obj
[docs]
class ConfinedModel(BaseModel):
"""A base class confined to explicitly defined fields in the model schema."""
model_config = ConfigDict(
extra="forbid",
)
[docs]
def error_if_none(value: T | None, msg: str) -> T:
"""Check and return the value if not None.
:param value: The value to check
:param msg: The `RuntimeError` message to raise if it is None
:raises RuntimeError: If the value is None
:returns: The value if not None
Used to implement a pattern where a variable is None at init, then
changed by a method, then used in a later method.
"""
if value is None:
raise RuntimeError(msg)
return value