import copy
import pprint
import warnings
from collections import namedtuple
import numpy as np
from .core import CallbackBase, CollectThenCompute
[docs]
class LiveFit(CallbackBase):
"""
Fit a model to data using nonlinear least-squares minimization.
Parameters
----------
model : lmfit.Model
y : string
name of the field in the Event document that is the dependent variable
independent_vars : dict
map the independent variable name(s) in the model to the field(s)
in the Event document; e.g., ``{'x': 'motor'}``
init_guess : dict, optional
initial guesses for other values, if expected by model;
e.g., ``{'sigma': 1}``
update_every : int or None, optional
How often to recompute the fit. If `None`, do not compute until the
end. Default is 1 (recompute after each new point).
Attributes
----------
result : lmfit.ModelResult
"""
def __init__(self, model, y, independent_vars, init_guess=None, *, update_every=1):
self.ydata = []
self.independent_vars_data = {}
self.__stale = False
self.result = None
self._model = model
self.y = y
self.independent_vars = independent_vars
if init_guess is None:
init_guess = {}
self.init_guess = init_guess
self.update_every = update_every
@property
def model(self):
# Make this a property so it can't be updated.
return self._model
@property
def independent_vars(self):
return self._independent_vars
@independent_vars.setter
def independent_vars(self, val):
if set(val) != set(self.model.independent_vars):
raise ValueError(
"keys {} must match the independent variables in the model {}".format( # noqa: UP032
set(val), set(self.model.independent_vars)
)
)
self._independent_vars = val
self.independent_vars_data.clear()
self.independent_vars_data.update({k: [] for k in val})
self._reset()
def _reset(self):
self.result = None
self.__stale = False
self.ydata.clear()
for v in self.independent_vars_data.values():
v.clear()
def start(self, doc):
self._reset()
super().start(doc)
def event(self, doc):
if self.y not in doc["data"]:
return
y = doc["data"][self.y]
idv = {k: doc["data"][v] for k, v in self.independent_vars.items()}
# Always stash the data for the next time the fit is updated.
self.update_caches(y, idv)
self.__stale = True
# Maybe update the fit or maybe wait.
if self.update_every is not None:
i = len(self.ydata)
N = len(self.model.param_names)
if i < N:
# not enough points to fit yet
pass
elif (i == N) or ((i - 1) % self.update_every == 0):
self.update_fit()
super().event(doc)
def stop(self, doc):
# Update the fit if it was not updated by the last event.
if self.__stale:
self.update_fit()
super().stop(doc)
def update_caches(self, y, independent_vars):
self.ydata.append(y)
for k, v in self.independent_vars_data.items():
v.append(independent_vars[k])
def update_fit(self):
N = len(self.model.param_names)
if len(self.ydata) < N:
warnings.warn(
f"LiveFitPlot cannot update fit until there are at least {N} data points",
stacklevel=1,
)
else:
kwargs = {}
kwargs.update(self.independent_vars_data)
kwargs.update(self.init_guess)
self.result = self.model.fit(self.ydata, **kwargs)
self.__stale = False
# This function is vendored from scipy v0.16.1 to avoid adding a scipy
# dependency just for one Python function
def center_of_mass(input, labels=None, index=None):
"""
Calculate the center of mass of the values of an array at labels.
Parameters
----------
input : ndarray
Data from which to calculate center-of-mass.
labels : ndarray, optional
Labels for objects in `input`, as generated by `ndimage.label`.
Only used with `index`. Dimensions must be the same as `input`.
index : int or sequence of ints, optional
Labels for which to calculate centers-of-mass. If not specified,
all labels greater than zero are used. Only used with `labels`.
Returns
-------
center_of_mass : tuple, or list of tuples
Coordinates of centers-of-mass.
Examples
--------
>>> a = np.array(([0,0,0,0],
[0,1,1,0],
[0,1,1,0],
[0,1,1,0]))
>>> from scipy import ndimage
>>> ndimage.measurements.center_of_mass(a)
(2.0, 1.5)
Calculation of multiple objects in an image
>>> b = np.array(([0,1,1,0],
[0,1,0,0],
[0,0,0,0],
[0,0,1,1],
[0,0,1,1]))
>>> lbl = ndimage.label(b)[0]
>>> ndimage.measurements.center_of_mass(b, lbl, [1,2])
[(0.33333333333333331, 1.3333333333333333), (3.5, 2.5)]
"""
normalizer = np.sum(input, labels, index)
grids = np.ogrid[[slice(0, i) for i in input.shape]]
results = [np.sum(input * grids[dir].astype(float), labels, index) / normalizer for dir in range(input.ndim)]
if np.isscalar(results[0]):
return tuple(results)
return [tuple(v) for v in np.array(results).T]
[docs]
class PeakStats(CollectThenCompute):
"""
Compute peak statsitics after a run finishes.
Results are stored in the attributes.
Parameters
----------
x : string
field name for the x variable (e.g., a motor)
y : string
field name for the y variable (e.g., a detector)
calc_derivative_and_stats : bool, optional
calculate derivative of the readings and its stats. False by default.
edge_count : int or None, optional
If not None, number of points at beginning and end to use
for quick and dirty background subtraction.
Notes
-----
It is assumed that the two fields, x and y, are recorded in the same
Event stream.
Attributes
----------
com : center of mass
cen : mid-point between half-max points on each side of the peak
max : x location of y maximum
min : x location of y minimum
crossings : crosses between y and middle line, which is
((np.max(y) + np.min(y)) / 2). Users can estimate FWHM based
on those info.
fwhm : the computed full width half maximum (fwhm) of a peak.
The distance between the first and last crossing is taken to
be the fwhm.
"""
__slots__ = (
"x",
"y",
"x_data",
"y_data",
"stats",
"derivative_stats",
"min",
"max",
"com",
"cen",
"crossings",
"fwhm",
"lin_bkg",
)
def __init__(self, x, y, *, edge_count=None, calc_derivative_and_stats=False):
self.x = x
self.y = y
self._edge_count = edge_count
self._calc_derivative_and_stats = calc_derivative_and_stats
self.stats = None
self.derivative_stats = None
self._stats_fields = {
"min": None,
"max": None,
"com": None,
"cen": None,
"crossings": None,
"fwhm": None,
"lin_bkg": None,
}
for field, value in self._stats_fields.items():
setattr(self, field, value)
super().__init__()
def __getitem__(self, key):
if key in ["x", "y", "stats", "derivative_stats"] + list(self._stats_fields.keys()):
return getattr(self, key)
else:
raise KeyError
def __dict__(self):
return_dict = {}
if self.stats is not None:
return_dict["stats"] = self.stats._asdict()
if self.derivative_stats is not None:
return_dict["derivative_stats"] = self.derivative_stats._asdict()
return return_dict
def __repr__(self):
return pprint.pformat(self.__dict__())
@staticmethod
def _calc_stats(x, y, fields, edge_count=None):
y_orig = np.copy(y)
if edge_count is not None:
left_x = np.mean(x[:edge_count])
left_y = np.mean(y[:edge_count])
right_x = np.mean(x[-edge_count:])
right_y = np.mean(y[-edge_count:])
m = (right_y - left_y) / (right_x - left_x)
b = left_y - m * left_x
y = y - (m * x + b)
fields["lin_bkg"] = {"m": m, "b": b}
argmin_y = np.argmin(y)
argmax_y = np.argmax(y)
fields["min"] = (x[argmin_y], y_orig[argmin_y])
fields["max"] = (x[argmax_y], y_orig[argmax_y])
(fields["com"],) = np.interp(center_of_mass(y), np.arange(len(x)), x)
mid = (np.max(y) + np.min(y)) / 2
crossings = np.where(np.diff((y > mid).astype(int)))[0]
_cen_list = []
for cr in crossings.ravel():
_x = x[cr : cr + 2]
_y = y[cr : cr + 2] - mid
dx = np.diff(_x)[0]
dy = np.diff(_y)[0]
m = dy / dx
_cen_list.append((-_y[0] / m) + _x[0])
if _cen_list:
fields["cen"] = np.mean(_cen_list)
fields["crossings"] = np.array(_cen_list)
if len(_cen_list) >= 2:
fields["fwhm"] = np.abs(fields["crossings"][-1] - fields["crossings"][0], dtype=float)
Stats = namedtuple("Stats", field_names=fields.keys())
stats = Stats(**fields)
return stats
def compute(self):
"This method is called at run-stop time by the superclass."
# clear all results
for field, value in self._stats_fields.items():
setattr(self, field, value)
x = []
y = []
for event in self._events:
try:
_x = event["data"][self.x]
_y = event["data"][self.y]
except KeyError:
pass
else:
x.append(_x)
y.append(_y)
x = np.array(x)
y = np.array(y)
if not len(x):
# nothing to do
return
self.x_data = x
self.y_data = y
stats_fields = copy.deepcopy(self._stats_fields)
self.stats = self._calc_stats(x, y, stats_fields, edge_count=self._edge_count)
for field in self._stats_fields:
setattr(self, field, getattr(self.stats, field))
if self._calc_derivative_and_stats:
# Calculate the derivative stats of the data
x_der = x[1:]
y_der = np.diff(y)
stats_fields = copy.deepcopy(self._stats_fields)
stats_fields.update({"x": x_der, "y": y_der})
self.derivative_stats = self._calc_stats(x_der, y_der, stats_fields, edge_count=self._edge_count)
# reset y data
y = self.y_data