Source code for bluesky.callbacks.core

"""
Useful callbacks for the Run Engine
"""

import logging
import os
import time as ttime
import warnings
from collections import OrderedDict, deque, namedtuple
from datetime import datetime
from functools import partial as _partial
from functools import wraps as _wraps
from itertools import count

from event_model import DocumentRouter

from ..utils import ensure_uid

MIMETYPE_LOOKUP = {
    "hdf5": "application/x-hdf5",
    "AD_HDF5_SWMR_STREAM": "application/x-hdf5",
    "AD_HDF5_SWMR_SLICE": "application/x-hdf5",
    "AD_TIFF": "multipart/related;type=image/tiff",
    "AD_HDF5_GERM": "application/x-hdf5",
}

logger = logging.getLogger(__name__)


[docs] def make_callback_safe(func=None, *, logger=None): """ If the wrapped func raises any exceptions, log them but continue. This is intended to ensure that any failures in non-critical callbacks do not interrupt data acquisition. It should *not* be applied to any critical callbacks, such as ones that perform data-saving, but is well suited to callbacks that perform non-critical streaming visualization or data processing. To debug the issue causing a failure, it can be convenient to turn this off and let the failures raise. To do this, set the environment variable ``BLUESKY_DEBUG_CALLBACKS=1``. Parameters ---------- func: callable logger: logging.Logger, optional Examples -------- Decorate a callback to make sure it will not interrupt data acquisition if it fails. >>> @make_callback_safe ... def callback(name, doc): ... ... """ if func is None: return _partial(make_callback_safe, logger=logger) @_wraps(func) def inner(*args, **kwargs): try: return func(*args, **kwargs) except Exception: debug_mode = os.environ.get("BLUESKY_DEBUG_CALLBACKS", False) if logger is not None: if debug_mode: msg = f"Exception in {func}" else: msg = ( f"An exception raised in the callback {func} " "is being suppressed to not interrupt plan " "execution. To investigate try setting the " "BLUESKY_DEBUG_CALLBACKS env to '1'" ) logger.exception(msg) if debug_mode: raise return inner
[docs] def make_class_safe(cls=None, *, to_wrap=None, logger=None): """ If the wrapped func raises any exceptions, log them but continue. This is intended to ensure that any failures in non-critical callbacks do not interrupt data acquisition. It should *not* be applied to any critical callbacks, such as ones that perform data-saving, but is well suited to callbacks that perform non-critical streaming visualization or data processing. To debug the issue causing a failure, it can be convenient to turn this off and let the failures raise. To do this, set the environment variable ``BLUESKY_DEBUG_CALLBACKS=1``. Parameters ---------- cls: callable to_wrap: List[String], optional Names of methods of cls to wrap. Default is ``['__call__']``. logger: logging.Logger, optional Examples -------- Decorate a class to make sure it will not interrupt data acquisition if it fails. >>> @make_class_safe ... class Callback(event_model.DocumentRouter): ... ... """ if cls is None: return _partial(make_class_safe, to_wrap=to_wrap, logger=logger) if to_wrap is None: to_wrap = ["__call__"] for f_name in to_wrap: setattr(cls, f_name, make_callback_safe(getattr(cls, f_name), logger=logger)) return cls
class CallbackBase(DocumentRouter): log = None class CallbackCounter: "As simple as it sounds: count how many times a callback is called." # Wrap itertools.count in something we can use as a callback. def __init__(self): self.counter = count() self(None, {}) # Pass a fake doc to prime the counter (start at 1). def __call__(self, name, doc): self.value = next(self.counter) def print_metadata(name, doc): "Print all fields except uid and time." for field, value in sorted(doc.items()): # uid is returned by the RunEngine, and time is self-evident if field not in ["time", "uid"]: print(f"{field}: {value}") def collector(field, output): """ Build a function that appends data to a list. This is useful for testing but not advised for general use. (There is probably a better way to do whatever you want to do!) Parameters ---------- field : str the name of a data field in an Event output : mutable iterable such as a list Returns ------- func : function expects one argument, an Event dictionary """ def f(name, event): output.append(event["data"][field]) return f def format_num(x, max_len=11, pre=5, post=5): if (abs(x) > 10**pre or abs(x) < 10**-post) and x != 0: x = f"%.{post}e" % x else: x = f"%{pre}.{post}f" % x return x def get_obj_fields(fields): """ If fields includes any objects, get their field names using obj.describe() ['det1', det_obj] -> ['det1, 'det_obj_field1, 'det_obj_field2']" """ string_fields = [] for field in fields: if isinstance(field, str): string_fields.append(field) else: try: field_list = sorted(field.describe().keys()) except AttributeError: raise ValueError( # noqa: B904 "Fields must be strings or objects with a 'describe' method that return a dict." ) string_fields.extend(field_list) return string_fields class CollectThenCompute(CallbackBase): def __init__(self): self._start_doc = None self._stop_doc = None self._events = deque() self._descriptors = deque() def start(self, doc): self._start_doc = doc super().start(doc) def descriptor(self, doc): self._descriptors.append(doc) super().descriptor(doc) def event(self, doc): self._events.append(doc) super().event(doc) def stop(self, doc): self._stop_doc = doc self.compute() super().stop(doc) def reset(self): self._start_doc = None self._stop_doc = None self._events.clear() self._descriptors.clear() def compute(self): raise NotImplementedError("This method must be defined by a subclass.")
[docs] @make_class_safe(logger=logger) class LiveTable(CallbackBase): """Live updating table Parameters ---------- fields : list List of fields to add to the table. stream_name : str, optional The event stream to watch for print_header_interval : int, optional Reprint the header every this many lines, defaults to 50 min_width : int, optional The minimum width is spaces of the data columns. Defaults to 12 default_prec : int, optional Precision to use if it can not be found in descriptor, defaults to 3 extra_pad : int, optional Number of extra spaces to put around the printed data, defaults to 1 separator_lines : bool, optional Add empty lines before and after the printed table, default True logbook : callable, optional Must take a sting as the first positional argument def logbook(input_str): pass out : callable, optional Function to call to 'print' a line. Defaults to `print` """ _FMTLOOKUP = { "s": "{pad}{{{k}: >{width}.{prec}{dtype}}}{pad}", "f": "{pad}{{{k}: >{width}.{prec}{dtype}}}{pad}", "g": "{pad}{{{k}: >{width}.{prec}{dtype}}}{pad}", "d": "{pad}{{{k}: >{width}{dtype}}}{pad}", } _FMT_MAP = { "number": "f", "integer": "d", "string": "s", } _fm_sty = namedtuple("fm_sty", ["width", "prec", "dtype"]) # type: ignore water_mark = "{st[plan_type]} {st[plan_name]} ['{st[uid]:.8s}'] (scan num: {st[scan_id]})" ev_time_key = "SUPERLONG_EV_TIMEKEY_THAT_I_REALLY_HOPE_NEVER_CLASHES" def __init__( self, fields, *, stream_name="primary", print_header_interval=50, min_width=12, default_prec=3, extra_pad=1, separator_lines=True, logbook=None, out=print, ): super().__init__() self._header_interval = print_header_interval # expand objects self._fields = get_obj_fields(fields) self._stream = stream_name self._start = None self._stop = None self._descriptors = set() self._pad_len = extra_pad self._extra_pad = " " * extra_pad self._min_width = min_width self._default_prec = default_prec self._separator_lines = separator_lines self._format_info = OrderedDict( [ ("seq_num", self._fm_sty(10 + self._pad_len, "", "d")), (self.ev_time_key, self._fm_sty(10 + 2 * extra_pad, 10, "s")), ] ) self._rows = [] self.logbook = logbook self._sep_format = None self._out = out def descriptor(self, doc): def patch_up_precision(p): try: return int(p) except (TypeError, ValueError): return self._default_prec if doc["name"] != self._stream: return self._descriptors.add(doc["uid"]) dk = doc["data_keys"] for k in self._fields: width = max(self._min_width, len(k) + 2, self._default_prec + 1 + 2 * self._pad_len) try: dk_entry = dk[k] except KeyError: # this descriptor does not know about this key continue if dk_entry["dtype"] not in self._FMT_MAP: warnings.warn( # noqa: B028 "The key {} will be skipped because LiveTable " "does not know how to display the dtype {}" "".format(k, dk_entry["dtype"]) ) continue prec = patch_up_precision(dk_entry.get("precision", self._default_prec)) fmt = self._fm_sty(width=width, prec=prec, dtype=self._FMT_MAP[dk_entry["dtype"]]) self._format_info[k] = fmt self._sep_format = "+" + "+".join("-" * f.width for f in self._format_info.values()) + "+" self._main_fmnt = "|".join( "{{: >{w}}}{pad}".format(w=f.width - self._pad_len, pad=" " * self._pad_len) for f in self._format_info.values() ) headings = [k if k != self.ev_time_key else "time" for k in self._format_info] self._header = "|" + self._main_fmnt.format(*headings) + "|" self._data_formats = OrderedDict( ( k, self._FMTLOOKUP[f.dtype].format( k=f"h{str(hash(k))}", width=f.width - 2 * self._pad_len, prec=f.prec, dtype=f.dtype, pad=self._extra_pad, ), ) for k, f in self._format_info.items() ) self._count = 0 if self._separator_lines: self._print("\n") self._print(self._sep_format) self._print(self._header) self._print(self._sep_format) super().descriptor(doc) def event(self, doc): try: if ensure_uid(doc["descriptor"]) not in self._descriptors: return # shallow copy so we can mutate data = dict(doc["data"]) self._count += 1 if not self._count % self._header_interval: self._print(self._sep_format) self._print(self._header) self._print(self._sep_format) fmt_time = str(datetime.fromtimestamp(doc["time"]).time()) data[self.ev_time_key] = fmt_time data["seq_num"] = doc["seq_num"] cols = [ f.format(**{f"h{str(hash(k))}": data[k]}) # Show data[k] if k exists in this Event and is 'filled'. # (The latter is only applicable if the data is # externally-stored -- hence the fallback to `True`.) if ((k in data) and doc.get("filled", {}).get(k, True)) # Otherwise use a placeholder of whitespace. else " " * self._format_info[k].width for k, f in self._data_formats.items() ] self._print("|" + "|".join(cols) + "|") except Exception as ex: if self.log is not None: self.log.exception(ex) self._print(f"{{k:*^{self._min_width}}}".format(k=" failed to format row ")) if os.environ.get("BLUESKY_DEBUG_CALLBACKS", False): raise ex super().event(doc) def stop(self, doc): if ensure_uid(doc["run_start"]) != self._start["uid"]: return # This sleep is just cosmetic. It improves the odds that the bottom # border is not printed until all the rows from events are printed, # avoiding this ugly scenario: # # | 4 | 22:08:56.7 | 0.000 | # +-----------+------------+------------+ # generator scan ['6d3f71'] (scan num: 1) # Out[2]: | 5 | 22:08:56.8 | 0.000 | ttime.sleep(0.1) if self._sep_format is not None: self._print(self._sep_format) self._stop = doc wm = self.water_mark.format(st=self._start) self._out(wm) if self.logbook: self.logbook("\n".join([wm] + self._rows)) if self._separator_lines: self._print("\n") super().stop(doc) def start(self, doc): self._rows = [] self._start = doc self._stop = None self._sep_format = None super().start(doc) def _print(self, out_str): self._rows.append(out_str) self._out(out_str)