Coverage for larch/fitting/__init__.py: 52%

322 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-10-16 21:04 +0000

1#!/usr/bin/env python 

2 

3from copy import copy, deepcopy 

4import random 

5import numpy as np 

6from scipy.stats import f 

7 

8import lmfit 

9from lmfit import Parameter 

10from lmfit import (Parameters, Minimizer, conf_interval, 

11 ci_report, conf_interval2d) 

12 

13from lmfit.minimizer import MinimizerResult 

14from lmfit.model import (ModelResult, save_model, load_model, 

15 save_modelresult, load_modelresult) 

16from lmfit.confidence import f_compare 

17 

18from lmfit.printfuncs import gformat, getfloat_attr 

19from uncertainties import ufloat, correlated_values 

20from uncertainties import wrap as un_wrap 

21 

22from ..symboltable import Group, isgroup 

23 

24 

25def isParameter(x): 

26 return (isinstance(x, Parameter) or 

27 x.__class__.__name__ == 'Parameter') 

28 

29def param_value(val): 

30 "get param value -- useful for 3rd party code" 

31 while isinstance(val, Parameter): 

32 val = val.value 

33 return val 

34 

35def format_param(par, length=10, with_initial=True): 

36 value = repr(par) 

37 if isParameter(par): 

38 value = gformat(par.value, length=length) 

39 if not par.vary and par.expr is None: 

40 value = f"{value} (fixed)" 

41 else: 

42 stderr = 'unknown' 

43 if par.stderr is not None: 

44 stderr = gformat(par.stderr, length=length) 

45 value = f"{value} +/-{stderr}" 

46 if par.vary and par.expr is None and with_initial: 

47 value = f"{value} (init={gformat(par.init_value, length=length)})" 

48 if par.expr is not None: 

49 value = f"{value} = '{par.expr}'" 

50 return value 

51 

52 

53def stats_table(results, labels=None, csv_output=False, csv_delim=','): 

54 """ 

55 create a table comparing fit statistics for multiple fit results 

56 """ 

57 stats = {'number of variables': 'nvarys', 

58 'chi-square': 'chi_square', 

59 'reduced chi-square': 'chi2_reduced', 

60 'r-factor': 'rfactor', 

61 'Akaike Info Crit': 'aic', 

62 'Bayesian Info Crit': 'bic'} 

63 

64 nfits = len(results) 

65 if labels is not None: 

66 if len(labels) != len(results): 

67 raise ValueError('labels must be a list that is the same length as results') 

68 

69 columns = [['Statistics']] 

70 if labels is None: 

71 labels = [f" Fit {i+1}" for i in range(nfits)] 

72 for lab in labels: 

73 columns.append([lab]) 

74 

75 for sname, attr in stats.items(): 

76 columns[0].append(sname) 

77 for i, result in enumerate(results): 

78 columns[i+1].append(getfloat_attr(result, attr)) 

79 

80 return format_table_columns(columns, csv_output=csv_output, csv_delim=csv_delim) 

81 

82def paramgroups_table(pgroups, labels=None, csv_output=False, csv_delim=','): 

83 """ 

84 create a table comparing parameters from a Feffit Parameter Grooup for multiple fit results 

85 """ 

86 nfits = len(pgroups) 

87 if labels is not None: 

88 if len(labels) != len(pgroups): 

89 raise ValueError('labels must be a list that is the same length as Parameter Groups') 

90 

91 columns = [['Parameter']] 

92 if labels is None: 

93 labels = [f" Fit {i+1}" for i in range(nfits)] 

94 for lab in labels: 

95 columns.append([lab]) 

96 

97 parnames = [] 

98 for pgroup in pgroups: 

99 for pname in dir(pgroup): 

100 if pname not in parnames: 

101 parnames.append(pname) 

102 

103 for pname in parnames: 

104 columns[0].append(pname) 

105 for i, pgroup in enumerate(pgroups): 

106 value = 'N/A' 

107 par = getattr(pgroup, pname, None) 

108 if par is not None: 

109 value = format_param(par, length=10, with_initial=False) 

110 columns[i+1].append(value) 

111 

112 return format_table_columns(columns, csv_output=csv_output, csv_delim=csv_delim) 

113 

114 

115def format_table_columns(columns, csv_output=False, csv_delim=','): 

116 hjoin, rjoin, edge = '+', '|', '|' 

117 if csv_output: 

118 hjoin, rjoin, edge = csv_delim, csv_delim, '' 

119 

120 ncols = len(columns) 

