Coverage for larch/fitting/__init__.py: 52%
322 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-10-16 21:04 +0000
« prev ^ index » next coverage.py v7.6.0, created at 2024-10-16 21:04 +0000
1#!/usr/bin/env python
3from copy import copy, deepcopy
4import random
5import numpy as np
6from scipy.stats import f
8import lmfit
9from lmfit import Parameter
10from lmfit import (Parameters, Minimizer, conf_interval,
11 ci_report, conf_interval2d)
13from lmfit.minimizer import MinimizerResult
14from lmfit.model import (ModelResult, save_model, load_model,
15 save_modelresult, load_modelresult)
16from lmfit.confidence import f_compare
18from lmfit.printfuncs import gformat, getfloat_attr
19from uncertainties import ufloat, correlated_values
20from uncertainties import wrap as un_wrap
22from ..symboltable import Group, isgroup
25def isParameter(x):
26 return (isinstance(x, Parameter) or
27 x.__class__.__name__ == 'Parameter')
29def param_value(val):
30 "get param value -- useful for 3rd party code"
31 while isinstance(val, Parameter):
32 val = val.value
33 return val
35def format_param(par, length=10, with_initial=True):
36 value = repr(par)
37 if isParameter(par):
38 value = gformat(par.value, length=length)
39 if not par.vary and par.expr is None:
40 value = f"{value} (fixed)"
41 else:
42 stderr = 'unknown'
43 if par.stderr is not None:
44 stderr = gformat(par.stderr, length=length)
45 value = f"{value} +/-{stderr}"
46 if par.vary and par.expr is None and with_initial:
47 value = f"{value} (init={gformat(par.init_value, length=length)})"
48 if par.expr is not None:
49 value = f"{value} = '{par.expr}'"
50 return value
53def stats_table(results, labels=None, csv_output=False, csv_delim=','):
54 """
55 create a table comparing fit statistics for multiple fit results
56 """
57 stats = {'number of variables': 'nvarys',
58 'chi-square': 'chi_square',
59 'reduced chi-square': 'chi2_reduced',
60 'r-factor': 'rfactor',
61 'Akaike Info Crit': 'aic',
62 'Bayesian Info Crit': 'bic'}
64 nfits = len(results)
65 if labels is not None:
66 if len(labels) != len(results):
67 raise ValueError('labels must be a list that is the same length as results')
69 columns = [['Statistics']]
70 if labels is None:
71 labels = [f" Fit {i+1}" for i in range(nfits)]
72 for lab in labels:
73 columns.append([lab])
75 for sname, attr in stats.items():
76 columns[0].append(sname)
77 for i, result in enumerate(results):
78 columns[i+1].append(getfloat_attr(result, attr))
80 return format_table_columns(columns, csv_output=csv_output, csv_delim=csv_delim)
82def paramgroups_table(pgroups, labels=None, csv_output=False, csv_delim=','):
83 """
84 create a table comparing parameters from a Feffit Parameter Grooup for multiple fit results
85 """
86 nfits = len(pgroups)
87 if labels is not None:
88 if len(labels) != len(pgroups):
89 raise ValueError('labels must be a list that is the same length as Parameter Groups')
91 columns = [['Parameter']]
92 if labels is None:
93 labels = [f" Fit {i+1}" for i in range(nfits)]
94 for lab in labels:
95 columns.append([lab])
97 parnames = []
98 for pgroup in pgroups:
99 for pname in dir(pgroup):
100 if pname not in parnames:
101 parnames.append(pname)
103 for pname in parnames:
104 columns[0].append(pname)
105 for i, pgroup in enumerate(pgroups):
106 value = 'N/A'
107 par = getattr(pgroup, pname, None)
108 if par is not None:
109 value = format_param(par, length=10, with_initial=False)
110 columns[i+1].append(value)
112 return format_table_columns(columns, csv_output=csv_output, csv_delim=csv_delim)
115def format_table_columns(columns, csv_output=False, csv_delim=','):
116 hjoin, rjoin, edge = '+', '|', '|'
117 if csv_output:
118 hjoin, rjoin, edge = csv_delim, csv_delim, ''
120 ncols = len(columns)
121 nrows = len(columns[0])
122 slen = [2]*ncols
123 for i, col in enumerate(columns):
124 slen[i] = max(5, max([len(row) for row in col]))
126 buff = []
127 if not csv_output:
128 header = edge + hjoin.join(['-'*(lx+2) for lx in slen]) + edge
129 buff = [header]
131 while len(columns[0]) > 0:
132 values = [c.pop(0) for c in columns]
133 row = rjoin.join([f" {l:{slen[i]}.{slen[i]}s} " for i, l in enumerate(values)])
134 buff.append(edge + row + edge)
135 if not csv_output and len(buff) == 2:
136 buff.append(header)
137 if not csv_output:
138 buff.append(header)
139 return '\n'.join(buff)
142def f_test(ndata, nvars, chisquare, chisquare0, nfix=1):
143 """return the F-test value for the following input values:
144 f = f_test(ndata, nparams, chisquare, chisquare0, nfix=1)
146 nfix = the number of fixed parameters.
147 """
148 return f_compare(ndata, nvars, chisquare, chisquare0, nfix=1)
150def confidence_report(conf_vals, **kws):
151 """return a formatted report of confidence intervals calcualted
152 by confidence_intervals
153 """
154 return ci_report(conf_vals)
156def asteval_with_uncertainties(*vals, **kwargs):
157 """Calculate object value, given values for variables.
159 This is used by the uncertainties package to calculate the
160 uncertainty in an object even with a complicated expression.
162 """
163 _obj = kwargs.get('_obj', None)
164 _pars = kwargs.get('_pars', None)
165 _names = kwargs.get('_names', None)
166 _asteval = _pars._asteval
167 if (_obj is None or _pars is None or _names is None or
168 _asteval is None or _obj._expr_ast is None):
169 return 0
170 for val, name in zip(vals, _names):
171 _asteval.symtable[name] = val
173 # re-evaluate all constraint parameters to
174 # force the propagation of uncertainties
175 [p._getval() for p in _pars.values()]
176 return _asteval.eval(_obj._expr_ast)
179wrap_ueval = un_wrap(asteval_with_uncertainties)
182def eval_stderr(obj, uvars, _names, _pars):
183 """Evaluate uncertainty and set ``.stderr`` for a parameter `obj`.
185 Given the uncertain values `uvars` (list of `uncertainties.ufloats`),
186 a list of parameter names that matches `uvars`, and a dictionary of
187 parameter objects, keyed by name.
189 This uses the uncertainties package wrapped function to evaluate the
190 uncertainty for an arbitrary expression (in ``obj._expr_ast``) of
191 parameters.
193 """
194 if not isinstance(obj, Parameter) or getattr(obj, '_expr_ast', None) is None:
195 return
196 uval = wrap_ueval(*uvars, _obj=obj, _names=_names, _pars=_pars)
197 try:
198 obj.stderr = uval.std_dev
199 except Exception:
200 obj.stderr = 0
203class ParameterGroup(Group):
204 """
205 Group for Fitting Parameters
206 """
207 def __init__(self, name=None, **kws):
208 if name is not None:
209 self.__name__ = name
210 if '_larch' in kws:
211 kws.pop('_larch')
212 self.__params__ = Parameters()
213 Group.__init__(self)
214 self.__exprsave__ = {}
215 for key, val in kws.items():
216 expr = getattr(val, 'expr', None)
217 if expr is not None:
218 self.__exprsave__[key] = expr
219 val.expr = None
220 setattr(self, key, val)
222 for key, val in self.__exprsave__.items():
223 self.__params__[key].expr = val
226 def __repr__(self):
227 return '<Param Group {:s}>'.format(self.__name__)
229 def __setattr__(self, name, val):
230 if isParameter(val):
231 if val.name != name:
232 # allow 'a=Parameter(2, ..)' to mean Parameter(name='a', value=2, ...)
233 nval = None
234 try:
235 nval = float(val.name)
236 except (ValueError, TypeError):
237 pass
238 if nval is not None:
239 val.value = nval
240 skip = getattr(val, 'skip', None)
241 self.__params__.add(name, value=val.value, vary=val.vary, min=val.min,
242 max=val.max, expr=val.expr, brute_step=val.brute_step)
243 val = self.__params__[name]
245 val.skip = skip
246 elif hasattr(self, '__params__') and not name.startswith('__'):
247 self.__params__._asteval.symtable[name] = val
248 self.__dict__[name] = val
250 def __delattr__(self, name):
251 self.__dict__.pop(name)
252 if name in self.__params__:
253 self.__params__.pop(name)
255 def __add(self, name, value=None, vary=True, min=-np.inf, max=np.inf,
256 expr=None, stderr=None, correl=None, brute_step=None, skip=None):
257 if expr is None and isinstance(value, str):
258 expr = value
259 value = None
260 if self.__params__ is not None:
261 self.__params__.add(name, value=value, vary=vary, min=min, max=max,
262 expr=expr, brute_step=brute_step)
263 self.__params__[name].stderr = stderr
264 self.__params__[name].correl = correl
265 self.__params__[name].skip = skip
266 self.__dict__[name] = self.__params__[name]
269def param_group(**kws):
270 "create a parameter group"
271 return ParameterGroup(**kws)
273def randstr(n):
274 return ''.join([chr(random.randint(97, 122)) for i in range(n)])
276class unnamedParameter(Parameter):
277 """A Parameter that can be nameless"""
278 def __init__(self, name=None, value=None, vary=True, min=-np.inf, max=np.inf,
279 expr=None, brute_step=None, user_data=None, skip=None):
280 if name is None:
281 name = randstr(8)
282 self.name = name
283 self.user_data = user_data
284 self.init_value = value
285 self.min = min
286 self.max = max
287 self.brute_step = brute_step
288 self.vary = vary
289 self.skip = skip
290 self._expr = expr
291 self._expr_ast = None
292 self._expr_eval = None
293 self._expr_deps = []
294 self._delay_asteval = False
295 self.stderr = None
296 self.correl = None
297 self.from_internal = lambda val: val
298 self._val = value
299 self._init_bounds()
300 Parameter.__init__(self, name, value=value, vary=vary,
301 min=min, max=max, expr=expr,
302 brute_step=brute_step,
303 user_data=user_data)
305def param(*args, **kws):
306 "create a fitting Parameter as a Variable"
307 if len(args) > 0:
308 a0 = args[0]
309 if isinstance(a0, str):
310 kws.update({'expr': a0})
311 elif isinstance(a0, (int, float)):
312 kws.update({'value': a0})
313 else:
314 raise ValueError("first argument to param() must be string or number")
315 args = args[1:]
316 if '_larch' in kws:
317 kws.pop('_larch')
318 if 'vary' not in kws:
319 kws['vary'] = False
321 return unnamedParameter(*args, **kws)
323def guess(value, **kws):
324 """create a fitting Parameter as a Variable.
325 A minimum or maximum value for the variable value can be given:
326 x = guess(10, min=0)
327 y = guess(1.2, min=1, max=2)
328 """
329 kws.update({'vary':True})
330 return param(value, **kws)
332def is_param(obj):
333 """return whether an object is a Parameter"""
334 return isParameter(obj)
336def dict2params(pars):
337 """sometimes we get a plain dict of Parameters,
338 with vals that are Parameter objects"""
339 if isinstance(pars, Parameters):
340 return pars
341 out = Parameters()
342 for key, val in pars.items():
343 if isinstance(val, Parameter):
344 out[key] = val
345 return out
347def group2params(paramgroup):
348 """take a Group of Parameter objects (and maybe other things)
349 and put them into a lmfit.Parameters, ready for use in fitting
350 """
351 if isinstance(paramgroup, Parameters):
352 return paramgroup
353 if isinstance(paramgroup, dict):
354 params = Parameters()
355 for key, val in paramgroup.items():
356 if isinstance(val, Parameter):
357 params[key] = val
358 return params
361 if isinstance(paramgroup, ParameterGroup):
362 return paramgroup.__params__
364 params = Parameters()
365 if paramgroup is not None:
366 for name in dir(paramgroup):
367 par = getattr(paramgroup, name)
368 if getattr(par, 'skip', None) not in (False, None):
369 continue
370 if isParameter(par):
371 params.add(name, value=par.value, vary=par.vary,
372 min=par.min, max=par.max,
373 brute_step=par.brute_step)
374 else:
375 params._asteval.symtable[name] = par
377 # now set any expression (that is, after all symbols are defined)
378 for name in dir(paramgroup):
379 par = getattr(paramgroup, name)
380 if isParameter(par) and par.expr is not None:
381 params[name].expr = par.expr
383 return params
385def params2group(params, paramgroup):
386 """fill Parameter objects in paramgroup with
387 values from lmfit.Parameters
388 """
389 _params = getattr(paramgroup, '__params__', None)
390 for name, param in params.items():
391 this = getattr(paramgroup, name, None)
392 if isParameter(this):
393 if _params is not None:
394 _params[name] = this
395 for attr in ('value', 'vary', 'stderr', 'min', 'max', 'expr',
396 'name', 'correl', 'brute_step', 'user_data'):
397 setattr(this, attr, getattr(param, attr, None))
398 if this.stderr is not None:
399 try:
400 this.uvalue = ufloat(this.value, this.stderr)
401 except:
402 pass
405def minimize(fcn, paramgroup, method='leastsq', args=None, kws=None,
406 scale_covar=True, iter_cb=None, reduce_fcn=None, nan_polcy='omit',
407 **fit_kws):
408 """
409 wrapper around lmfit minimizer for Larch
410 """
411 if isinstance(paramgroup, ParameterGroup):
412 params = paramgroup.__params__
413 elif isgroup(paramgroup):
414 params = group2params(paramgroup)
415 elif isinstance(Parameters):
416 params = paramgroup
417 else:
418 raise ValueError('minimize takes ParamterGroup or Group as first argument')
420 if args is None:
421 args = ()
422 if kws is None:
423 kws = {}
425 def _residual(params):
426 params2group(params, paramgroup)
427 return fcn(paramgroup, *args, **kws)
429 fitter = Minimizer(_residual, params, iter_cb=iter_cb,
430 reduce_fcn=reduce_fcn, nan_policy='omit', **fit_kws)
432 result = fitter.minimize(method=method)
433 params2group(result.params, paramgroup)
435 out = Group(name='minimize results', fitter=fitter, fit_details=result,
436 chi_square=result.chisqr, chi_reduced=result.redchi)
438 for attr in ('aic', 'bic', 'covar', 'params', 'nvarys',
439 'nfree', 'ndata', 'var_names', 'nfev', 'success',
440 'errorbars', 'message', 'lmdif_message', 'residual'):
441 setattr(out, attr, getattr(result, attr, None))
442 return out
444def fit_report(fit_result, modelpars=None, show_correl=True, min_correl=0.1,
445 sort_pars=True, **kws):
446 """generate a report of fitting results
447 wrapper around lmfit.fit_report
449 The report contains the best-fit values for the parameters and their
450 uncertainties and correlations.
452 Parameters
453 ----------
454 fit_result : result from fit
455 Fit Group output from fit, or lmfit.MinimizerResult returned from a fit.
456 modelpars : Parameters, optional
457 Known Model Parameters.
458 show_correl : bool, optional
459 Whether to show list of sorted correlations (default is True).
460 min_correl : float, optional
461 Smallest correlation in absolute value to show (default is 0.1).
462 sort_pars : bool or callable, optional
463 Whether to show parameter names sorted in alphanumerical order. If
464 False, then the parameters will be listed in the order they
465 were added to the Parameters dictionary. If callable, then this (one
466 argument) function is used to extract a comparison key from each
467 list element.
469 Returns
470 -------
471 string
472 Multi-line text of fit report.
475 """
476 result = getattr(fit_result, 'fit_details', fit_result)
477 if isinstance(result, MinimizerResult):
478 return lmfit.fit_report(result, modelpars=modelpars,
479 show_correl=show_correl,
480 min_correl=min_correl, sort_pars=sort_pars)
481 elif isinstance(result, ModelResult):
482 return result.fit_report(modelpars=modelpars,
483 show_correl=show_correl,
484 min_correl=min_correl, sort_pars=sort_pars)
485 else:
486 result = getattr(fit_result, 'params', fit_result)
487 if isinstance(result, Parameters):
488 return lmfit.fit_report(result, modelpars=modelpars,
489 show_correl=show_correl,
490 min_correl=min_correl, sort_pars=sort_pars)
491 else:
492 try:
493 result = group2params(fit_result)
494 return lmfit.fit_report(result, modelpars=modelpars,
495 show_correl=show_correl,
496 min_correl=min_correl, sort_pars=sort_pars)
497 except (ValueError, AttributeError):
498 pass
499 return "Cannot make fit report with %s" % repr(fit_result)
502def confidence_intervals(fit_result, sigmas=(1, 2, 3), **kws):
503 """calculate the confidence intervals from a fit
504 for supplied sigma values
506 wrapper around lmfit.conf_interval
507 """
508 fitter = getattr(fit_result, 'fitter', None)
509 result = getattr(fit_result, 'fit_details', None)
510 return conf_interval(fitter, result, sigmas=sigmas, **kws)
512def chi2_map(fit_result, xname, yname, nx=21, ny=21, sigma=3, **kws):
513 """generate a confidence map for any two parameters for a fit
515 Arguments
516 ==========
517 minout output of minimize() fit (must be run first)
518 xname name of variable parameter for x-axis
519 yname name of variable parameter for y-axis
520 nx number of steps in x [21]
521 ny number of steps in y [21]
522 sigma scale for uncertainty range [3]
524 Returns
525 =======
526 xpts, ypts, map
528 Notes
529 =====
530 1. sigma sets the extent of values to explore:
531 param.value +/- sigma * param.stderr
532 """
533 #
534 fitter = getattr(fit_result, 'fitter', None)
535 result = getattr(fit_result, 'fit_details', None)
536 if fitter is None or result is None:
537 raise ValueError("chi2_map needs valid fit result as first argument")
539 x = result.params[xname]
540 y = result.params[yname]
541 xrange = (x.value + sigma * x.stderr, x.value - sigma * x.stderr)
542 yrange = (y.value + sigma * y.stderr, y.value - sigma * y.stderr)
544 return conf_interval2d(fitter, result, xname, yname,
545 limits=(xrange, yrange),
546 nx=nx, ny=ny, nsigma=2*sigma, **kws)
548_larch_name = '_math'
549exports = {'param': param,
550 'guess': guess,
551 'param_group': param_group,
552 'confidence_intervals': confidence_intervals,
553 'confidence_report': confidence_report,
554 'f_test': f_test, 'chi2_map': chi2_map,
555 'is_param': isParameter,
556 'isparam': isParameter,
557 'minimize': minimize,
558 'ufloat': ufloat,
559 'fit_report': fit_report,
560 'Parameters': Parameters,
561 'Parameter': Parameter,
562 'lm_minimize': minimize,
563 'lm_save_model': save_model,
564 'lm_load_model': load_model,
565 'lm_save_modelresult': save_modelresult,
566 'lm_load_modelresult': load_modelresult,
567 }
569for name in ('BreitWignerModel', 'ComplexConstantModel',
570 'ConstantModel', 'DampedHarmonicOscillatorModel',
571 'DampedOscillatorModel', 'DoniachModel',
572 'ExponentialGaussianModel', 'ExponentialModel',
573 'ExpressionModel', 'GaussianModel', 'Interpreter',
574 'LinearModel', 'LognormalModel', 'LorentzianModel',
575 'MoffatModel', 'ParabolicModel', 'Pearson7Model',
576 'PolynomialModel', 'PowerLawModel',
577 'PseudoVoigtModel', 'QuadraticModel',
578 'RectangleModel', 'SkewedGaussianModel',
579 'StepModel', 'StudentsTModel', 'VoigtModel'):
580 val = getattr(lmfit.models, name, None)
581 if val is not None:
582 exports[name] = val
584_larch_builtins = {'_math': exports}