Source code for ophyd_async.core.signal

from __future__ import annotations

import asyncio
import functools
from typing import (
    Any,
    AsyncGenerator,
    Callable,
    Dict,
    Generic,
    Mapping,
    Optional,
    Tuple,
    Type,
    Union,
)

from bluesky.protocols import (
    DataKey,
    Locatable,
    Location,
    Movable,
    Reading,
    Status,
    Subscribable,
)

from ophyd_async.core.mock_signal_backend import MockSignalBackend
from ophyd_async.protocols import AsyncConfigurable, AsyncReadable, AsyncStageable

from .async_status import AsyncStatus
from .device import Device
from .signal_backend import SignalBackend
from .soft_signal_backend import SignalMetadata, SoftSignalBackend
from .utils import DEFAULT_TIMEOUT, CalculatableTimeout, CalculateTimeout, Callback, T


def _add_timeout(func):
    @functools.wraps(func)
    async def wrapper(self: Signal, *args, **kwargs):
        return await asyncio.wait_for(func(self, *args, **kwargs), self._timeout)

    return wrapper


def _fail(self, other, *args, **kwargs):
    if isinstance(other, Signal):
        raise TypeError(
            "Can't compare two Signals, did you mean await signal.get_value() instead?"
        )
    else:
        return NotImplemented