121 nrows = len(columns[0]) 

122 slen = [2]*ncols 

123 for i, col in enumerate(columns): 

124 slen[i] = max(5, max([len(row) for row in col])) 

125 

126 buff = [] 

127 if not csv_output: 

128 header = edge + hjoin.join(['-'*(lx+2) for lx in slen]) + edge 

129 buff = [header] 

130 

131 while len(columns[0]) > 0: 

132 values = [c.pop(0) for c in columns] 

133 row = rjoin.join([f" {l:{slen[i]}.{slen[i]}s} " for i, l in enumerate(values)]) 

134 buff.append(edge + row + edge) 

135 if not csv_output and len(buff) == 2: 

136 buff.append(header) 

137 if not csv_output: 

138 buff.append(header) 

139 return '\n'.join(buff) 

140 

141 

142def f_test(ndata, nvars, chisquare, chisquare0, nfix=1): 

143 """return the F-test value for the following input values: 

144 f = f_test(ndata, nparams, chisquare, chisquare0, nfix=1) 

145 

146 nfix = the number of fixed parameters. 

147 """ 

148 return f_compare(ndata, nvars, chisquare, chisquare0, nfix=1) 

149 

150def confidence_report(conf_vals, **kws): 

151 """return a formatted report of confidence intervals calcualted 

152 by confidence_intervals 

153 """ 

154 return ci_report(conf_vals) 

155 

156def asteval_with_uncertainties(*vals, **kwargs): 

157 """Calculate object value, given values for variables. 

158 

159 This is used by the uncertainties package to calculate the 

160 uncertainty in an object even with a complicated expression. 

161 

162 """ 

163 _obj = kwargs.get('_obj', None) 

164 _pars = kwargs.get('_pars', None) 

165 _names = kwargs.get('_names', None) 

166 _asteval = _pars._asteval 

167 if (_obj is None or _pars is None or _names is None or 

168 _asteval is None or _obj._expr_ast is None): 

169 return 0 

170 for val, name in zip(vals, _names): 

171 _asteval.symtable[name] = val 

172 

173 # re-evaluate all constraint parameters to 

174 # force the propagation of uncertainties 

175 [p._getval() for p in _pars.values()] 

176 return _asteval.eval(_obj._expr_ast) 

177 

178 

179wrap_ueval = un_wrap(asteval_with_uncertainties) 

180 

181 

182def eval_stderr(obj, uvars, _names, _pars): 

183 """Evaluate uncertainty and set ``.stderr`` for a parameter `obj`. 

184 

185 Given the uncertain values `uvars` (list of `uncertainties.ufloats`), 

186 a list of parameter names that matches `uvars`, and a dictionary of 

187 parameter objects, keyed by name. 

188 

189 This uses the uncertainties package wrapped function to evaluate the 

190 uncertainty for an arbitrary expression (in ``obj._expr_ast``) of 

191 parameters. 

192 

193 """ 

194 if not isinstance(obj, Parameter) or getattr(obj, '_expr_ast', None) is None: 

195 return 

196 uval = wrap_ueval(*uvars, _obj=obj, _names=_names, _pars=_pars) 

197 try: 

198 obj.stderr = uval.std_dev 

199 except Exception: 

200 obj.stderr = 0 

201 

202 

203class ParameterGroup(Group): 

204 """ 

205 Group for Fitting Parameters 

206 """ 

207 def __init__(self, name=None, **kws): 

208 if name is not None: 

209 self.__name__ = name 

210 if '_larch' in kws: 

211 kws.pop('_larch') 

212 self.__params__ = Parameters() 

213 Group.__init__(self) 

214 self.__exprsave__ = {} 

215 for key, val in kws.items(): 

216 expr = getattr(val, 'expr', None) 

217 if expr is not None: 

218 self.__exprsave__[key] = expr 

219 val.expr = None 

220 setattr(self, key, val) 

221 

222 for key, val in self.__exprsave__.items(): 

223 self.__params__[key].expr = val 

224 

225 

226 def __repr__(self): 

227 return '<Param Group {:s}>'.format(self.__name__) 

228 

229 def __setattr__(self, name, val): 

230 if isParameter(val): 

231 if val.name != name: 

232 # allow 'a=Parameter(2, ..)' to mean Parameter(name='a', value=2, ...) 

233 nval = None 

234 try: 

235 nval = float(val.name) 

236 except (ValueError, TypeError): 

237 pass 

238 if nval is not None: 

239 val.value = nval 

