"""
The key classes needed to use 0MQ for multiprocess document communication.
`Publisher` : subscribe this to the RE to emit the documents. Expects a server to
have a SUBSCRIBE port open to PUB to.
`RemoteDispatcher` : subscribe callbacks to this class in a remote process. Expects
a server to have a PUB port open to SUBSCRIBE to.
`Proxy` : server that binds ports for Pubslisher to push to and the Dispatcher
to listen to. Typically this is started with the cli tool ``bluesky-zmq-proxy``
"""
import asyncio
import copy
import logging
import pickle
import warnings
from collections.abc import Callable
from pathlib import Path
from typing import Any, NamedTuple
import zmq
import zmq.asyncio as zmq_asyncio
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator
from ..run_engine import Dispatcher, DocumentNames
logger = logging.getLogger(__name__)
class ServerCurve(NamedTuple):
# path to the secret key for the server
secret_path: Path
# path to folder of client's public keys. If None, allow all clients
client_public_keys: Path | None
# set of ip addresses to allow
allow: set[str] | None
class ClientCurve(NamedTuple):
# path to the secret key for the server
secret_path: Path
# path to the servers public key
server_public_key: Path
def _normalize_address(inp: str | tuple | int | None):
if isinstance(inp, str):
if "://" in inp:
protocol, _, rest_str = inp.partition("://")
else:
protocol = "tcp"
rest_str = inp
elif isinstance(inp, tuple):
if inp[0] in ["tcp", "ipc"]:
protocol, *rest = inp
else:
protocol = "tcp"
rest = list(inp)
if protocol == "tcp":
if len(rest) == 2:
rest_str = ":".join(str(r) for r in rest)
else:
(rest_str,) = rest
else:
(rest_str,) = rest
elif isinstance(inp, int):
protocol = "tcp"
rest_str = f"0.0.0.0:{inp}"
elif inp is None:
protocol = "tcp"
rest_str = "*"
else:
raise TypeError(f"Input expected to be int, str, or tuple, not {type(inp).__name__}")
return f"{protocol}://{rest_str}"
class Bluesky0MQDecodeError(Exception):
"""Custom exception class for things that go wrong reading message from wire."""
...
[docs]
class Publisher:
"""
A callback that publishes documents to a 0MQ proxy.
Parameters
----------
address : string or tuple
Address of a running 0MQ proxy, given either as a string like
``'127.0.0.1:5567'`` or as a tuple like ``('127.0.0.1', 5567)``
prefix : bytes, optional
User-defined bytestring used to distinguish between multiple
Publishers. May not contain b' '.
serializer: function, optional
optional function to serialize data. Default is pickle.dumps
curve_config: ClientCurve, optional
CURVE security configuration for client authentication.
Examples
--------
Publish from a RunEngine to a Proxy running on localhost on port 5567.
>>> publisher = Publisher('localhost:5567')
>>> RE = RunEngine({})
>>> RE.subscribe(publisher)
"""
def __init__(
self,
address: str | tuple[str, int],
*,
prefix: bytes = b"",
serializer: Callable = pickle.dumps,
curve_config: ClientCurve | None = None,
):
if isinstance(prefix, str):
raise ValueError("prefix must be bytes, not string")
if b" " in prefix:
raise ValueError(f"prefix {prefix!r} may not contain b' '")
self.address = _normalize_address(address)
self._prefix = bytes(prefix)
self._context = zmq.Context()
self._socket = self._context.socket(zmq.PUB)
if curve_config is not None:
# Load the client cert pair
client_public, client_secret = zmq.auth.load_certificate(curve_config.secret_path)
self._socket.setsockopt(zmq.CURVE_PUBLICKEY, client_public)
if client_secret is None:
raise ValueError("The client secret key could not be found.")
self._socket.setsockopt(zmq.CURVE_SECRETKEY, client_secret)
# Load the server public key and register with the socket
server_key, _ = zmq.auth.load_certificate(curve_config.server_public_key)
self._socket.setsockopt(zmq.CURVE_SERVERKEY, server_key)
self._socket.connect(self.address)
self._serializer = serializer
def __call__(self, name: str, doc: dict[str, Any]):
doc = copy.deepcopy(doc)
message = b" ".join([self._prefix, name.encode(), self._serializer(doc)])
self._socket.send(message)
def close(self):
self._socket.close()
self._context.destroy() # close Socket(s); terminate Context
[docs]
class Proxy:
"""
Start a 0MQ proxy on the local host.
The addresses can be specified flexibly. It is best to use
a domain_socket (available on unix):
- ``'icp:///tmp/domain_socket'``
- ``('ipc', '/tmp/domain_socket')``
tcp sockets are also supported:
- ``'tcp://localhost:6557'``
- ``6657`` (implicitly binds to ``'tcp://localhost:6557'``
- ``('tcp', 'localhost', 6657)``
- ``('localhost', 6657)``
Parameters
----------
in_address : str or tuple or int, optional
Address that RunEngines should broadcast to.
If None, a random tcp port on all interfaces is used.
out_address : str or tuple or int, optional
Address that subscribers should subscribe to.
If None, a random tcp port on all interfaces is used.
in_curve: ServerCurve or ClientCurve or None, optional
CURVE security configuration for the incoming socket. If None, no security is applied.
out_curve: ServerCurve or ClientCurve or None, optional
CURVE security configuration for the outgoing socket. If None, no security is applied.
in_bind: bool, default True
If True, the incoming socket will be bound to the address.
out_bind: bool, default True
If True, the outgoing socket will be bound to the address.
in_port: int or None, optional
DEPRECATED alias for in_address. If specified, must be used instead of in_address
out_port: int or None, optional
DEPRECATED alias for out_address. If specified, must be used instead of out_address
Attributes
----------
in_address: int or str or tuple
Port that RunEngines should broadcast to.
out_address : int or str or tuple
Port that subscribers should subscribe to.
closed : boolean
True if the Proxy has already been started and subsequently
interrupted and is therefore unusable.
Examples
--------
Run on specific ports.
>>> proxy = Proxy(in_address='localhost:5567', out_address='localhost:5568')
>>> proxy
Proxy(in_port=5567, out_port=5568)
>>> proxy.start() # runs until interrupted
Run on random ports, and access those ports before starting.
>>> proxy = Proxy()
>>> proxy
Proxy(in_port=56504, out_port=56505)
>>> proxy.in_port
56504
>>> proxy.out_port
56505
>>> proxy.start() # runs until interrupted
"""
@staticmethod
def configure_server_socket(
ctx: zmq.Context,
sock_type: int,
address: str | tuple | int | None,
curve: ServerCurve | ClientCurve | None,
bind: bool = True,
) -> tuple[zmq.Socket, int | str]:
"""Helper method to create and bind or connect a socket with optional CURVE security.
Parameters
----------
ctx : zmq.Context
The ZMQ context to use for creating the socket.
sock_type : int
The type of socket to create (e.g. zmq.SUB, zmq.PUB).
address : str | tuple | int | None
The address to bind or connect the socket to.
curve : ServerCurve | ClientCurve | None
CURVE security configuration. If None, no security is applied.
bind : bool, default True
If True, the socket will be bound to the address.
Returns
-------
socket : zmq.Socket
The configured ZMQ socket.
address : str
The addresss to which the socket is bound or connected.
"""
socket: zmq.Socket = ctx.socket(sock_type)
norm_address = _normalize_address(address)
logger.debug(f"Creating socket of type {sock_type} for address {norm_address}, bind={bind}")
random_port = False
if norm_address.startswith("tcp"):
if ":" not in norm_address[6:]:
random_port = True
if curve is not None:
if bind:
# Server mode - expect ServerCurve
if not isinstance(curve, ServerCurve):
raise TypeError("When bind=True, curve must be a ServerCurve instance")
logger.debug(f"Configuring CURVE server security with secret_path={curve.secret_path}")
# build authenticator
auth = ThreadAuthenticator(ctx)
auth.start()
logger.debug("Started ZMQ authenticator")
if curve.allow is not None:
auth.allow(*curve.allow)
logger.debug(f"Configured IP address allowlist: {curve.allow}")
# Tell the authenticator how to handle CURVE requests
if curve.client_public_keys is None:
# accept any client that knows the public key
auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY)
logger.debug("Configured CURVE to allow any client with valid public key")
else:
auth.configure_curve(domain="*", location=curve.client_public_keys)
logger.debug(f"Configured CURVE client public keys from: {curve.client_public_keys}")
# get public and private keys from the certificate
server_public, server_secret = zmq.auth.load_certificate(curve.secret_path)
if server_secret is None:
raise ValueError("The server secret key could not be found.")
# attach them to the socket
socket.setsockopt(zmq.CURVE_PUBLICKEY, server_public)
socket.setsockopt(zmq.CURVE_SECRETKEY, server_secret)
socket.setsockopt(zmq.CURVE_SERVER, True)
logger.debug("Applied CURVE keys and enabled CURVE server mode")
else:
# Client mode - expect ClientCurve
if not isinstance(curve, ClientCurve):
raise TypeError("When bind=False, curve must be a ClientCurve instance")
logger.debug(f"Configuring CURVE client security with secret_path={curve.secret_path}")
# Load the client cert pair
client_public, client_secret = zmq.auth.load_certificate(curve.secret_path)
socket.setsockopt(zmq.CURVE_PUBLICKEY, client_public)
if client_secret is None:
raise ValueError("The client secret key could not be found.")
socket.setsockopt(zmq.CURVE_SECRETKEY, client_secret)
# Load the server public key and register with the socket
server_key, _ = zmq.auth.load_certificate(curve.server_public_key)
socket.setsockopt(zmq.CURVE_SERVERKEY, server_key)
logger.debug("Applied CURVE client keys and server public key")
if bind:
if random_port:
port = socket.bind_to_random_port(norm_address)
final_address = norm_address + ":" + str(port)
logger.debug(f"Bound to random port: {port}")
else:
final_address = socket.bind(norm_address).addr
logger.debug(f"Bound to address: {norm_address}")
else:
final_address = socket.connect(norm_address).addr
logger.debug(f"Connected to address: {norm_address}")
logger.debug(f"Socket configured with final address: {final_address}")
return socket, final_address
def __init__(
self,
in_address: str | tuple[str, int] | None = None,
out_address: str | tuple[str, int] | None = None,
*,
in_curve: ServerCurve | ClientCurve | None = None,
out_curve: ServerCurve | ClientCurve | None = None,
in_bind: bool = True,
out_bind: bool = True,
in_port: int | None = None,
out_port: int | None = None,
):
# Handle backward compatibility for in_port -> in_address
if in_port is not None and in_address is not None:
raise ValueError("Cannot specify both 'in_port' and 'in_address'. Use 'in_address' only.")
if in_port is not None:
warnings.warn(
"The 'in_port' parameter is deprecated and will be removed in a future release. "
"Use 'in_address' instead.",
DeprecationWarning,
stacklevel=2,
)
in_address = f"localhost:{in_port}"
# Handle backward compatibility for out_port -> out_address
if out_port is not None and out_address is not None:
raise ValueError("Cannot specify both 'out_port' and 'out_address'. Use 'out_address' only.")
if out_port is not None:
warnings.warn(
"The 'out_port' parameter is deprecated and will be removed in a future release. "
"Use 'out_address' instead.",
DeprecationWarning,
stacklevel=2,
)
out_address = f"localhost:{out_port}"
# Delete deprecated parameter names
del in_port, out_port
self.closed = False
try:
context = zmq.Context()
frontend, self.in_port = self.configure_server_socket(
context, zmq.SUB, in_address, in_curve, bind=in_bind
)
frontend.setsockopt_string(zmq.SUBSCRIBE, "")
backend, self.out_port = self.configure_server_socket(
context, zmq.PUB, out_address, out_curve, bind=out_bind
)
except BaseException:
# Clean up whichever components we have defined so far.
try:
frontend.close()
except NameError:
...
try:
backend.close()
except NameError:
...
try:
context.destroy()
except NameError:
...
raise
else:
self._frontend = frontend
self._backend = backend
self._context = context
def start(self):
if self.closed:
raise RuntimeError(
f"This Proxy has already been started and interrupted. Create a fresh instance with {repr(self)}"
)
try:
zmq.device(zmq.FORWARDER, self._frontend, self._backend)
finally:
self.closed = True
self._frontend.close()
self._backend.close()
self._context.destroy()
def __repr__(self):
return "{}(in_port={in_port}, out_port={out_port})".format(type(self).__name__, **vars(self))
[docs]
class RemoteDispatcher(Dispatcher):
"""
Dispatch documents received over the network from a 0MQ proxy.
Parameters
----------
address : tuple
Address of a running 0MQ proxy, given either as a string like
``'127.0.0.1:5567'`` or as a tuple like ``('127.0.0.1', 5567)``
prefix : bytes, optional
User-defined bytestring used to distinguish between multiple
Publishers. If set, messages without this prefix will be ignored.
If unset, no mesages will be ignored.
loop : zmq.asyncio.ZMQEventLoop, optional
optional event loop to use. Default is to create a new event loop.
deserializer: function, optional
optional function to deserialize data. Default is pickle.loads
Examples
--------
Print all documents generated by remote RunEngines.
>>> d = RemoteDispatcher(('localhost', 5568))
>>> d.subscribe(print)
>>> d.start() # runs until interrupted
"""
def __init__(
self,
address: str | tuple[str, int],
*,
prefix: bytes = b"",
loop: asyncio.AbstractEventLoop | None = None,
deserializer: Callable = pickle.loads,
strict: bool = False,
curve_config: ServerCurve | ClientCurve | None = None,
):
if isinstance(prefix, str):
raise ValueError("prefix must be bytes, not string")
if b" " in prefix:
raise ValueError(f"prefix {prefix!r} may not contain b' '")
self._prefix = prefix
self._deserializer = deserializer
self.address = _normalize_address(address)
if loop is None:
loop = asyncio.new_event_loop()
self.loop = loop
self._context = None
self._socket = None
def __finish_setup():
asyncio.set_event_loop(self.loop)
self._context = zmq_asyncio.Context()
self._socket = sock = self._context.socket(zmq.SUB)
if curve_config is not None:
# Load the client cert pair
client_public, client_secret = zmq.auth.load_certificate(curve_config.secret_path)
sock.setsockopt(zmq.CURVE_PUBLICKEY, client_public)
if client_secret is None:
raise ValueError("The client secret key could not be found.")
sock.setsockopt(zmq.CURVE_SECRETKEY, client_secret)
# Load the server public key and register with the socket
server_key, _ = zmq.auth.load_certificate(curve_config.server_public_key)
sock.setsockopt(zmq.CURVE_SERVERKEY, server_key)
self._socket.connect(self.address)
self._socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.__factory = __finish_setup
self._task = None
self.closed = False
self._strict = strict
super().__init__()
async def _poll(self):
our_prefix = self._prefix # local var to save an attribute lookup
while True:
message = await self._socket.recv()
try:
prefix, name, doc = message.split(b" ", 2)
except ValueError as e:
if self._strict:
raise Bluesky0MQDecodeError from e
else:
print(
f"The message {message} could not be split into "
"three parts by b' '. Dropping message on floor "
"and continuing"
f"\n\n{e}"
)
continue
try:
name = name.decode()
except UnicodeDecodeError as e:
if self._strict:
raise Bluesky0MQDecodeError from e
else:
print(
f"The name {name} can not be decoded as utf-8. "
"Dropping message on the floor and continuing. "
f"\n\n{e}"
)
continue
if (not our_prefix) or prefix == our_prefix:
try:
doc = self._deserializer(doc)
except Exception as e:
if self._strict:
raise Bluesky0MQDecodeError from e
else:
if len(doc) > 1024:
msg_doc = doc[:1024] + b"--SNIPPED--"
else:
msg_doc = doc
print(
f"Failed to deserialize the {name} document "
f"{msg_doc} using {self._deserializer}. "
"Dropping on floor and continuing"
f"\n\n{e}"
)
continue
self.loop.call_soon(self.process, DocumentNames[name], doc)
def start(self):
if self.closed:
raise RuntimeError(
"This RemoteDispatcher has already been "
"started and interrupted. Create a fresh "
f"instance with {self!r}"
)
try:
self.__factory()
self._task = self.loop.create_task(self._poll())
self.loop.run_until_complete(self._task)
task_exception = self._task.exception()
if task_exception is not None:
raise task_exception
finally:
self.stop()
def stop(self):
if self._task is not None:
self._task.cancel()
if self._socket is not None:
self._socket.close()
if self._context is not None:
self._context.destroy()
self.loop.close()
self.closed = True