Source code for bluesky.callbacks.zmq

"""
The key classes needed to use 0MQ for multiprocess document communication.

`Publisher` : subscribe this to the RE to emit the documents.  Expects a server to
have a SUBSCRIBE port open to PUB to.

`RemoteDispatcher` : subscribe callbacks to this class in a remote process.  Expects
a server to have a PUB port open to SUBSCRIBE to.

`Proxy` : server that binds ports for Pubslisher to push to and the Dispatcher
to listen to.  Typically this is started with the cli tool ``bluesky-zmq-proxy``

"""

import asyncio
import copy
import logging
import pickle
import threading
import warnings
from collections.abc import Callable
from pathlib import Path
from typing import Any, NamedTuple

import zmq
import zmq.asyncio as zmq_asyncio
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator

from ..run_engine import Dispatcher, DocumentNames

logger = logging.getLogger(__name__)


class ServerCurve(NamedTuple):
    # path to the secret key for the server
    secret_path: Path
    # path to folder of client's public keys.  If None, allow all clients
    client_public_keys: Path | None
    # set of ip addresses to allow
    allow: set[str] | None


class ClientCurve(NamedTuple):
    # path to the secret key for the server
    secret_path: Path
    # path to the servers public key
    server_public_key: Path


def _normalize_address(inp: str | tuple | int | None):
    if isinstance(inp, str):
        if "://" in inp:
            protocol, _, rest_str = inp.partition("://")
        else:
            protocol = "tcp"
            rest_str = inp
    elif isinstance(inp, tuple):
        if inp[0] in ["tcp", "ipc"]:
            protocol, *rest = inp
        else:
            protocol = "tcp"
            rest = list(inp)
        if protocol == "tcp":
            if len(rest) == 2:
                rest_str = ":".join(str(r) for r in rest)
            else:
                (rest_str,) = rest
        else:
            (rest_str,) = rest
    elif isinstance(inp, int):
        protocol = "tcp"
        rest_str = f"0.0.0.0:{inp}"
    elif inp is None:
        protocol = "tcp"
        rest_str = "*"

    else:
        raise TypeError(f"Input expected to be int, str, or tuple, not {type(inp).__name__}")

    return f"{protocol}://{rest_str}"


class Bluesky0MQDecodeError(Exception):
    """Custom exception class for things that go wrong reading message from wire."""

    ...


