Source code for bluesky_adaptive.agents.tsuchinoko
import pickle
import time
import warnings
from abc import ABC
from collections.abc import Sequence
from logging import getLogger
import numpy as np
import zmq
from numpy._typing import ArrayLike
from .base import Agent
logger = getLogger("bluesky_adaptive.agents")
SLEEP_FOR_AGENT_TIME = 0.1
SLEEP_FOR_TSUCHINOKO_TIME = 0.1
FORCE_KICKSTART_TIME = 5
class TsuchinokoBase:
def __init__(self, *args, host: str = "127.0.0.1", port: int = 5557, **kwargs):
"""
Parameters
----------
args
args passed through to `bluesky_adaptive.agents.base.Agent.__init__()`
host
A host address target for the zmq socket.
port
The port used for the zmq socket.
kwargs
kwargs passed through to `bluesky_adaptive.agents.base.Agent.__init__()`
"""
super().__init__(*args, **kwargs)
self.host = host
self.port = port
self.outbound_measurements = []
self.context = None
self.socket = None
self.setup_socket()
self.last_targets_received = time.time()
self.kickstart()
def kickstart(self):
self.send_payload({"send_targets": True}) # kickstart to recover from shutdowns
self.last_targets_received = time.time() # forgive lack of response until now
def setup_socket(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PAIR)
# Attempt to connect, retry every second if fails
while True:
try:
self.socket.connect(f"tcp://{self.host}:{self.port}")
except zmq.ZMQError:
logger.info(f"Unable to connect to tcp://{self.host}:{self.port}. Retrying in 1 second...")
time.sleep(1)
else:
logger.info(f"Connected to tcp://{self.host}:{self.port}.")
break
def ingest(self, x, yv):
"""
Send measurement to BlueskyAdaptiveEngine
"""
payload = {"target_measured": (x, yv)}
self.send_payload(payload)
def suggest(self, batch_size: int = 1) -> Sequence[ArrayLike]:
"""
Wait until at least one target is received, also exhaust the queue of
received targets, overwriting old ones
"""
payload = None
while True:
try:
payload = self.recv_payload(flags=zmq.NOBLOCK)
except zmq.ZMQError:
if payload is not None:
break
else:
time.sleep(SLEEP_FOR_TSUCHINOKO_TIME)
if time.time() > self.last_targets_received + FORCE_KICKSTART_TIME:
self.kickstart()
assert "candidate" in payload
self.last_targets_received = time.time()
return payload
def send_payload(self, payload: dict):
logger.info(f"message: {payload}")
self.socket.send(pickle.dumps(payload))
def recv_payload(self, flags=0) -> dict:
payload_response = pickle.loads(self.socket.recv(flags=flags))
logger.info(f"response: {payload_response}")
return payload_response
[docs]
class TsuchinokoAgent(TsuchinokoBase, Agent, ABC):
"""
A Bluesky-Adaptive 'Agent'. This Agent communicates with Tsuchinoko over zmq
to request new targets and report back measurements. This is an abstract
class that must be subclassed.
A `tsuchinoko.execution.bluesky_adaptive.BlueskyAdaptiveEngine` is required
for the Tsuchinoko server to complement one of these `TsuchinokoAgent`.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._targets_shape = None
def ingest(self, x, yv) -> dict[str, ArrayLike]:
super().ingest(x, yv)
return self.get_ingest_document(x, yv)
def suggest(self, batch_size: int = 1) -> tuple[Sequence[dict[str, ArrayLike]], Sequence[ArrayLike]]:
targets = super().suggest(batch_size)
optimizer_state = targets.pop("optimizer")
return self.get_suggest_documents(targets, optimizer_state), targets
def get_ingest_document(self, x, yv) -> dict[str, ArrayLike]:
"""
Return any single document corresponding to 'tell'-ing Tsuchinoko about the newly measured `x`, `y` data
Parameters
----------
x :
Independent variable for data observed
yv :
Dependent variable for data observed, concatenated with variance
Returns
-------
dict
Dictionary to be unpacked or added to a document
"""
y, v = yv
return {"independent": np.asarray(x), "observable": np.asarray(y), "variance": np.asarray(v)}
def get_suggest_documents(
self, targets: Sequence[ArrayLike], optimizer_state: dict
) -> Sequence[dict[str, ArrayLike]]:
"""
Ask the agent for a new batch of points to measure.
Parameters
----------
targets : List[Tuple]
The new target positions to be measured received during this `ask`.
optimizer_state: Dict
The serialized state of a GPOptimizer instance
Returns
-------
docs : Sequence[dict]
Documents of key metadata from the ask approach for each point in next_points.
Must be length of batch size.
"""
# check if targets length changes
if not self._targets_shape:
self._targets_shape = len(targets)
if self._targets_shape != len(targets):
warnings.warn(
"The length of the target queue has changed. A new databroker run will be generated", stacklevel=2
)
self.close_and_restart()
return [targets | optimizer_state]