240 skip = getattr(val, 'skip', None) 

241 self.__params__.add(name, value=val.value, vary=val.vary, min=val.min, 

242 max=val.max, expr=val.expr, brute_step=val.brute_step) 

243 val = self.__params__[name] 

244 

245 val.skip = skip 

246 elif hasattr(self, '__params__') and not name.startswith('__'): 

247 self.__params__._asteval.symtable[name] = val 

248 self.__dict__[name] = val 

249 

250 def __delattr__(self, name): 

251 self.__dict__.pop(name) 

252 if name in self.__params__: 

253 self.__params__.pop(name) 

254 

255 def __add(self, name, value=None, vary=True, min=-np.inf, max=np.inf, 

256 expr=None, stderr=None, correl=None, brute_step=None, skip=None): 

257 if expr is None and isinstance(value, str): 

258 expr = value 

259 value = None 

260 if self.__params__ is not None: 

261 self.__params__.add(name, value=value, vary=vary, min=min, max=max, 

262 expr=expr, brute_step=brute_step) 

263 self.__params__[name].stderr = stderr 

264 self.__params__[name].correl = correl 

265 self.__params__[name].skip = skip 

266 self.__dict__[name] = self.__params__[name] 

267 

268 

269def param_group(**kws): 

270 "create a parameter group" 

271 return ParameterGroup(**kws) 

272 

273def randstr(n): 

274 return ''.join([chr(random.randint(97, 122)) for i in range(n)]) 

275 

276class unnamedParameter(Parameter): 

277 """A Parameter that can be nameless""" 

278 def __init__(self, name=None, value=None, vary=True, min=-np.inf, max=np.inf, 

279 expr=None, brute_step=None, user_data=None, skip=None): 

280 if name is None: 

281 name = randstr(8) 

282 self.name = name 

283 self.user_data = user_data 

284 self.init_value = value 

285 self.min = min 

286 self.max = max 

287 self.brute_step = brute_step 

288 self.vary = vary 

289 self.skip = skip 

290 self._expr = expr 

291 self._expr_ast = None 

292 self._expr_eval = None 

293 self._expr_deps = [] 

294 self._delay_asteval = False 

295 self.stderr = None 

296 self.correl = None 

297 self.from_internal = lambda val: val 

298 self._val = value 

299 self._init_bounds() 

300 Parameter.__init__(self, name, value=value, vary=vary, 

301 min=min, max=max, expr=expr, 

302 brute_step=brute_step, 

303 user_data=user_data) 

304 

305def param(*args, **kws): 

306 "create a fitting Parameter as a Variable" 

307 if len(args) > 0: 

308 a0 = args[0] 

309 if isinstance(a0, str): 

310 kws.update({'expr': a0}) 

311 elif isinstance(a0, (int, float)): 

312 kws.update({'value': a0}) 

313 else: 

314 raise ValueError("first argument to param() must be string or number") 

315 args = args[1:] 

316 if '_larch' in kws: 

317 kws.pop('_larch') 

318 if 'vary' not in kws: 

319 kws['vary'] = False 

320 

321 return unnamedParameter(*args, **kws) 

322 

323def guess(value, **kws): 

324 """create a fitting Parameter as a Variable. 

325 A minimum or maximum value for the variable value can be given: 

326 x = guess(10, min=0) 

327 y = guess(1.2, min=1, max=2) 

328 """ 

329 kws.update({'vary':True}) 

330 return param(value, **kws) 

331 

332def is_param(obj): 

333 """return whether an object is a Parameter""" 

334 return isParameter(obj) 

335 

336def dict2params(pars): 

337 """sometimes we get a plain dict of Parameters, 

338 with vals that are Parameter objects""" 

339 if isinstance(pars, Parameters): 

340 return pars 

341 out = Parameters() 

342 for key, val in pars.items(): 

343 if isinstance(val, Parameter): 

344 out[key] = val 

345 return out 

346 

347def group2params(paramgroup): 

348 """take a Group of Parameter objects (and maybe other things) 

349 and put them into a lmfit.Parameters, ready for use in fitting 

350 """ 

351 if isinstance(paramgroup, Parameters): 

352 return paramgroup 

353 if isinstance(paramgroup, dict): 

354 params = Parameters() 

355 for key, val in paramgroup.items(): 

356 if isinstance(val, Parameter): 

357 params[key] = val 

358 return params 

359 

360 

361 if isinstance(paramgroup, ParameterGroup): 

362 return paramgroup.__params__ 

363 

364 params = Parameters() 

