Source code for bluesky_adaptive.agents.botorch

"""
Basic BoTorch functionality, primarily for examples.

These mixins act to fufill the abstract methods of blusky_adaptive.agents.Agent that are relevant to
the decision making, and not the experimental specifics.

Children will need to implement the following:
Experiment specific:
    - measurement_plan_name
    - measurement_plan_args
    - measurement_plan_kwargs
    - unpack_run
"""

import importlib
from abc import ABC
from logging import getLogger
from typing import Callable, Optional, Tuple

import torch
from botorch import fit_gpytorch_mll
from botorch.acquisition import AcquisitionFunction, UpperConfidenceBound
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from databroker.client import BlueskyRun
from gpytorch.mlls import ExactMarginalLogLikelihood

from bluesky_adaptive.agents.base import Agent

logger = getLogger("bluesky_adaptive.agents")


[docs] class SingleTaskGPAgentBase(Agent, ABC): def __init__( self, *, bounds: torch.Tensor, gp: SingleTaskGP = None, device: torch.device = None, out_dim=1, partial_acq_function: Optional[Callable] = None, num_restarts: int = 10, raw_samples: int = 20, **kwargs, ): """Single Task GP based Bayesian Optimization Parameters ---------- bounds : torch.Tensor A `2 x d` tensor of lower and upper bounds for each column of `X` gp : SingleTaskGP, optional GP surrogate model to use, by default uses BoTorch default device : torch.device, optional Device, by default cuda if avail out_dim : int, optional Dimension of output predictions by surrogate model, by default 1 partial_acq_function : Optional[Callable], optional Partial acquisition function that will take a single argument of a conditioned surrogate model. By default UCB with beta at 0.1 num_restarts : int, optional Number of restarts for optimizing the acquisition function, by default 10 raw_samples : int, optional Number of samples used to instantiate the initial conditions of the acquisition function optimizer. For a discussion of num_restarts vs raw_samples, see: https://github.com/pytorch/botorch/issues/366 Defaults to 20. """ super().__init__(**kwargs) self.inputs = None self.targets = None self.device = ( torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) ) self.bounds = torch.tensor(bounds, device=self.device).view(2, -1) if gp is None: dummy_x, dummy_y = torch.randn(2, self.bounds.shape[-1], device=self.device), torch.randn( 2, out_dim, device=self.device ) gp = SingleTaskGP(dummy_x, dummy_y) self.surrogate_model = gp self.mll = ExactMarginalLogLikelihood(self.surrogate_model.likelihood, self.surrogate_model) self.surrogate_model.to(self.device) self.mll.to(self.device) if partial_acq_function is None: self._partial_acqf = lambda gp: UpperConfidenceBound(gp, beta=0.1) self.acqf_name = "UpperConfidenceBound" else: self._partial_acqf = partial_acq_function self.acqf_name = "custom" self.num_restarts = num_restarts self.raw_samples = raw_samples def server_registrations(self) -> None: super().server_registrations() self._register_method("update_acquisition_function") def update_acquisition_function(self, acqf_name, **kwargs): module = importlib.import_module("botorch.acquisition") self.acqf_name = acqf_name self._partial_acqf = lambda gp: getattr(module, acqf_name)(gp, **kwargs) self.close_and_restart() def start(self, *args, **kwargs): _md = dict(acqf_name=self.acqf_name) self.metadata.update(_md) super().start(*args, **kwargs) def tell(self, x, y): if self.inputs is None: self.inputs = torch.atleast_2d(torch.tensor(x, device=self.device)) self.targets = torch.atleast_1d(torch.tensor(y, device=self.device)) else: self.inputs = torch.cat([self.inputs, torch.atleast_2d(torch.tensor(x, device=self.device))], dim=0) self.targets = torch.cat([self.targets, torch.atleast_1d(torch.tensor(y, device=self.device))], dim=0) self.inputs.to(self.device) self.targets.to(self.device) self.surrogate_model.set_train_data(self.inputs, self.targets, strict=False) return dict(independent_variable=x, observable=y, cache_len=len(self.targets)) def report(self): """Fit GP, and construct acquisition function. Document retains state dictionary. """ fit_gpytorch_mll(self.mll) acqf = self._partial_acqf(self.surrogate_model) return dict( latest_data=self.tell_cache[-1], cache_len=self.inputs.shape[0], **{ "STATEDICT-" + ":".join(key.split(".")): val.detach().cpu().numpy() for key, val in acqf.state_dict().items() }, ) def ask(self, batch_size=1): """Fit GP, optimize acquisition function, and return next points. Document retains candidate, acquisition values, and state dictionary. """ if batch_size > 1: logger.warning(f"Batch size greater than 1 is not implemented. Reducing {batch_size} to 1.") batch_size = 1 fit_gpytorch_mll(self.mll) acqf = self._partial_acqf(self.surrogate_model) acqf.to(self.device) candidate, acq_value = optimize_acqf( acq_function=acqf, bounds=self.bounds, q=batch_size, num_restarts=self.num_restarts, raw_samples=self.raw_samples, ) return ( [ dict( candidate=candidate.detach().cpu().numpy(), acquisition_value=acq_value.detach().cpu().numpy(), latest_data=self.tell_cache[-1], cache_len=self.inputs.shape[0], **{ "STATEDICT-" + ":".join(key.split(".")): val.detach().cpu().numpy() for key, val in acqf.state_dict().items() }, ) ], torch.atleast_1d(candidate).detach().cpu().numpy(), ) def remodel_from_report(self, run: BlueskyRun, idx: int = None) -> Tuple[AcquisitionFunction, SingleTaskGP]: idx = -1 if idx is None else idx keys = [key for key in run.report["data"].keys() if key.split("-")[0] == "STATEDICT"] state_dict = {".".join(key[10:].split(":")): torch.tensor(run.report["data"][key][idx]) for key in keys} acqf = self._partial_acqf(self.surrogate_model) acqf.load_state_dict(state_dict) acqf.to(self.device) return acqf, acqf.model