Source code for bluesky.bundlers

import asyncio
import inspect
import time as ttime
from collections import defaultdict, deque
from typing import Any, Callable, Deque, Dict, FrozenSet, Iterable, List, Optional, Tuple, Union

from event_model import (
    ComposeDescriptorBundle,
    DataKey,
    Datum,
    DocumentNames,
    EventModelValueError,
    Resource,
    StreamDatum,
    StreamRange,
    StreamResource,
    compose_run,
    pack_event_page,
)
from event_model.documents.event import Event

from .log import doc_logger
from .protocols import (
    Callback,
    Collectable,
    Configurable,
    EventCollectable,
    EventPageCollectable,
    Flyable,
    HasName,
    Readable,
    Reading,
    Subscribable,
    SyncOrAsync,
    T,
    WritesStreamAssets,
    check_supports,
)
from .utils import (
    IllegalMessageSequence,
    Msg,
    _rearrange_into_parallel_dicts,
    iterate_maybe_async,
    maybe_await,
    maybe_collect_asset_docs,
    maybe_update_hints,
    new_uid,
    short_uid,
)

ObjDict = Dict[Any, Dict[str, T]]
ExternalAssetDoc = Union[Datum, Resource, StreamDatum, StreamResource]


[docs] class RunBundler:
[docs] def __init__(self, md, record_interruptions, emit, emit_sync, log, *, strict_pre_declare=False): # if create can YOLO implicitly create a stream self._strict_pre_declare = strict_pre_declare # state stolen from the RE self.bundling = False # if we are in the middle of bundling readings self._bundle_name = None # name given to event descriptor self._run_start_uid = None # The (future) runstart uid self._objs_read: Deque[HasName] = deque() # objects read in one Event self._read_cache: Deque[Dict[str, Reading]] = deque() # cache of obj.read() in one Event self._asset_docs_cache = deque() # cache of obj.collect_asset_docs() self._describe_cache: ObjDict[DataKey] = dict() # cache of all obj.describe() output # noqa: C408 self._describe_collect_cache: ObjDict[Dict[str, DataKey]] = dict() # noqa: C408 # cache of all obj.describe() output self._config_desc_cache: ObjDict[DataKey] = dict() # " obj.describe_configuration() # noqa: C408 self._config_values_cache: ObjDict[Any] = dict() # " obj.read_configuration() values # noqa: C408 self._config_ts_cache: ObjDict[Any] = dict() # " obj.read_configuration() timestamps # noqa: C408 # cache of {name: (doc, compose_event, compose_event_page)} self._descriptors: Dict[Any, ComposeDescriptorBundle] = dict() # noqa: C408 self._descriptor_objs: Dict[str, Dict[HasName, Dict[str, DataKey]]] = dict() # noqa: C408 # cache of {obj: {objs_frozen_set: (doc, compose_event, compose_event_page)} self._local_descriptors: Dict[Any, Dict[FrozenSet[str], ComposeDescriptorBundle]] = dict() # noqa: C408 # a seq_num counter per stream self._sequence_counters: Dict[Any, int] = dict() # noqa: C408 self._sequence_counters_copy: Dict[Any, int] = dict() # for if we redo data-points # noqa: C408 self._monitor_params: Dict[Subscribable, Tuple[Callback, Dict]] = dict() # noqa: C408 # cache of {obj: (cb, kwargs)} # a cache of stream_resource uid to the data_keys that stream_resource collects for self._stream_resource_data_keys: Dict[str, Iterable[str]] = dict() # noqa: C408 self.run_is_open = False self._uncollected = set() # objects after kickoff(), before collect() # we expect the RE to take care of the composition self._md = md # this is state on the RE, mirror it here rather than refer to # the parent self.record_interruptions = record_interruptions # this is RE.emit, but lifted to this context self.emit = emit self.emit_sync = emit_sync self.log = log # Map of set of collect objects to list of stream names that they can be collected into self._declared_stream_names: Dict[FrozenSet, List[str]] = {}
[docs] async def open_run(self, msg): self.run_is_open = True self._run_start_uid = new_uid() self._interruptions_desc_uid = None # uid for a special Event Desc. self._interruptions_counter = 0 # seq_num, special Event stream run = compose_run(uid=self._run_start_uid, event_counters=self._sequence_counters, metadata=self._md) doc = run.start_doc self._compose_descriptor = run.compose_descriptor self._compose_resource = run.compose_resource self._compose_stop = run.compose_stop self._compose_stream_resource = run.compose_stream_resource await self.emit(DocumentNames.start, doc) doc_logger.debug( "[start] document is emitted (run_uid=%r)", self._run_start_uid, extra={"doc_name": "start", "run_uid": self._run_start_uid}, ) await self.reset_checkpoint_state_coro() # Emit an Event Descriptor for recording any interruptions as Events. if self.record_interruptions: # To store the interruptions uid outside of event-model self._interruptions_desc_uid = new_uid() dk = {"dtype": "string", "shape": [], "source": "RunEngine"} descriptor_bundle = self._compose_descriptor( uid=self._interruptions_desc_uid, name="interruptions", data_keys={"interruption": dk}, ) self._interruptions_desc = descriptor_bundle.descriptor_doc self._interruptions_compose_event = descriptor_bundle.compose_event await self.emit(DocumentNames.descriptor, self._interruptions_desc) return self._run_start_uid
[docs] async def close_run(self, msg): """Instruct the RunEngine to write the RunStop document Expected message object is:: Msg('close_run', None, exit_status=None, reason=None) if *exit_stats* and *reason* are not provided, use the values stashed on the RE. """ if not self.run_is_open: raise IllegalMessageSequence( "A 'close_run' message was received but there is no run " "open. If this occurred after a pause/resume, add " "a 'checkpoint' message after the 'close_run' message." ) self.log.debug("Stopping run %r", self._run_start_uid) # Clear any uncleared monitoring callbacks. for obj, (cb, kwargs) in list(self._monitor_params.items()): # noqa: B007 obj.clear_sub(cb) del self._monitor_params[obj] reason = msg.kwargs.get("reason", None) if reason is None: reason = "" exit_status = msg.kwargs.get("exit_status", "success") or "success" doc = self._compose_stop( exit_status=exit_status, reason=reason, ) await self.emit(DocumentNames.stop, doc) doc_logger.debug( "[stop] document is emitted (run_uid=%r)", self._run_start_uid, extra={"doc_name": "stop", "run_uid": self._run_start_uid}, ) await self.reset_checkpoint_state_coro() self.run_is_open = False return doc["run_start"]
async def _prepare_stream( self, desc_key: str, objs_dks: Dict[HasName, Dict[str, DataKey]], ): # We do not have an Event Descriptor for this set # so one must be created. data_keys = {} config = {} object_keys = {} hints: Dict[str, Any] = {} for obj, dks in objs_dks.items(): maybe_update_hints(hints, obj) # dks is an OrderedDict. Record that order as a list. object_keys[obj.name] = list(dks) for key in dks.keys(): dks[key]["object_name"] = obj.name data_keys.update(dks) config[obj.name] = { "data": self._config_values_cache[obj], "timestamps": self._config_ts_cache[obj], "data_keys": self._config_desc_cache[obj], } self._descriptors[desc_key] = self._compose_descriptor( desc_key, data_keys, configuration=config, hints=hints, object_keys=object_keys, ) await self.emit(DocumentNames.descriptor, self._descriptors[desc_key].descriptor_doc) doc_logger.debug( "[descriptor] document emitted with name %r containing data keys %r (run_uid=%r)", desc_key, data_keys.keys(), self._run_start_uid, extra={"doc_name": "descriptor", "run_uid": self._run_start_uid, "data_keys": data_keys.keys()}, ) self._descriptor_objs[desc_key] = objs_dks if desc_key not in self._sequence_counters: self._sequence_counters[desc_key] = 1 self._sequence_counters_copy[desc_key] = 1 return ( self._descriptors[desc_key].descriptor_doc, self._descriptors[desc_key].compose_event, list(objs_dks), ) async def _ensure_cached(self, obj, collect=False): coros = [] if not collect and obj not in self._describe_cache: coros.append(self._cache_describe(obj)) elif collect and obj not in self._describe_collect_cache: coros.append(self._cache_describe_collect(obj)) if obj not in self._config_desc_cache: coros.append(self._cache_describe_config(obj)) coros.append(self._cache_read_config(obj)) await asyncio.gather(*coros) async def declare_stream(self, msg): """Generate and emit an EventDescriptor.""" command, no_obj, objs, kwargs, _ = msg stream_name = kwargs.get("name") assert stream_name is not None, "A stream name that is not None is required for pre-declare" collect = kwargs.get("collect", False) assert no_obj is None objs = frozenset(objs) objs_dks = {} # {collect_object: stream_data_keys} await asyncio.gather(*[self._ensure_cached(obj, collect=collect) for obj in objs]) for obj in objs: if collect: data_keys = self._describe_collect_cache[obj] streams_and_data_keys: List[Tuple[str, Dict[str, Any]]] = ( self._maybe_format_datakeys_with_stream_name(data_keys, message_stream_name=stream_name) ) # ensure that there is only one stream and it is the stream we have provided. assert len(streams_and_data_keys) == 1 and streams_and_data_keys[0][0] == stream_name, ( "`declare_stream` contained `collect=True` but `describe_collect` did " f"not return a single Dict[str, DataKey] for the passed in {stream_name}" ) else: data_keys = self._describe_cache[obj] objs_dks[obj] = data_keys existing_stream_names = self._declared_stream_names.setdefault(objs, []) existing_stream_names.append(stream_name) return await self._prepare_stream(stream_name, objs_dks)
[docs] async def create(self, msg): """ Start bundling future obj.read() calls for an Event document. Expected message object is:: Msg('create', None, name='primary') Msg('create', name='primary') Note that the `name` kwarg will be the 'name' field of the resulting descriptor. So descriptor['name'] = msg.kwargs['name']. Also note that changing the 'name' of the Event will create a new Descriptor document. """ if self.bundling: raise IllegalMessageSequence( "A second 'create' message is not " "allowed until the current event " "bundle is closed with a 'save' or " "'drop' message." ) self._read_cache.clear() self._asset_docs_cache.clear() self._objs_read.clear() self.bundling = True command, obj, args, kwargs, _ = msg try: self._bundle_name = kwargs["name"] except KeyError: try: (self._bundle_name,) = args except ValueError: raise ValueError( "Msg('create') now requires a stream name, given as " "Msg('create', name) or Msg('create', name=name)" ) from None if self._strict_pre_declare: if self._bundle_name not in self._descriptors: raise IllegalMessageSequence("In strict mode you must pre-declare streams.")
[docs] async def read(self, msg, reading): """ Add a reading to the open event bundle. Expected message object is:: Msg('read', obj) """ if self.bundling: obj = msg.obj # if the object is not in the _describe_cache, cache it # Note: there is a race condition between the code here # and in monitor() and collect(), so if you do them concurrently # on the same device you make obj.describe() calls multiple times. # As this is harmless and not an expected use case, we don't guard # against it. Reading multiple devices concurrently works fine. await self._ensure_cached(obj) # check that current read collides with nothing else in # current event cur_keys = set(self._describe_cache[obj].keys()) for read_obj in self._objs_read: # that is, field names known_keys = self._describe_cache[read_obj].keys() if set(known_keys) & cur_keys: raise ValueError( f"Data keys (field names) from {obj!r} " f"collide with those from {read_obj!r}. " f"The colliding keys are {set(known_keys) & cur_keys}" ) # add this object to the cache of things we have read self._objs_read.append(obj) # Stash the results, which will be emitted the next time _save is # called --- or never emitted if _drop is called instead. self._read_cache.append(reading) # Ask the object for any resource or datum documents is has cached # and cache them as well. Likewise, these will be emitted if and # when _save is called. asset_docs_collected = [x async for x in maybe_collect_asset_docs(msg, obj, *msg.args, **msg.kwargs)] self._asset_docs_cache.extend(asset_docs_collected) return reading
async def _cache_describe(self, obj): "Read the object's describe and cache it." obj = check_supports(obj, Readable) self._describe_cache[obj] = await maybe_await(obj.describe()) async def _cache_describe_config(self, obj): "Read the object's describe_configuration and cache it." if isinstance(obj, Configurable): conf_keys = await maybe_await(obj.describe_configuration()) else: conf_keys = {} self._config_desc_cache[obj] = conf_keys async def _cache_read_config(self, obj): "Read the object's configuration and cache it." if isinstance(obj, Configurable): conf = await maybe_await(obj.read_configuration()) else: conf = {} config_values = {} config_ts = {} for key, val in conf.items(): config_values[key] = val["value"] config_ts[key] = val["timestamp"] self._config_values_cache[obj] = config_values self._config_ts_cache[obj] = config_ts
[docs] async def monitor(self, msg): """ Monitor a signal. Emit event documents asynchronously. A descriptor document is emitted immediately. Then, a closure is defined that emits Event documents associated with that descriptor from a separate thread. This process is not related to the main bundling process (create/read/save). Expected message object is:: Msg('monitor', obj, **kwargs) Msg('monitor', obj, name='event-stream-name', **kwargs) where kwargs are passed through to ``obj.subscribe()`` """ obj = check_supports(msg.obj, Subscribable) if msg.args: raise ValueError("The 'monitor' Msg does not accept positional arguments.") kwargs = dict(msg.kwargs) name = kwargs.pop("name", short_uid("monitor")) if obj in self._monitor_params: raise IllegalMessageSequence(f"A 'monitor' message was sent for {obj} which is already monitored") await self._ensure_cached(obj) stream_bundle = await self._prepare_stream(name, {obj: self._describe_cache[obj]}) compose_event = stream_bundle[1] def emit_event(readings: Optional[Dict[str, Reading]] = None, *args, **kwargs): if readings is not None: # We were passed something we can use, but check no args or kwargs assert ( not args and not kwargs ), "If subscribe callback called with readings, args and kwargs are not supported." else: # Ignore the inputs. Use this call as a signal to call read on the # object, a crude way to be sure we get all the info we need. readable_obj = check_supports(obj, Readable) # type: ignore readings = readable_obj.read() # type: ignore assert not inspect.isawaitable(readings), ( f"{readable_obj} has async read() method and the callback " "passed to subscribe() was not called with Dict[str, Reading]" ) data, timestamps = _rearrange_into_parallel_dicts(readings) doc = compose_event( data=data, timestamps=timestamps, ) self.emit_sync(DocumentNames.event, doc) self._monitor_params[obj] = emit_event, kwargs # TODO: deprecate **kwargs when Ophyd.v2 is available obj.subscribe(emit_event, **kwargs)
[docs] def record_interruption(self, content): """ Emit an event in the 'interruptions' event stream. If we are not inside a run or if self.record_interruptions is False, nothing is done. """ if self._interruptions_desc_uid is not None: # We are inside a run and self.record_interruptions is True. doc = self._interruptions_compose_event( data={"interruption": content}, timestamps={"interruption": ttime.time()}, ) self._interruptions_counter += 1 self.emit_sync(DocumentNames.event, doc)
[docs] def rewind(self): self._sequence_counters.clear() self._sequence_counters.update(self._sequence_counters_copy) # make sure we do not forget about streams we roll back to the # very beginning of for desc_key in self._descriptor_objs: if desc_key not in self._sequence_counters: self._sequence_counters[desc_key] = 1 self._sequence_counters_copy[desc_key] = 1 # This is needed to 'cancel' an open bundling (e.g. create) if # the pause happens after a 'checkpoint', after a 'create', but # before the paired 'save'. self.bundling = False
[docs] async def unmonitor(self, msg): """ Stop monitoring; i.e., remove the callback emitting event documents. Expected message object is:: Msg('unmonitor', obj) """ obj = check_supports(msg.obj, Subscribable) if obj not in self._monitor_params: raise IllegalMessageSequence(f"Cannot 'unmonitor' {obj}; it is not being monitored.") cb, kwargs = self._monitor_params[obj] obj.clear_sub(cb) del self._monitor_params[obj] await self.reset_checkpoint_state_coro()
[docs] async def save(self, msg): """Save the event that is currently being bundled Create and emit an Event document containing the data read from devices in self._objs_read. Emit any Resource and Datum documents cached by those devices before emitting the Event document. If this is the first Event of its stream then create and emit the Event Descriptor document before emitting Resource, Datum, and Event documents. Expected message object is:: Msg('save') """ if not self.bundling: raise IllegalMessageSequence( "A 'create' message must be sent, to " "open an event bundle, before that " "bundle can be saved with 'save'." ) # Short-circuit if nothing has been read. (Do not create empty Events.) if not self._objs_read: self.bundling = False self._bundle_name = None return # The Event Descriptor is uniquely defined by the set of objects # read in this Event grouping. objs_read = frozenset(self._objs_read) # Event Descriptor key desc_key = self._bundle_name # This is a separate check because it can be reset on resume. self.bundling = False self._bundle_name = None ( descriptor_doc, compose_event, _, ) = self._descriptors.get(desc_key, (None, None, None)) d_objs = self._descriptor_objs.get(desc_key, None) objs_dks = {} # we do not have the descriptor cached, make it if descriptor_doc is None or d_objs is None: for obj in objs_read: await self._ensure_cached(obj, collect=isinstance(obj, Collectable)) objs_dks[obj] = self._describe_cache[obj] descriptor_doc, compose_event, d_objs = await self._prepare_stream(desc_key, objs_dks) # do have the descriptor cached elif frozenset(d_objs) != objs_read: raise RuntimeError(f"Mismatched objects read, expected {frozenset(d_objs)!s}, " f"got {objs_read!s}") # Resource and Datum documents indices_generated = await self._pack_external_assets(self._asset_docs_cache, message_stream_name=desc_key) if indices_generated > 1: raise RuntimeError( "Received multiple indices in a `stream_datum` document for one event, " ' during a `read()` `save()`. `stream_datum` should have indices {"start": n, "stop": n+1} ' "in a `read()` `save()`." ) # Merge list of readings into single dict. readings = {k: v for d in self._read_cache for k, v in d.items()} data, timestamps = _rearrange_into_parallel_dicts(readings) # Mark all externally-stored data as not filled so that consumers # know that the corresponding data are identifiers, not dereferenced # data. filled = { k: False for k, v in self._descriptors[desc_key].descriptor_doc["data_keys"].items() if "external" in v and v["external"] != "STREAM:" } event_doc = compose_event( data=data, timestamps=timestamps, filled=filled, ) await self.emit(DocumentNames.event, event_doc) doc_logger.debug( "[event] document emitted with data keys %r (run_uid=%r)", data.keys(), self._run_start_uid, extra={"doc_name": "event", "run_uid": self._run_start_uid, "data_keys": data.keys()}, )
[docs] def clear_monitors(self): for obj, (cb, kwargs) in list(self._monitor_params.items()): # noqa: B007 try: obj.clear_sub(cb) except Exception: self.log.exception("Failed to stop monitoring %r.", obj) else: del self._monitor_params[obj]
[docs] def reset_checkpoint_state(self): # Keep a safe separate copy of the sequence counters to use if we # rewind and retake some data points. for key, counter in list(self._sequence_counters.items()): self._sequence_counters_copy[key] = counter
[docs] async def reset_checkpoint_state_coro(self): self.reset_checkpoint_state()
[docs] async def suspend_monitors(self): for obj, (cb, kwargs) in self._monitor_params.items(): # noqa: B007 obj.clear_sub(cb)
[docs] async def restore_monitors(self): for obj, (cb, kwargs) in self._monitor_params.items(): obj.subscribe(cb, **kwargs)
[docs] async def clear_checkpoint(self, msg): self._sequence_counters_copy.clear()
[docs] async def drop(self, msg): """Drop the event that is currently being bundled Expected message object is:: Msg('drop') """ if not self.bundling: raise IllegalMessageSequence( "A 'create' message must be sent, to " "open an event bundle, before that " "bundle can be dropped with 'drop'." ) self.bundling = False self._bundle_name = None self.log.debug("Dropped open event bundle")
[docs] async def kickoff(self, msg): """Start a flyscan object. Expected message object is: If `flyer_object` has a `kickoff` function that takes no arguments:: Msg('kickoff', flyer_object) Msg('kickoff', flyer_object, group=<name>) If *flyer_object* has a ``kickoff`` function that takes ``(start, stop, steps)`` as its function arguments:: Msg('kickoff', flyer_object, start, stop, step) Msg('kickoff', flyer_object, start, stop, step, group=<name>) """ self._uncollected.add(msg.obj)
# Could we have a look at changing this now? we use one of each use can in # seperate places so it could be two seperate methods for each dictionary type. def _maybe_format_datakeys_with_stream_name( self, describe_collect_dict: Union[Dict[str, DataKey], Dict[str, Dict[str, DataKey]]], message_stream_name: Optional[str] = None, ) -> List[Tuple[str, Dict[str, DataKey]]]: """ Check if the dictionary returned by describe collect is a dict `{str: DataKey}` or a `{str: {str: DataKey}}`. If a `message_stream_name` is passed then return a singleton list of the form of `{message_stream_name: describe_collect_dict}.items()`. If the `message_stream_name` is None then return the `describe_collect_dict.items()`. """ def has_str_source(d: dict): return isinstance(d, dict) and isinstance(d.get("source", None), str) if describe_collect_dict: first_value = list(describe_collect_dict.values())[0] if has_str_source(first_value): # We have Dict[str, DataKey], so return just this # If stream name not given then default to "primary" return [(message_stream_name or "primary", describe_collect_dict)] elif all(has_str_source(v) for v in first_value.values()): # We have Dict[str, Dict[str, DataKey]] so return its items if message_stream_name and list(describe_collect_dict) != [message_stream_name]: # The collect contained a name and describe_collect returned a Dict[str, Dict[str, DataKey]], # this is only acceptable if the only key in the parent dict is message_stream_name raise RuntimeError( f"Expected a single stream {message_stream_name!r}, got {describe_collect_dict}" ) return list(describe_collect_dict.items()) else: raise RuntimeError( f"Invalid describe_collect return: {describe_collect_dict} when collect " f"was called on {message_stream_name}" ) else: # Empty dict, could be either but we don't care return [] async def _cache_describe_collect(self, obj: Collectable): "Read the object's describe and cache it." obj = check_supports(obj, Collectable) c = await maybe_await(obj.describe_collect()) self._describe_collect_cache[obj] = c async def _describe_collect(self, collect_object: Flyable): """Read an object's describe_collect and cache it. Read describe collect for a collect_object and ensure it is cached in the _describe_collect_cache. This is required for scans of a single collect object where the data structure is doubly nested. In this case calling describe_collect on the object returns a data structure like so: { "stream1" : {"stream1-pv1":{data_keys}, "stream1-pv2" :{data_keys}}, "stream2" : {"stream2-pv1":{data_keys}, "stream2-pv2" :{data_keys}} } Single nested data keys should be rejected since they are new style, and are collected under one stream. They should be pre-declared with declare-stream, prior to collecting. The describe_collect on this object returns a data structure like so: { "stream1-pv1": {data_keys}, "stream1-pv1": {data_keys} } """ await self._ensure_cached(collect_object, collect=True) describe_collect = self._describe_collect_cache[collect_object] describe_collect_items = list(self._maybe_format_datakeys_with_stream_name(describe_collect)) local_descriptors: Dict[Any, Dict[FrozenSet[str], ComposeDescriptorBundle]] = {} # Check that singly nested stuff should have been pre-declared def is_data_key(obj: Any) -> bool: return isinstance(obj, dict) and {"dtype", "shape", "source"}.issubset(frozenset(obj.keys())) assert all( not is_data_key(value) for value in describe_collect.values() ), "Single nested data keys should be pre-decalred" # Make sure you can't use identidal data keys in multiple streams duplicates: Dict[str, DataKey] = defaultdict(dict) for stream, data_keys in describe_collect.items(): for key, stuff in data_keys.items(): for other_stream, other_data_keys in describe_collect.items(): for other_key, other_stuff in other_data_keys.items(): if stream != other_stream and key == other_key and stuff == other_stuff: duplicates[stream][key] = stuff if len(duplicates) > 0: raise RuntimeError( f"Can't use identical data keys in multiple streams: {duplicates}", f"Data keys: {list(duplicates.values())}", f"streams: {duplicates.keys()}", ) for stream_name, stream_data_keys in describe_collect_items: if stream_name not in self._descriptor_objs or ( collect_object not in self._descriptor_objs[stream_name] ): await self._prepare_stream(stream_name, {collect_object: stream_data_keys}) else: objs_read = self._descriptor_objs[stream_name] if stream_data_keys != objs_read[collect_object]: raise RuntimeError( "Mismatched objects read, " f"expected {stream_data_keys!s}, " f"got {objs_read!s}" ) local_descriptors[frozenset(stream_data_keys)] = self._descriptors[stream_name] self._local_descriptors[collect_object] = local_descriptors async def _pack_seq_nums_into_stream_datum( self, doc: StreamDatum, message_stream_name: str, stream_datum_previous_indices_difference: int ) -> int: if doc["seq_nums"] != StreamRange(start=0, stop=0): raise EventModelValueError( f"Received `seq_nums` {doc['seq_nums']} in stream {doc['stream_resource']} " "during `collect()` or `describe_collect()`. `seq_nums` should be None or " "`StreamRange(start=0, stop=0)` on `ComposeStreamDatum` when used with the " "run engine." ) indices_difference = doc["indices"]["stop"] - doc["indices"]["start"] if ( stream_datum_previous_indices_difference and stream_datum_previous_indices_difference != indices_difference ): raise EventModelValueError( f"Received `indices` {doc['indices']} during `collect()` these are of a different " f"width `{indices_difference}` than other detectors in the same collect() or save()." ) current_seq_counter = self._sequence_counters[message_stream_name] doc["seq_nums"] = StreamRange(start=current_seq_counter, stop=current_seq_counter + indices_difference) if doc["stream_resource"] not in self._stream_resource_data_keys: raise RuntimeError( f"Receieved a `steam_datum` referring to an unknown stream resource {doc['stream_resource']}" ) return indices_difference # message strem name here? async def _pack_external_assets( self, asset_docs: Iterable[Tuple[str, ExternalAssetDoc]], message_stream_name: Optional[str] ): """Packs some external asset documents with relevant information from the run.""" stream_datum_previous_indices_difference = 0 data_keys_received = set() descriptor_doc = None external_data_keys = None if message_stream_name: descriptor_doc = self._descriptors[message_stream_name].descriptor_doc external_data_keys = self.get_external_data_keys(descriptor_doc["data_keys"]) for name, doc in asset_docs: if name == DocumentNames.resource.value: doc["run_start"] = self._run_start_uid elif name == DocumentNames.stream_resource.value: doc["run_start"] = self._run_start_uid if doc["uid"] in self._stream_resource_data_keys: raise RuntimeError(f"Received `stream_resource` with uid {doc['uid']} twice.") self._stream_resource_data_keys[doc["uid"]] = doc["data_key"] if not external_data_keys or doc["data_key"] not in external_data_keys: raise RuntimeError( f"Receieved a `stream_resource` with data_key {doc['data_key']} that is not in the " f"descriptor 'STREAM:' data_keys {external_data_keys}" ) elif name == DocumentNames.stream_datum.value: if doc["descriptor"]: raise RuntimeError( f"Received a `stream_datum` {doc['uid']} with a `descriptor` uid already " f"filled in, with the value {doc['descriptor']} this should be an empty string." ) if not descriptor_doc: raise RuntimeError(f"`descriptor` not made for stream {message_stream_name}.") data_keys_received.add(self._stream_resource_data_keys[doc["stream_resource"]]) doc["descriptor"] = descriptor_doc["uid"] stream_datum_previous_indices_difference = await self._pack_seq_nums_into_stream_datum( doc, message_stream_name, # type: ignore stream_datum_previous_indices_difference, # type: ignore ) elif name == DocumentNames.datum.value: ... else: raise RuntimeError( f"Tried to emit an external asset {name}, acceptable external assets are " "`resource`, `stream_resource`, `datum`, or `stream_datum`" ) await self.emit(DocumentNames(name), doc) doc_logger.debug( "[%s] document emitted %r", name, doc, extra={"doc_name": name, "run_uid": self._run_start_uid, "doc": doc}, ) # Check we have a stream_datum for each external data_key in the descriptor if descriptor_doc and data_keys_received and set(external_data_keys) != data_keys_received: # type: ignore raise RuntimeError( f"Received `stream_datum` for each of the data_keys {data_keys_received}, " # type: ignore f"expected `stream_datum` for each of the data_keys {set(external_data_keys)}." # type: ignore ) return stream_datum_previous_indices_difference def get_external_data_keys(self, data_keys: Dict[str, DataKey]) -> List[DataKey]: """Get the external data keys from the descriptor data_keys dictionary""" return [x for x in data_keys if ("external" in data_keys[x] and data_keys[x]["external"] == "STREAM:")] async def _collect_events( self, collect_obj: EventCollectable, local_descriptors, return_payload: bool, message_stream_name: Optional[str], ): payload = [] pages: Dict[FrozenSet[str], List[Event]] = defaultdict(list) if message_stream_name: compose_event = self._descriptors[message_stream_name].compose_event data_keys = self._descriptors[message_stream_name].descriptor_doc["data_keys"] objs_read = frozenset(data_keys.keys()) async for partial_event in iterate_maybe_async(collect_obj.collect()): if return_payload: payload.append(partial_event) if not message_stream_name: objs_read = frozenset(partial_event["data"]) compose_event = local_descriptors[objs_read].compose_event data_keys = local_descriptors[objs_read].descriptor_doc["data_keys"] assert frozenset(data_keys.keys()) == objs_read if [x for x in self.get_external_data_keys(data_keys) if x in partial_event["data"]]: raise RuntimeError("Received an event containing data for external data keys.") # is there a way to generalise the keys? if "filled" in partial_event.keys(): event = compose_event( data=partial_event["data"], timestamps=partial_event["timestamps"], filled=partial_event["filled"], ) else: event = compose_event(data=partial_event["data"], timestamps=partial_event["timestamps"]) pages[objs_read].append(event) for event_list in pages.values(): await self.emit(DocumentNames.event_page, pack_event_page(*event_list)) doc_logger.debug( "[event_page] document is emitted for descriptors (run_uid=%r)", self._run_start_uid, extra={"doc_name": "event_page", "run_uid": self._run_start_uid}, ) return payload async def _collect_event_pages( self, collect_obj: EventPageCollectable, local_descriptors, return_payload: bool, message_stream_name: str ): payload = [] if message_stream_name: compose_event_page = self._descriptors[message_stream_name].compose_event_page data_keys = self._descriptors[message_stream_name].descriptor_doc["data_keys"] async for ev_page in iterate_maybe_async(collect_obj.collect_pages()): if return_payload: payload.append(ev_page) if not message_stream_name: objs_read = frozenset(ev_page["data"]) compose_event_page = local_descriptors[objs_read].compose_event_page data_keys = local_descriptors[objs_read].descriptor_doc["data_keys"] if [x for x in self.get_external_data_keys(data_keys) if x in ev_page["data"]]: raise RuntimeError("Received an event_page containing data for external data keys.") ev_page = compose_event_page(data=ev_page["data"], timestamps=ev_page["timestamps"]) doc_logger.debug( "[event_page] document is emitted with data keys %r (run_uid=%r)", ev_page["data"].keys(), ev_page["uid"], extra={ "doc_name": "event_page", "run_uid": self._run_start_uid, "data_keys": ev_page["data"].keys(), }, ) await self.emit(DocumentNames.event_page, ev_page) return payload
[docs] async def collect(self, msg): """ Collect data cached by a flyer and emit documents. Expect message object is Msg('collect', collect_obj, collect_obj_2, ..., stream=True, return_payload=True, name='stream_name') Where there must be at least one collect object. If multiple are used they must obey the WritesStreamAssets protocol. """ stream_name = None if not self.run_is_open: # sanity check -- 'kickoff' should catch this and make this # code path impossible raise IllegalMessageSequence("A 'collect' message was sent but no run is open.") # If stream is True, run 'event' subscription per document. # If stream is False, run 'event_page' subscription once. # Stream is True is no longer supported stream = msg.kwargs.get("stream", False) if stream is True: raise RuntimeError( "Collect now emits EventPages (stream=False), " "so emitting Events (stream=True) is no longer supported" ) # If True, accumulate all the Events in memory and return them at the # end, providing the plan access to the Events. If False, do not # accumulate, and return None. return_payload = msg.kwargs.get("return_payload", True) # Get a list of the collectable objects from the message obj and args collect_objects = [check_supports(obj, Collectable) for obj in (msg.obj,) + msg.args] # Get references to get_index methods if we have more than one collect object # raise error if collect_objects don't obey WritesStreamAssests protocol indices: List[Callable[[None], SyncOrAsync[int]]] = [] if len(collect_objects) > 1: indices = [check_supports(obj, WritesStreamAssets).get_index for obj in collect_objects] # Warn for page collectable support for obj in collect_objects: if isinstance(obj, EventCollectable) and isinstance(obj, EventPageCollectable): doc_logger.warn( "collect() was called for a device %r which is both EventCollectable " "and EventPageCollectable. Using device.collect_pages().", obj.name, ) self._uncollected.discard(obj) # Get the provided message stream name for singly nested scans message_stream_name = msg.kwargs.get("name", None) # Retrive the stream names from pre-declared streams declared_stream_names = self._declared_stream_names.get(frozenset(collect_objects), []) # If a stream name was provided in the message, check the stream has been declared # If one was not provided, but a single stream has been declared, then use that stream. if message_stream_name: assert ( message_stream_name in declared_stream_names ), "If a message stream name is provided declare stream needs to be called first." stream_name = message_stream_name elif declared_stream_names: assert len(frozenset(declared_stream_names)) == 1 # Allow duplicate declarations stream_name = declared_stream_names[0] # If there is not a stream then we should be using an old-style doubly nested # and we need to describe_collect and prepare the nested streams. if not stream_name: if frozenset(collect_objects) not in self._local_descriptors or ( collect_objects[0] not in self._local_descriptors ): if len(collect_objects) > 1: raise IllegalMessageSequence( "If collecting multiple objects you must predeclare a stream for all " "the objects first and provide the stream name" ) else: await self._describe_collect(collect_objects[0]) # Get the indicies from the collect objects coros = [maybe_await(get_index()) for get_index in indices] if coros: # There is more than one collect object, so collect up to a minimum index min_index = min(await asyncio.gather(*coros)) else: # There is only one collect object, so don't pass an index down min_index = None collected_asset_docs = [ x for obj in collect_objects async for x in maybe_collect_asset_docs( msg, obj, index=min_index, ) ] indices_difference = await self._pack_external_assets( collected_asset_docs, message_stream_name=stream_name ) # Make event pages for an object which is EventCollectable or EventPageCollectable # objects that are EventCollectable will now group the Events and Emit an Event Page if len(collect_objects) == 1 and not isinstance(collect_objects[0], WritesStreamAssets): local_descriptors: Dict[Any, Dict[FrozenSet[str], ComposeDescriptorBundle]] = {} collect_obj = collect_objects[0] # If the single collect object is singly nested, gather descriptors if collect_obj not in self._local_descriptors: objs = self._descriptor_objs[stream_name] data_keys = objs[collect_obj] local_descriptors[frozenset(data_keys)] = self._descriptors[stream_name] self._local_descriptors[collect_obj] = local_descriptors local_descriptors = self._local_descriptors[collect_obj] if isinstance(collect_obj, EventPageCollectable): payload = await self._collect_event_pages( collect_obj, local_descriptors, return_payload, stream_name ) # TODO: check that event pages have same length as indices_difference elif isinstance(collect_obj, EventCollectable): payload = await self._collect_events(collect_obj, local_descriptors, return_payload, stream_name) # TODO: check that events have same length as indices_difference else: return_payload = False if not stream_name: raise RuntimeError( "A `collect` message on a device that isn't EventCollectable or EventPageCollectable " "requires a `name=stream_name` argument" ) # Since there are no events or event_pages incrementing the sequence counter, we do it ourselves. self._sequence_counters[stream_name] += indices_difference if return_payload: return payload else: # Since there are no events or event_pages incrementing the sequence counter, we do it ourselves. self._sequence_counters[stream_name] += indices_difference
[docs] async def backstop_collect(self): for obj in list(self._uncollected): try: await self.collect(Msg("collect", obj)) except Exception: self.log.exception("Failed to collect %r.", obj)
[docs] async def configure(self, msg): """Configure an object Expected message object is :: Msg('configure', object, *args, **kwargs) which results in this call :: object.configure(*args, **kwargs) """ obj = msg.obj # Invalidate any event descriptors that include this object. # New event descriptors, with this new configuration, will # be created for any future event documents. for name in list(self._descriptors): obj_set = self._descriptor_objs[name] if obj in obj_set: del self._descriptors[name] await self._prepare_stream(name, obj_set) continue await self._cache_read_config(obj)