[docs] class Signal(Device, Generic[T]): """A Device with the concept of a value, with R, RW, W and X flavours""" def __init__( self, backend: Optional[SignalBackend[T]] = None, timeout: Optional[float] = DEFAULT_TIMEOUT, name: str = "", ) -> None: self._timeout = timeout self._backend = backend super().__init__(name)
[docs] async def connect( self, mock=False, timeout=DEFAULT_TIMEOUT, force_reconnect: bool = False, backend: Optional[SignalBackend[T]] = None, ): if backend: if self._backend and backend is not self._backend: raise ValueError("Backend at connection different from previous one.") self._backend = backend if ( self._previous_connect_was_mock is not None and self._previous_connect_was_mock != mock ): raise RuntimeError( f"`connect(mock={mock})` called on a `Signal` where the previous " f"connect was `mock={self._previous_connect_was_mock}`. Changing mock " "value between connects is not permitted." ) self._previous_connect_was_mock = mock if mock and not issubclass(type(self._backend), MockSignalBackend): # Using a soft backend, look to the initial value self._backend = MockSignalBackend(initial_backend=self._backend) if self._backend is None: raise RuntimeError("`connect` called on signal without backend") can_use_previous_connection: bool = self._connect_task is not None and not ( self._connect_task.done() and self._connect_task.exception() ) if force_reconnect or not can_use_previous_connection: self.log.debug(f"Connecting to {self.source}") self._connect_task = asyncio.create_task( self._backend.connect(timeout=timeout) ) else: self.log.debug(f"Reusing previous connection to {self.source}") assert ( self._connect_task ), "this assert is for type analysis and will never fail" await self._connect_task
@property def source(self) -> str: """Like ca://PV_PREFIX:SIGNAL, or "" if not set""" return self._backend.source(self.name) __lt__ = __le__ = __eq__ = __ge__ = __gt__ = __ne__ = _fail def __hash__(self): # Restore the default implementation so we can use in a set or dict return hash(id(self))
class _SignalCache(Generic[T]): def __init__(self, backend: SignalBackend[T], signal: Signal): self._signal = signal self._staged = False self._listeners: Dict[Callback, bool] = {} self._valid = asyncio.Event() self._reading: Optional[Reading] = None self._value: Optional[T] = None self.backend = backend signal.log.debug(f"Making subscription on source {signal.source}") backend.set_callback(self._callback) def close(self): self.backend.set_callback(None) self._signal.log.debug(f"Closing subscription on source {self._signal.source}") async def get_reading(self) -> Reading: await self._valid.wait() assert self._reading is not None, "Monitor not working" return self._reading async def get_value(self) -> T: await self._valid.wait() assert self._value is not None, "Monitor not working" return self._value def _callback(self, reading: Reading, value: T): self._signal.log.debug( f"Updated subscription: reading of source {self._signal.source} changed" f"from {self._reading} to {reading}" ) self._reading = reading self._value = value self._valid.set() for function, want_value in self._listeners.items(): self._notify(function, want_value) def _notify(self, function: Callback, want_value: bool): if want_value: function(self._value) else: function({self._signal.name: self._reading}) def subscribe(self, function: Callback, want_value: bool) -> None: self._listeners[function] = want_value if self._valid.is_set(): self._notify(function, want_value) def unsubscribe(self, function: Callback) -> bool: self._listeners.pop(function) return self._staged or bool(self._listeners) def set_staged(self, staged: bool): self._staged = staged return self._staged or bool(self._listeners)
[docs] class SignalR(Signal[T], AsyncReadable, AsyncStageable, Subscribable): """Signal that can be read from and monitored""" _cache: Optional[_SignalCache] = None def _backend_or_cache( self, cached: Optional[bool] ) -> Union[_SignalCache, SignalBackend]: # If cached is None then calculate it based on whether we already have a cache if cached is None: cached = self._cache is not None if cached: assert self._cache, f"{self.source} not being monitored" return self._cache else: return self._backend def _get_cache(self) -> _SignalCache: if not self._cache: self._cache = _SignalCache(self._backend, self) return self._cache def _del_cache(self, needed: bool): if self._cache and not needed: self._cache.close() self._cache = None
[docs] @_add_timeout async def read(self, cached: Optional[bool] = None) -> Dict[str, Reading]: """Return a single item dict with the reading in it""" return {self.name: await self._backend_or_cache(cached).get_reading()}
[docs] @_add_timeout async def describe(self) -> Dict[str, DataKey]: """Return a single item dict with the descriptor in it""" return {self.name: await self._backend.get_datakey(self.source)}
[docs] @_add_timeout async def get_value(self, cached: Optional[bool] = None) -> T: """The current value""" value = await self._backend_or_cache(cached).get_value() self.log.debug(f"get_value() on source {self.source} returned {value}") return value
[docs] def subscribe_value(self, function: Callback[T]): """Subscribe to updates in value of a device""" self._get_cache().subscribe(function, want_value=True)
[docs] def subscribe(self, function: Callback[Dict[str, Reading]]) -> None: """Subscribe to updates in the reading""" self._get_cache().subscribe(function, want_value=False)
[docs] def clear_sub(self, function: Callback) -> None: """Remove a subscription.""" self._del_cache(self._get_cache().unsubscribe(function))
[docs] @AsyncStatus.wrap async def stage(self) -> None: """Start caching this signal""" self._get_cache().set_staged(True)
[docs] @AsyncStatus.wrap async def unstage(self) -> None: """Stop caching this signal""" self._del_cache(self._get_cache().set_staged(False))
[docs] class SignalW(Signal[T], Movable): """Signal that can be set"""
[docs] def set( self, value: T, wait=True, timeout: CalculatableTimeout = CalculateTimeout ) -> AsyncStatus: """Set the value and return a status saying when it's done""" if timeout is CalculateTimeout: timeout = self._timeout async def do_set(): self.log.debug(f"Putting value {value} to backend at source {self.source}") await self._backend.put(value, wait=wait, timeout=timeout) self.log.debug( f"Successfully put value {value} to backend at source {self.source}" ) return AsyncStatus(do_set())
[docs] class SignalRW(SignalR[T], SignalW[T], Locatable): """Signal that can be both read and set"""
[docs] async def locate(self) -> Location: location: Location = { "setpoint": await self._backend.get_setpoint(), "readback": await self.get_value(), } return location
[docs] class SignalX(Signal): """Signal that puts the default value"""
[docs] def trigger( self, wait=True, timeout: CalculatableTimeout = CalculateTimeout ) -> AsyncStatus: """Trigger the action and return a status saying when it's done""" if timeout is CalculateTimeout: timeout = self._timeout coro = self._backend.put(None, wait=wait, timeout=timeout) return AsyncStatus(coro)
[docs] def soft_signal_rw( datatype: Optional[Type[T]] = None, initial_value: Optional[T] = None, name: str = "", units: str | None = None, precision: int | None = None, ) -> SignalRW[T]: """Creates a read-writable Signal with a SoftSignalBackend. May pass metadata, which are propagated into describe. """ metadata = SignalMetadata(units=units, precision=precision) signal = SignalRW( SoftSignalBackend(datatype, initial_value, metadata=metadata), name=name, ) return signal
[docs] def soft_signal_r_and_setter( datatype: Optional[Type[T]] = None, initial_value: Optional[T] = None, name: str = "", units: str | None = None, precision: int | None = None, ) -> Tuple[SignalR[T], Callable[[T], None]]: """Returns a tuple of a read-only Signal and a callable through which the signal can be internally modified within the device. May pass metadata, which are propagated into describe. Use soft_signal_rw if you want a device that is externally modifiable """ metadata = SignalMetadata(units=units, precision=precision) backend = SoftSignalBackend(datatype, initial_value, metadata=metadata) signal = SignalR(backend, name=name) return (signal, backend.set_value)
def _generate_assert_error_msg( name: str, expected_result: str, actual_result: str ) -> str: WARNING = "\033[93m" FAIL = "\033[91m" ENDC = "\033[0m" return ( f"Expected {WARNING}{name}{ENDC} to produce" + f"\n{FAIL}{expected_result}{ENDC}" + f"\nbut actually got \n{FAIL}{actual_result}{ENDC}" )
[docs] async def assert_value(signal: SignalR[T], value: Any) -> None: """Assert a signal's value and compare it an expected signal. Parameters ---------- signal: signal with get_value. value: The expected value from the signal. Notes ----- Example usage:: await assert_value(signal, value) """ actual_value = await signal.get_value() assert actual_value == value, _generate_assert_error_msg( name=signal.name, expected_result=value, actual_result=actual_value, )
[docs] async def assert_reading( readable: AsyncReadable, expected_reading: Mapping[str, Reading] ) -> None: """Assert readings from readable. Parameters ---------- readable: Callable with readable.read function that generate readings. reading: The expected readings from the readable. Notes ----- Example usage:: await assert_reading(readable, reading) """ actual_reading = await readable.read() assert expected_reading == actual_reading, _generate_assert_error_msg( name=readable.name, expected_result=expected_reading, actual_result=actual_reading, )
[docs] async def assert_configuration( configurable: AsyncConfigurable, configuration: Mapping[str, Reading], ) -> None: """Assert readings from Configurable. Parameters ---------- configurable: Configurable with Configurable.read function that generate readings. configuration: The expected readings from configurable. Notes ----- Example usage:: await assert_configuration(configurable configuration) """ actual_configurable = await configurable.read_configuration() assert configuration == actual_configurable, _generate_assert_error_msg( name=configurable.name, expected_result=configuration, actual_result=actual_configurable, )
[docs] def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int): """Assert emitted document generated by running a Bluesky plan Parameters ---------- Doc: A dictionary numbers: expected emission in kwarg from Notes ----- Example usage:: assert_emitted(docs, start=1, descriptor=1, resource=1, datum=1, event=1, stop=1) """ assert list(docs) == list(numbers), _generate_assert_error_msg( name="documents", expected_result=list(numbers), actual_result=list(docs), ) actual_numbers = {name: len(d) for name, d in docs.items()} assert actual_numbers == numbers, _generate_assert_error_msg( name="emitted", expected_result=numbers, actual_result=actual_numbers, )
[docs] async def observe_value( signal: SignalR[T], timeout: float | None = None, done_status: Status | None = None ) -> AsyncGenerator[T, None]: """Subscribe to the value of a signal so it can be iterated from. Parameters ---------- signal: Call subscribe_value on this at the start, and clear_sub on it at the end timeout: If given, how long to wait for each updated value in seconds. If an update is not produced in this time then raise asyncio.TimeoutError done_status: If this status is complete, stop observing and make the iterator return. If it raises an exception then this exception will be raised by the iterator. Notes ----- Example usage:: async for value in observe_value(sig): do_something_with(value) """ q: asyncio.Queue[T | Status] = asyncio.Queue() if timeout is None: get_value = q.get else: async def get_value(): return await asyncio.wait_for(q.get(), timeout) if done_status is not None: done_status.add_callback(q.put_nowait) signal.subscribe_value(q.put_nowait) try: while True: item = await get_value() if done_status and item is done_status: if exc := done_status.exception(): raise exc else: break else: yield item finally: signal.clear_sub(q.put_nowait)
class _ValueChecker(Generic[T]): def __init__(self, matcher: Callable[[T], bool], matcher_name: str): self._last_value: Optional[T] = None self._matcher = matcher self._matcher_name = matcher_name async def _wait_for_value(self, signal: SignalR[T]): async for value in observe_value(signal): self._last_value = value if self._matcher(value): return async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]): try: await asyncio.wait_for(self._wait_for_value(signal), timeout) except asyncio.TimeoutError as e: raise TimeoutError( f"{signal.name} didn't match {self._matcher_name} in {timeout}s, " f"last value {self._last_value!r}" ) from e
[docs] async def wait_for_value( signal: SignalR[T], match: Union[T, Callable[[T], bool]], timeout: Optional[float] ): """Wait for a signal to have a matching value. Parameters ---------- signal: Call subscribe_value on this at the start, and clear_sub on it at the end match: If a callable, it should return True if the value matches. If not callable then value will be checked for equality with match. timeout: How long to wait for the value to match Notes ----- Example usage:: wait_for_value(device.acquiring, 1, timeout=1) Or:: wait_for_value(device.num_captured, lambda v: v > 45, timeout=1) """ if callable(match): checker = _ValueChecker(match, match.__name__) else: checker = _ValueChecker(lambda v: v == match, repr(match)) await checker.wait_for_value(signal, timeout)
[docs] async def set_and_wait_for_value( signal: SignalRW[T], value: T, timeout: float = DEFAULT_TIMEOUT, status_timeout: Optional[float] = None, ) -> AsyncStatus: """Set a signal and monitor it until it has that value. Useful for busy record, or other Signals with pattern: - Set Signal with wait=True and stash the Status - Read the same Signal to check the operation has started - Return the Status so calling code can wait for operation to complete This function sets a signal to a specified value, optionally with or without a ca/pv put callback, and waits for the readback value of the signal to match the value it was set to. Parameters ---------- signal: The signal to set and monitor value: The value to set it to timeout: How long to wait for the signal to have the value status_timeout: How long the returned Status will wait for the set to complete Notes ----- Example usage:: set_and_wait_for_value(device.acquire, 1) """ status = signal.set(value, timeout=status_timeout) await wait_for_value(signal, value, timeout=timeout) return status