from collections.abc import Awaitable, Callable, Iterable, Iterator
from contextlib import contextmanager
from unittest.mock import AsyncMock, Mock
from ophyd_async.core import (
Device,
LazyMock,
MockSignalBackend,
Signal,
SignalConnector,
SignalDatatypeT,
SignalR,
)
[docs]
def get_mock(device: Device | Signal) -> Mock:
"""Return the mock (which may have child mocks attached) for a Device.
The device must have been connected in mock mode.
"""
mock = device._mock # noqa: SLF001
assert isinstance(mock, LazyMock), f"Device {device} not connected in mock mode"
return mock()
def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend:
connector = signal._connector # noqa: SLF001
assert isinstance(connector, SignalConnector), f"Expected Signal, got {signal}"
assert isinstance(connector.backend, MockSignalBackend), (
f"Signal {signal} not connected in mock mode"
)
return connector.backend
[docs]
def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT):
"""Set the value of a signal that is in mock mode."""
backend = _get_mock_signal_backend(signal)
backend.set_value(value)
class _SetValuesIterator(Iterator[SignalDatatypeT]):
# Garbage collected by the time __del__ is called unless we put it as a
# global attrbute here.
require_all_consumed: bool = False
def __init__(
self,
signal: SignalR[SignalDatatypeT],
values: Iterable[SignalDatatypeT],
require_all_consumed: bool = False,
):
self.signal = signal
self.values = values
self.require_all_consumed = require_all_consumed
self.index = 0
self.iterator = enumerate(values, start=1)
def __next__(self) -> SignalDatatypeT:
# Will propogate StopIteration
self.index, next_value = next(self.iterator)
set_mock_value(self.signal, next_value)
return next_value
def __del__(self):
if self.require_all_consumed:
# Values is cast to a list here because the user has supplied
# require_all_consumed=True, we can therefore assume they
# supplied a finite list.
# In the case of require_all_consumed=False, an infinite
# iterble is permitted
values = list(self.values)
if self.index != len(values):
# Report the values consumed and the values yet to be
# consumed
consumed = values[0 : self.index]
to_be_consumed = values[self.index :]
raise AssertionError(
f"{self.signal.name}: {consumed} were consumed "
f"but {to_be_consumed} were not consumed"
)
[docs]
def set_mock_values(
signal: SignalR[SignalDatatypeT],
values: Iterable[SignalDatatypeT],
require_all_consumed: bool = False,
) -> Iterator[SignalDatatypeT]:
"""Set a signal to a sequence of values, optionally repeating.
:param signal: A signal connected in mock mode.
:param values:
An iterable of the values to set the signal to, on each iteration the
next value will be set.
:param require_all_consumed:
If True, an AssertionError will be raised if the iterator is deleted
before all values have been consumed.
:example:
```python
for value_set in set_mock_values(signal, range(3)):
# do something
cm = set_mock_values(signal, [1, 3, 8], require_all_consumed=True):
next(cm) # do something
```
"""
return _SetValuesIterator(
signal,
values,
require_all_consumed=require_all_consumed,
)
@contextmanager
def _unset_side_effect_cm(put_mock: AsyncMock):
yield
put_mock.side_effect = None
[docs]
def callback_on_mock_put(
signal: Signal[SignalDatatypeT],
callback: Callable[[SignalDatatypeT, bool], SignalDatatypeT | None]
| Callable[[SignalDatatypeT, bool], Awaitable[SignalDatatypeT | None]],
):
"""For setting a callback when a backend is put to.
Can either be used in a context, with the callback being unset on exit, or
as an ordinary function.
The value that the callback returns (if not None) will be set to the signal
readback. If None is returned then the readback will be set to the setpoint.
:param signal: A signal with a `MockSignalBackend` backend.
:param callback: The callback to call when the backend is put to during the
context.
"""
backend = _get_mock_signal_backend(signal)
backend.put_mock.side_effect = callback
return _unset_side_effect_cm(backend.put_mock)
[docs]
def set_mock_put_proceeds(signal: Signal, proceeds: bool):
"""Allow or block a put with wait=True from proceeding."""
backend = _get_mock_signal_backend(signal)
if proceeds:
backend.put_proceeds.set()
else:
backend.put_proceeds.clear()
[docs]
@contextmanager
def mock_puts_blocked(*signals: Signal):
"""Context manager to block puts at the start and unblock at the end."""
for signal in signals:
set_mock_put_proceeds(signal, False)
yield
for signal in signals:
set_mock_put_proceeds(signal, True)
[docs]
def get_mock_put(signal: Signal) -> AsyncMock:
"""Get the mock associated with the put call on the signal."""
return _get_mock_signal_backend(signal).put_mock