Skip to content

Commit 37bec28

Browse files
committed
Some tweaks to curve_fitting reporting
1 parent e0c79b4 commit 37bec28

3 files changed

Lines changed: 81 additions & 56 deletions

File tree

Stoner/analysis/fitting/classes.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class _Curve_Fit_Result:
172172
as a class to make handling easier.
173173
"""
174174

175-
def __init__(self, popt=None, pcov=None, infodict=None, mesg=None, ier=None):
175+
def __init__(self, infodict=None, mesg=None, ier=None, results=None):
176176
"""Store the results of the curve fit full_output fit.
177177
178178
Args:
@@ -187,30 +187,21 @@ def __init__(self, popt=None, pcov=None, infodict=None, mesg=None, ier=None):
187187
ier (int):
188188
Numerical error message.
189189
"""
190-
self.popt = popt
191-
self.pcov = pcov
192-
if pcov is not None:
193-
self.perr = np.sqrt(np.diag(pcov))
194-
else:
195-
self.perr = None
196-
self.mesg = mesg
197-
self.ier = ier
198-
self.nfev = None
199-
self.fvec = None
200-
self.fjac = None
201-
self.ipvt = None
202-
self.qtf = None
203-
self.func = None
190+
self._mapping = {}
191+
self.func = lambda *args: None
192+
self.f_name = None
193+
self.args = []
194+
self.kwargs = {}
204195
self.p0 = None
205196
self._residual_vals = None
206-
self.chisq = None
207-
self.nfree = None
208197
self.f_name = None
209-
self._infodict = infodict
210-
if infodict:
211-
for k in infodict:
212-
setattr(self, k, infodict[k])
213-
self.xdata = None
198+
self._infodict = {}
199+
self._results = {}
200+
self.infodict = {k: None for k in ["mfev", "fvec", "fjac", "ipvt", "qtf"]} if infodict is None else infodict
201+
self.results = (
202+
{k: None for k in ["pop", "perr", "pcov", "mesg", "ier", "chisq", "nfree"]} if results is None else results
203+
)
204+
self.data = None
214205
self.labels = None
215206
self.units = None
216207

@@ -250,14 +241,28 @@ def row(self):
250241
@property
251242
def infodict(self):
252243
"""Wrapper infodict."""
253-
return getattr(self, "_infodict", None)
244+
return self._infodict
254245

255246
@infodict.setter
256247
def infodict(self, value):
257248
"""Wrapper for setting infodict and subkeys."""
249+
self._infodict = value
258250
if isinstance(value, dict):
259251
for k in value:
260-
setattr(self, k, value[k])
252+
self._mapping[k] = "_infodict"
253+
254+
@property
255+
def results(self):
256+
"""Wrapper for a results dictionary."""
257+
return self._results
258+
259+
@results.setter
260+
def results(self, value):
261+
"""Wrapper for setting results dictionary and nting keys."""
262+
self._results = value
263+
if isinstance(value, dict):
264+
for k in value:
265+
self._mapping[k] = "_results"
261266

262267
@property
263268
def fit(self):
@@ -267,22 +272,18 @@ def fit(self):
267272
@property
268273
def fit_values(self):
269274
"""Return the fit values if x data is set."""
270-
if self.xdata is None or self.func is None or self.popt is None:
275+
if self.data is None or self.func is None or self.popt is None:
271276
raise ValueError(
272277
"Need to have some x-data, the fitting functions and optimal parameters before calculating fit"
273278
)
274-
return self.func(self.xdata, *self.popt)
279+
return self.func(self.data.data[:, self.settings.columns.xcol], *self.popt)
275280

276281
@property
277-
def data(self):
278-
"""Return the data that was fitted."""
279-
self._data = getattr(self, "_data", np.array([]))
280-
return self._data
281-
282-
@data.setter
283-
def data(self, data):
284-
"""Return the data that was fitted."""
285-
self._data = data
282+
def perr(self):
283+
"""Return standar error from covariance matrix."""
284+
if "perr" in self.results and self.results["perr"] is not None:
285+
return self.results["perr"]
286+
return np.sqrt(np.diag(self.pcov))
286287

287288
@property
288289
def report(self):
@@ -336,6 +337,27 @@ def params(self):
336337
"""List the parameter class objects."""
337338
return get_func_params(self.func)
338339

340+
def __dir__(self):
341+
"""Extend the attribute directory."""
342+
return super().__dir__() + list(self._mapping.keys())
343+
344+
def __getattr__(self, name):
345+
"""Defer to using things we're tracking in the mapping."""
346+
try:
347+
return super().__getattr__(name)
348+
except AttributeError:
349+
pass
350+
if name in self._mapping:
351+
return getattr(self, self._mapping[name]).get(name, None)
352+
raise AttributeError(f"{name} is not an attribute or {self.__class__.__name__}.")
353+
354+
def __setattr__(self, name, value):
355+
"""Pass through if an atrbute is already set."""
356+
if name != "_mapping" and name in self._mapping:
357+
getattr(self, self._mapping[name])[name] = value
358+
return
359+
super().__setattr__(name, value)
360+
339361
def fit_report(self):
340362
"""Create a Fit report like lmfit does."""
341363
template = f"""[[ Model ]]

Stoner/analysis/fitting/functions.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,12 @@ def _func(x, *beta):
241241
return result_obj
242242

243243

244-
def _normalise_fit_result(datafile, xcol, ycol, fit, result_obj):
244+
def _normalise_fit_result(datafile, settings, fit, result_obj):
245245
"""Normalise the fit results based on the fit instance."""
246246
func = result_obj.func
247247
args = result_obj.args
248-
result_obj.data = datafile.data[:, ycol]
249-
result_obj.xdata = datafile.data[:, xcol]
248+
result_obj.data = datafile
249+
result_obj.settings = settings
250250

251251
match fit:
252252
case _Curve_Fit_Result():
@@ -273,23 +273,19 @@ def _normalise_fit_result(datafile, xcol, ycol, fit, result_obj):
273273
perr = fit.perr
274274
nfev = fit.nfev
275275
nfree = len(datafile) - len(popt)
276-
fit_data = func(result_obj.xdata, *popt)
277-
chisq = np.sum((result_obj.data - fit_data) ** 2) / nfree
276+
fit_data = func(datafile // settings.columns.xcol, *popt)
277+
chisq = np.sum((datafile.data[:, settings.columns.ycol] - fit_data) ** 2) / nfree
278278
case _:
279279
raise RuntimeError("Unable to understand {type(fit)} as a fitting result")
280-
result_obj.popt = popt
281-
result_obj.perr = perr
282-
result_obj.nfev = nfev
283-
result_obj.chisq = chisq
284-
result_obj.nfree = nfree
280+
result_obj.results = {"popt": popt, "perr": perr, "nfev": nfev, "chisq": chisq, "nfree": nfree}
285281
return result_obj
286282

287283

288284
def _record_curve_fit_result(datafile, func, fit, settings):
289285
"""Annotate the DataFile object with the curve_fit result."""
290286
result_obj = _Curve_Fit_Result()
291287
result_obj = _normalise_model_func(func, settings.prefix, result_obj)
292-
result_obj = _normalise_fit_result(datafile, settings.columns.xcol, settings.columns.ycol, fit, result_obj)
288+
result_obj = _normalise_fit_result(datafile, settings, fit, result_obj)
293289

294290
result_obj.add_metadata(datafile)
295291

@@ -721,14 +717,15 @@ def p0_func(ydata,x=xdata):
721717
sigma = None
722718
for var in ["xcol", "ycol", "zcol", "xerr", "yerr", "zerr", "scale_covar"]:
723719
kwargs.pop(var, None)
724-
report = _Curve_Fit_Result(*_curve_fit(_func, xdat, ydat, **kwargs))
720+
popt, pcov, infodict, msg, ier = _curve_fit(_func, xdat, ydat, **kwargs)
721+
report = _Curve_Fit_Result()
722+
report.settings = settings
725723
report.func = func
726-
if p0 is None:
727-
report.p0 = np.ones(len(report.popt))
728-
else:
729-
report.p0 = p0
724+
report.results = {"popt": popt, "pcov": pcov, "mesg": msg, "ier": ier}
725+
report.infodict = infodict
726+
report.p0 = np.ones(len(report.popt)) if p0 is None else p0
730727
report.data = datafile
731-
report.residual_vals = data.y - report.fvec
728+
report.residual_vals = ydat - report.fvec
732729
report.chisq = (report.residual_vals**2).sum()
733730
report.nfree = len(datafile) - len(report.popt)
734731
report.chisq /= report.nfree
@@ -860,11 +857,15 @@ def differential_evolution(datafile, model, xcol=None, ycol=None, p0=None, sigma
860857
kwargs.pop("polish", None)
861858
kwargs["full_output"] = True
862859
kwargs["absolute_sigma"] = abs_sigma
863-
polish = _Curve_Fit_Result(*_curve_fit(model.func, data.x, data.y[0], sigma=data.e[0], p0=fit.x, **kwargs))
860+
popt, pcov, infodict, mesg, ier = _curve_fit(model.func, data.x, data.y[0], sigma=data.e[0], p0=fit.x, **kwargs)
861+
polish = _Curve_Fit_Result()
862+
polish.results = {"popt": popt, "pcov": pcov, "mesg": mesg, "ier": ier}
863+
polish.infodict = infodict
864+
polish.data = datafile
865+
polish.settings = settings
864866

865867
polish.func = model.func
866868
polish.p0 = p0
867-
polish.data = datafile
868869
polish.residual_vals = data.y - polish.fvec
869870
polish.chisq = (polish.residual_vals**2).sum()
870871
polish.nfree = len(datafile) - len(polish.popt)

scripts/PCAR-chi^2.ini

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ rescale_v: False
2121
# Warning, this may go badly wrong with certain DC waveforms
2222
remove_offset: False
2323
# Be clever about annotating the result on the plots
24+
# Decomposes the data into symmetric and antisymmetric parts before fitting
25+
decompose: True
2426
fancy_result: True
2527
#
2628
# Can switch between a least-squares fitting algorithm based on the lmfit module, or othogonal distance regression "odr"
@@ -76,9 +78,9 @@ units: meV
7678
value: 0.5
7779
vary: False
7880
min: 0.2
79-
max: 0.7
81+
max: 0.54
8082
label: P
81-
step: 0.05
83+
step: 0.025
8284

8385
[Z]
8486
value: 0.17

0 commit comments

Comments
 (0)