365 if paramgroup is not None: 

366 for name in dir(paramgroup): 

367 par = getattr(paramgroup, name) 

368 if getattr(par, 'skip', None) not in (False, None): 

369 continue 

370 if isParameter(par): 

371 params.add(name, value=par.value, vary=par.vary, 

372 min=par.min, max=par.max, 

373 brute_step=par.brute_step) 

374 else: 

375 params._asteval.symtable[name] = par 

376 

377 # now set any expression (that is, after all symbols are defined) 

378 for name in dir(paramgroup): 

379 par = getattr(paramgroup, name) 

380 if isParameter(par) and par.expr is not None: 

381 params[name].expr = par.expr 

382 

383 return params 

384 

385def params2group(params, paramgroup): 

386 """fill Parameter objects in paramgroup with 

387 values from lmfit.Parameters 

388 """ 

389 _params = getattr(paramgroup, '__params__', None) 

390 for name, param in params.items(): 

391 this = getattr(paramgroup, name, None) 

392 if isParameter(this): 

393 if _params is not None: 

394 _params[name] = this 

395 for attr in ('value', 'vary', 'stderr', 'min', 'max', 'expr', 

396 'name', 'correl', 'brute_step', 'user_data'): 

397 setattr(this, attr, getattr(param, attr, None)) 

398 if this.stderr is not None: 

399 try: 

400 this.uvalue = ufloat(this.value, this.stderr) 

401 except: 

402 pass 

403 

404 

405def minimize(fcn, paramgroup, method='leastsq', args=None, kws=None, 

406 scale_covar=True, iter_cb=None, reduce_fcn=None, nan_polcy='omit', 

407 **fit_kws): 

408 """ 

409 wrapper around lmfit minimizer for Larch 

410 """ 

411 if isinstance(paramgroup, ParameterGroup): 

412 params = paramgroup.__params__ 

413 elif isgroup(paramgroup): 

414 params = group2params(paramgroup) 

415 elif isinstance(Parameters): 

416 params = paramgroup 

417 else: 

418 raise ValueError('minimize takes ParamterGroup or Group as first argument') 

419 

420 if args is None: 

421 args = () 

422 if kws is None: 

423 kws = {} 

424 

425 def _residual(params): 

426 params2group(params, paramgroup) 

427 return fcn(paramgroup, *args, **kws) 

428 

429 fitter = Minimizer(_residual, params, iter_cb=iter_cb, 

430 reduce_fcn=reduce_fcn, nan_policy='omit', **fit_kws) 

431 

432 result = fitter.minimize(method=method) 

433 params2group(result.params, paramgroup) 

434 

435 out = Group(name='minimize results', fitter=fitter, fit_details=result, 

436 chi_square=result.chisqr, chi_reduced=result.redchi) 

437 

438 for attr in ('aic', 'bic', 'covar', 'params', 'nvarys', 

439 'nfree', 'ndata', 'var_names', 'nfev', 'success', 

440 'errorbars', 'message', 'lmdif_message', 'residual'): 

441 setattr(out, attr, getattr(result, attr, None)) 

442 return out 

443 

444def fit_report(fit_result, modelpars=None, show_correl=True, min_correl=0.1, 

445 sort_pars=True, **kws): 

446 """generate a report of fitting results 

447 wrapper around lmfit.fit_report 

448 

449 The report contains the best-fit values for the parameters and their 

450 uncertainties and correlations. 

451 

452 Parameters 

453 ---------- 

454 fit_result : result from fit 

455 Fit Group output from fit, or lmfit.MinimizerResult returned from a fit. 

456 modelpars : Parameters, optional 

457 Known Model Parameters. 

458 show_correl : bool, optional 

459 Whether to show list of sorted correlations (default is True). 

460 min_correl : float, optional 

461 Smallest correlation in absolute value to show (default is 0.1). 

462 sort_pars : bool or callable, optional 

463 Whether to show parameter names sorted in alphanumerical order. If 

464 False, then the parameters will be listed in the order they 

465 were added to the Parameters dictionary. If callable, then this (one 

466 argument) function is used to extract a comparison key from each 

467 list element. 

468 

469 Returns 

470 ------- 

471 string 

472 Multi-line text of fit report. 

473 

474 

475 """ 

476 result = getattr(fit_result, 'fit_details', fit_result) 

477 if isinstance(result, MinimizerResult): 

478 return lmfit.fit_report(result, modelpars=modelpars, 

479 show_correl=show_correl, 

480 min_correl=min_correl, sort_pars=sort_pars) 