[docs] class Publisher: """ A callback that publishes documents to a 0MQ proxy. Parameters ---------- address : string or tuple Address of a running 0MQ proxy, given either as a string like ``'127.0.0.1:5567'`` or as a tuple like ``('127.0.0.1', 5567)`` prefix : bytes, optional User-defined bytestring used to distinguish between multiple Publishers. May not contain b' '. serializer: function, optional optional function to serialize data. Default is pickle.dumps curve_config: ClientCurve, optional CURVE security configuration for client authentication. Examples -------- Publish from a RunEngine to a Proxy running on localhost on port 5567. >>> publisher = Publisher('localhost:5567') >>> RE = RunEngine({}) >>> RE.subscribe(publisher) """ def __init__( self, address: str | tuple[str, int], *, prefix: bytes = b"", serializer: Callable = pickle.dumps, curve_config: ClientCurve | None = None, ): if isinstance(prefix, str): raise ValueError("prefix must be bytes, not string") if b" " in prefix: raise ValueError(f"prefix {prefix!r} may not contain b' '") self.address = _normalize_address(address) self._prefix = bytes(prefix) self._context = zmq.Context() self._socket = self._context.socket(zmq.PUB) if curve_config is not None: # Load the client cert pair client_public, client_secret = zmq.auth.load_certificate(curve_config.secret_path) self._socket.setsockopt(zmq.CURVE_PUBLICKEY, client_public) if client_secret is None: raise ValueError("The client secret key could not be found.") self._socket.setsockopt(zmq.CURVE_SECRETKEY, client_secret) # Load the server public key and register with the socket server_key, _ = zmq.auth.load_certificate(curve_config.server_public_key) self._socket.setsockopt(zmq.CURVE_SERVERKEY, server_key) self._socket.connect(self.address) self._serializer = serializer def __call__(self, name: str, doc: dict[str, Any]): doc = copy.deepcopy(doc) message = b" ".join([self._prefix, name.encode(), self._serializer(doc)]) self._socket.send(message) def close(self): self._socket.close() self._context.destroy() # close Socket(s); terminate Context
[docs] class Proxy: """ Start a 0MQ proxy on the local host. The addresses can be specified flexibly. It is best to use a domain_socket (available on unix): - ``'icp:///tmp/domain_socket'`` - ``('ipc', '/tmp/domain_socket')`` tcp sockets are also supported: - ``'tcp://localhost:6557'`` - ``6657`` (implicitly binds to ``'tcp://localhost:6557'`` - ``('tcp', 'localhost', 6657)`` - ``('localhost', 6657)`` Parameters ---------- in_address : str or tuple or int, optional Address that RunEngines should broadcast to. If None, a random tcp port on all interfaces is used. out_address : str or tuple or int, optional Address that subscribers should subscribe to. If None, a random tcp port on all interfaces is used. in_curve: ServerCurve or ClientCurve or None, optional CURVE security configuration for the incoming socket. If None, no security is applied. out_curve: ServerCurve or ClientCurve or None, optional CURVE security configuration for the outgoing socket. If None, no security is applied. in_bind: bool, default True If True, the incoming socket will be bound to the address. out_bind: bool, default True If True, the outgoing socket will be bound to the address. in_port: int or None, optional DEPRECATED alias for in_address. If specified, must be used instead of in_address out_port: int or None, optional DEPRECATED alias for out_address. If specified, must be used instead of out_address Attributes ---------- in_address: int or str or tuple Port that RunEngines should broadcast to. out_address : int or str or tuple Port that subscribers should subscribe to. closed : boolean True if the Proxy has already been started and subsequently interrupted and is therefore unusable. Examples -------- Run on specific ports. >>> proxy = Proxy(in_address='localhost:5567', out_address='localhost:5568') >>> proxy Proxy(in_port=5567, out_port=5568) >>> proxy.start() # runs until interrupted Run on random ports, and access those ports before starting. >>> proxy = Proxy() >>> proxy Proxy(in_port=56504, out_port=56505) >>> proxy.in_port 56504 >>> proxy.out_port 56505 >>> proxy.start() # runs until interrupted """ @staticmethod def configure_server_socket( ctx: zmq.Context, sock_type: int, address: str | tuple | int | None, curve: ServerCurve | ClientCurve | None, bind: bool = True, ) -> tuple[zmq.Socket, int | str]: """Helper method to create and bind or connect a socket with optional CURVE security. Parameters ---------- ctx : zmq.Context The ZMQ context to use for creating the socket. sock_type : int The type of socket to create (e.g. zmq.SUB, zmq.PUB). address : str | tuple | int | None The address to bind or connect the socket to. curve : ServerCurve | ClientCurve | None CURVE security configuration. If None, no security is applied. bind : bool, default True If True, the socket will be bound to the address. Returns ------- socket : zmq.Socket The configured ZMQ socket. address : str The addresss to which the socket is bound or connected. """ socket: zmq.Socket = ctx.socket(sock_type) norm_address = _normalize_address(address) logger.debug(f"Creating socket of type {sock_type} for address {norm_address}, bind={bind}") random_port = False if norm_address.startswith("tcp"): if ":" not in norm_address[6:]: random_port = True if curve is not None: if bind: # Server mode - expect ServerCurve if not isinstance(curve, ServerCurve): raise TypeError("When bind=True, curve must be a ServerCurve instance") logger.debug(f"Configuring CURVE server security with secret_path={curve.secret_path}") # build authenticator auth = ThreadAuthenticator(ctx) auth.start() logger.debug("Started ZMQ authenticator") if curve.allow is not None: auth.allow(*curve.allow) logger.debug(f"Configured IP address allowlist: {curve.allow}") # Tell the authenticator how to handle CURVE requests if curve.client_public_keys is None: # accept any client that knows the public key auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) logger.debug("Configured CURVE to allow any client with valid public key") else: auth.configure_curve(domain="*", location=curve.client_public_keys) logger.debug(f"Configured CURVE client public keys from: {curve.client_public_keys}") # get public and private keys from the certificate server_public, server_secret = zmq.auth.load_certificate(curve.secret_path) if server_secret is None: raise ValueError("The server secret key could not be found.") # attach them to the socket socket.setsockopt(zmq.CURVE_PUBLICKEY, server_public) socket.setsockopt(zmq.CURVE_SECRETKEY, server_secret) socket.setsockopt(zmq.CURVE_SERVER, True) logger.debug("Applied CURVE keys and enabled CURVE server mode") else: # Client mode - expect ClientCurve if not isinstance(curve, ClientCurve): raise TypeError("When bind=False, curve must be a ClientCurve instance") logger.debug(f"Configuring CURVE client security with secret_path={curve.secret_path}") # Load the client cert pair client_public, client_secret = zmq.auth.load_certificate(curve.secret_path) socket.setsockopt(zmq.CURVE_PUBLICKEY, client_public) if client_secret is None: raise ValueError("The client secret key could not be found.") socket.setsockopt(zmq.CURVE_SECRETKEY, client_secret) # Load the server public key and register with the socket server_key, _ = zmq.auth.load_certificate(curve.server_public_key) socket.setsockopt(zmq.CURVE_SERVERKEY, server_key) logger.debug("Applied CURVE client keys and server public key") if bind: if random_port: port = socket.bind_to_random_port(norm_address) final_address = norm_address + ":" + str(port) logger.debug(f"Bound to random port: {port}") else: final_address = socket.bind(norm_address).addr logger.debug(f"Bound to address: {norm_address}") else: final_address = socket.connect(norm_address).addr logger.debug(f"Connected to address: {norm_address}") logger.debug(f"Socket configured with final address: {final_address}") return socket, final_address def __init__( self, in_address: str | tuple[str, int] | None = None, out_address: str | tuple[str, int] | None = None, *, in_curve: ServerCurve | ClientCurve | None = None, out_curve: ServerCurve | ClientCurve | None = None, in_bind: bool = True, out_bind: bool = True, in_port: int | None = None, out_port: int | None = None, ): # Handle backward compatibility for in_port -> in_address if in_port is not None and in_address is not None: raise ValueError("Cannot specify both 'in_port' and 'in_address'. Use 'in_address' only.") if in_port is not None: warnings.warn( "The 'in_port' parameter is deprecated and will be removed in a future release. " "Use 'in_address' instead.", DeprecationWarning, stacklevel=2, ) in_address = f"localhost:{in_port}" # Handle backward compatibility for out_port -> out_address if out_port is not None and out_address is not None: raise ValueError("Cannot specify both 'out_port' and 'out_address'. Use 'out_address' only.") if out_port is not None: warnings.warn( "The 'out_port' parameter is deprecated and will be removed in a future release. " "Use 'out_address' instead.", DeprecationWarning, stacklevel=2, ) out_address = f"localhost:{out_port}" # Delete deprecated parameter names del in_port, out_port self.closed = False try: context = zmq.Context() frontend, self.in_port = self.configure_server_socket( context, zmq.SUB, in_address, in_curve, bind=in_bind ) frontend.setsockopt_string(zmq.SUBSCRIBE, "") backend, self.out_port = self.configure_server_socket( context, zmq.PUB, out_address, out_curve, bind=out_bind ) except BaseException: # Clean up whichever components we have defined so far. try: frontend.close() except NameError: ... try: backend.close() except NameError: ... try: context.destroy() except NameError: ... raise else: self._frontend = frontend self._backend = backend self._context = context def start(self): if self.closed: raise RuntimeError( f"This Proxy has already been started and interrupted. Create a fresh instance with {repr(self)}" ) try: zmq.device(zmq.FORWARDER, self._frontend, self._backend) finally: self.closed = True self._frontend.close() self._backend.close() self._context.destroy() def __repr__(self): return "{}(in_port={in_port}, out_port={out_port})".format(type(self).__name__, **vars(self))
[docs] class RemoteDispatcher(Dispatcher): """ Dispatch documents received over the network from a 0MQ proxy. Parameters ---------- address : tuple Address of a running 0MQ proxy, given either as a string like ``'127.0.0.1:5567'`` or as a tuple like ``('127.0.0.1', 5567)`` prefix : bytes, optional User-defined bytestring used to distinguish between multiple Publishers. If set, messages without this prefix will be ignored. If unset, no mesages will be ignored. loop : zmq.asyncio.ZMQEventLoop, optional optional event loop to use. Default is to create a new event loop. deserializer: function, optional optional function to deserialize data. Default is pickle.loads Examples -------- Print all documents generated by remote RunEngines. >>> d = RemoteDispatcher(('localhost', 5568)) >>> d.subscribe(print) >>> d.start() # runs until interrupted """ def __init__( self, address: str | tuple[str, int], *, prefix: bytes = b"", loop: asyncio.AbstractEventLoop | None = None, deserializer: Callable = pickle.loads, strict: bool = False, curve_config: ServerCurve | ClientCurve | None = None, ): if isinstance(prefix, str): raise ValueError("prefix must be bytes, not string") if b" " in prefix: raise ValueError(f"prefix {prefix!r} may not contain b' '") self._prefix = prefix self._deserializer = deserializer self.address = _normalize_address(address) if loop is None: loop = asyncio.new_event_loop() self.loop = loop self._context = None self._socket = None def __finish_setup(): asyncio.set_event_loop(self.loop) self._context = zmq_asyncio.Context() self._socket = sock = self._context.socket(zmq.SUB) if curve_config is not None: # Load the client cert pair client_public, client_secret = zmq.auth.load_certificate(curve_config.secret_path) sock.setsockopt(zmq.CURVE_PUBLICKEY, client_public) if client_secret is None: raise ValueError("The client secret key could not be found.") sock.setsockopt(zmq.CURVE_SECRETKEY, client_secret) # Load the server public key and register with the socket server_key, _ = zmq.auth.load_certificate(curve_config.server_public_key) sock.setsockopt(zmq.CURVE_SERVERKEY, server_key) self._socket.connect(self.address) self._socket.setsockopt_string(zmq.SUBSCRIBE, "") self.__factory = __finish_setup self._task = None self._stopped = threading.Event() self.closed = False self._strict = strict super().__init__() async def _poll(self): our_prefix = self._prefix # local var to save an attribute lookup while True: message = await self._socket.recv() try: prefix, name, doc = message.split(b" ", 2) except ValueError as e: if self._strict: raise Bluesky0MQDecodeError from e else: print( f"The message {message} could not be split into " "three parts by b' '. Dropping message on floor " "and continuing" f"\n\n{e}" ) continue try: name = name.decode() except UnicodeDecodeError as e: if self._strict: raise Bluesky0MQDecodeError from e else: print( f"The name {name} can not be decoded as utf-8. " "Dropping message on the floor and continuing. " f"\n\n{e}" ) continue if (not our_prefix) or prefix == our_prefix: try: doc = self._deserializer(doc) except Exception as e: if self._strict: raise Bluesky0MQDecodeError from e else: if len(doc) > 1024: msg_doc = doc[:1024] + b"--SNIPPED--" else: msg_doc = doc print( f"Failed to deserialize the {name} document " f"{msg_doc} using {self._deserializer}. " "Dropping on floor and continuing" f"\n\n{e}" ) continue self.loop.call_soon(self.process, DocumentNames[name], doc) def start(self): if self.closed: raise RuntimeError( "This RemoteDispatcher has already been " "started and interrupted. Create a fresh " f"instance with {self!r}" ) try: self.__factory() self._task = self.loop.create_task(self._poll()) self.loop.run_until_complete(self._task) task_exception = self._task.exception() if task_exception is not None: raise task_exception finally: # The loop has stopped, so tear everything down here and signal # any thread blocked in ``stop`` that cleanup is complete. self._cleanup() def _cleanup(self): # Release the task, socket, context and loop. Safe to call only once # the loop has stopped; closing a running loop raises RuntimeError. if self._task is not None: self._task.cancel() self._task = None if self._socket is not None: self._socket.close() self._socket = None if self._context is not None: self._context.destroy() self._context = None if not self.loop.is_closed(): self.loop.close() self.closed = True self._stopped.set() def stop(self): """Stop the dispatcher and wait for it to finish. When called from a thread other than the one running the event loop, task cancellation is scheduled on the loop and this method blocks until :meth:`start` has torn the dispatcher down. """ if self.loop.is_running(): if self._task is not None: self.loop.call_soon_threadsafe(self._task.cancel) self._stopped.wait()