481 elif isinstance(result, ModelResult): 

482 return result.fit_report(modelpars=modelpars, 

483 show_correl=show_correl, 

484 min_correl=min_correl, sort_pars=sort_pars) 

485 else: 

486 result = getattr(fit_result, 'params', fit_result) 

487 if isinstance(result, Parameters): 

488 return lmfit.fit_report(result, modelpars=modelpars, 

489 show_correl=show_correl, 

490 min_correl=min_correl, sort_pars=sort_pars) 

491 else: 

492 try: 

493 result = group2params(fit_result) 

494 return lmfit.fit_report(result, modelpars=modelpars, 

495 show_correl=show_correl, 

496 min_correl=min_correl, sort_pars=sort_pars) 

497 except (ValueError, AttributeError): 

498 pass 

499 return "Cannot make fit report with %s" % repr(fit_result) 

500 

501 

502def confidence_intervals(fit_result, sigmas=(1, 2, 3), **kws): 

503 """calculate the confidence intervals from a fit 

504 for supplied sigma values 

505 

506 wrapper around lmfit.conf_interval 

507 """ 

508 fitter = getattr(fit_result, 'fitter', None) 

509 result = getattr(fit_result, 'fit_details', None) 

510 return conf_interval(fitter, result, sigmas=sigmas, **kws) 

511 

512def chi2_map(fit_result, xname, yname, nx=21, ny=21, sigma=3, **kws): 

513 """generate a confidence map for any two parameters for a fit 

514 

515 Arguments 

516 ========== 

517 minout output of minimize() fit (must be run first) 

518 xname name of variable parameter for x-axis 

519 yname name of variable parameter for y-axis 

520 nx number of steps in x [21] 

521 ny number of steps in y [21] 

522 sigma scale for uncertainty range [3] 

523 

524 Returns 

525 ======= 

526 xpts, ypts, map 

527 

528 Notes 

529 ===== 

530 1. sigma sets the extent of values to explore: 

531 param.value +/- sigma * param.stderr 

532 """ 

533 # 

534 fitter = getattr(fit_result, 'fitter', None) 

535 result = getattr(fit_result, 'fit_details', None) 

536 if fitter is None or result is None: 

537 raise ValueError("chi2_map needs valid fit result as first argument") 

538 

539 x = result.params[xname] 

540 y = result.params[yname] 

541 xrange = (x.value + sigma * x.stderr, x.value - sigma * x.stderr) 

542 yrange = (y.value + sigma * y.stderr, y.value - sigma * y.stderr) 

543 

544 return conf_interval2d(fitter, result, xname, yname, 

545 limits=(xrange, yrange), 

546 nx=nx, ny=ny, nsigma=2*sigma, **kws) 

547 

548_larch_name = '_math' 

549exports = {'param': param, 

550 'guess': guess, 

551 'param_group': param_group, 

552 'confidence_intervals': confidence_intervals, 

553 'confidence_report': confidence_report, 

554 'f_test': f_test, 'chi2_map': chi2_map, 

555 'is_param': isParameter, 

556 'isparam': isParameter, 

557 'minimize': minimize, 

558 'ufloat': ufloat, 

559 'fit_report': fit_report, 

560 'Parameters': Parameters, 

561 'Parameter': Parameter, 

562 'lm_minimize': minimize, 

563 'lm_save_model': save_model, 

564 'lm_load_model': load_model, 

565 'lm_save_modelresult': save_modelresult, 

566 'lm_load_modelresult': load_modelresult, 

567 } 

568 

569for name in ('BreitWignerModel', 'ComplexConstantModel', 

570 'ConstantModel', 'DampedHarmonicOscillatorModel', 

571 'DampedOscillatorModel', 'DoniachModel', 

572 'ExponentialGaussianModel', 'ExponentialModel', 

573 'ExpressionModel', 'GaussianModel', 'Interpreter', 

574 'LinearModel', 'LognormalModel', 'LorentzianModel', 

575 'MoffatModel', 'ParabolicModel', 'Pearson7Model', 

576 'PolynomialModel', 'PowerLawModel', 

577 'PseudoVoigtModel', 'QuadraticModel', 

578 'RectangleModel', 'SkewedGaussianModel', 

579 'StepModel', 'StudentsTModel', 'VoigtModel'): 

580 val = getattr(lmfit.models, name, None) 

581 if val is not None: 

582 exports[name] = val 

583 

584_larch_builtins = {'_math': exports}