Coverage for larch/wxxas/regress_panel.py: 0%
405 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-10-16 21:04 +0000
« prev ^ index » next coverage.py v7.6.0, created at 2024-10-16 21:04 +0000
1#!/usr/bin/env python
2"""
3Linear Combination panel
4"""
5import sys
6import time
7from pathlib import Path
8import wx
9import wx.grid as wxgrid
10import numpy as np
11import pickle
12import base64
13from copy import deepcopy
14from functools import partial
16from larch import Group
17from larch.math import index_of
18from larch.wxlib import (BitmapButton, TextCtrl, FloatCtrl, get_icon,
19 SimpleText, pack, Button, HLine, Choice, Check,
20 NumericCombo, CEN, LEFT, Font, FileSave, FileOpen,
21 DataTableGrid, Popup, FONTSIZE_FW, ExceptionPopup)
22from larch.io import save_groups, read_groups, read_csv
23from larch.utils.strutils import fix_varname
24from larch.utils import get_cwd, gformat
26from .taskpanel import TaskPanel
27from .config import Linear_ArrayChoices, Regress_Choices
29CSV_WILDCARDS = "CSV Files(*.csv,*.dat)|*.csv*;*.dat|All files (*.*)|*.*"
30MODEL_WILDCARDS = "Regression Model Files(*.regmod,*.dat)|*.regmod*;*.dat|All files (*.*)|*.*"
32Plot_Choices = ['Mean Spectrum + Active Energies',
33 'Spectra Stack',
34 'Predicted External Varliable']
36MAX_ROWS = 1000
38def make_steps(max=1, decades=8):
39 steps = [1.0]
40 for i in range(6):
41 steps.extend([(j*10**(-(1+i))) for j in (5, 2, 1)])
42 return steps
44class RegressionPanel(TaskPanel):
45 """Regression Panel"""
46 def __init__(self, parent, controller, **kws):
47 TaskPanel.__init__(self, parent, controller, panel='regression', **kws)
48 self.result = None
49 self.save_csvfile = 'RegressionData.csv'
50 self.save_modelfile = 'Model.regmod'
52 def process(self, dgroup, **kws):
53 """ handle processing"""
54 if self.skip_process:
55 return
56 self.skip_process = True
57 form = self.read_form()
59 def build_display(self):
60 panel = self.panel
61 wids = self.wids
62 self.skip_process = True
64 wids['fitspace'] = Choice(panel, choices=list(Linear_ArrayChoices.keys()),
65 action=self.onFitSpace, size=(175, -1))
66 wids['fitspace'].SetSelection(0)
67 # wids['plotchoice'] = Choice(panel, choices=Plot_Choices,
68 # size=(250, -1), action=self.onPlot)
70 wids['method'] = Choice(panel, choices=Regress_Choices, size=(250, -1),
71 action=self.onRegressMethod)
72 wids['method'].SetSelection(1)
73 add_text = self.add_text
75 opts = dict(digits=2, increment=1.0)
76 defaults = self.get_defaultconfig()
78 self.make_fit_xspace_widgets(elo=defaults['elo_rel'], ehi=defaults['ehi_rel'])
80 wids['alpha'] = NumericCombo(panel, make_steps(), fmt='%.6g',
81 default_val=0.01, width=100)
83 wids['auto_scale_pls'] = Check(panel, default=True, label='auto scale?')
84 wids['auto_alpha'] = Check(panel, default=False, label='auto alpha?')
86 wids['fit_intercept'] = Check(panel, default=True, label='fit intercept?')
88 wids['save_csv'] = Button(panel, 'Save CSV File', size=(150, -1),
89 action=self.onSaveCSV)
90 wids['load_csv'] = Button(panel, 'Load CSV File', size=(150, -1),
91 action=self.onLoadCSV)
93 wids['save_model'] = Button(panel, 'Save Model', size=(150, -1),
94 action=self.onSaveModel)
95 wids['save_model'].Disable()
97 wids['load_model'] = Button(panel, 'Load Model', size=(150, -1),
98 action=self.onLoadModel)
101 wids['train_model'] = Button(panel, 'Train Model From These Data',
102 size=(275, -1), action=self.onTrainModel)
104 wids['fit_group'] = Button(panel, 'Predict Variable for Selected Groups',
105 size=(275, -1), action=self.onPredictGroups)
106 wids['fit_group'].Disable()
109 w_cvfolds = self.add_floatspin('cv_folds', digits=0, with_pin=False,
110 value=0, increment=1, min_val=-1)
112 w_cvreps = self.add_floatspin('cv_repeats', digits=0, with_pin=False,
113 value=0, increment=1, min_val=-1)
115 w_ncomps = self.add_floatspin('ncomps', digits=0, with_pin=False,
116 value=3, increment=1, min_val=1)
118 wids['varname'] = wx.TextCtrl(panel, -1, 'valence', size=(150, -1))
119 wids['stat1'] = SimpleText(panel, ' - - - ')
120 wids['stat2'] = SimpleText(panel, ' - - - ')
123 collabels = [' File Group Name ', 'External Value',
124 'Predicted Value', 'Training?']
125 colsizes = [325, 110, 110, 90]
126 coltypes = ['str', 'float:12,4', 'float:12,4', 'str']
127 coldefs = ['', 0.0, 0.0, '']
129 self.font_fixedwidth = wx.Font(FONTSIZE_FW, wx.MODERN, wx.NORMAL, wx.NORMAL)
131 wids['table'] = DataTableGrid(panel, nrows=MAX_ROWS,
132 collabels=collabels,
133 datatypes=coltypes,
134 defaults=coldefs,
135 colsizes=colsizes)
136 wids['table'].SetMinSize((700, 225))
137 wids['table'].SetFont(self.font_fixedwidth)
139 wids['use_selected'] = Button(panel, 'Use Selected Groups',
140 size=(150, -1), action=self.onFillTable)
142 panel.Add(SimpleText(panel, 'Feature Regression, Model Selection',
143 size=(350, -1), **self.titleopts), style=LEFT, dcol=4)
145 add_text('Array to Use: ', newrow=True)
146 panel.Add(wids['fitspace'], dcol=4)
148 panel.Add(wids['fitspace_label'], newrow=True)
149 panel.Add(self.elo_wids)
150 add_text(' : ', newrow=False)
151 panel.Add(self.ehi_wids, dcol=3)
152 add_text('Regression Method:')
153 panel.Add(wids['method'], dcol=4)
154 add_text('PLS # components: ')
155 panel.Add(w_ncomps)
156 panel.Add(wids['auto_scale_pls'], dcol=2)
157 add_text('Lasso Alpha: ')
158 panel.Add(wids['alpha'])
159 panel.Add(wids['auto_alpha'], dcol=2)
160 panel.Add(wids['fit_intercept'])
162 add_text('Cross Validation: ')
163 add_text(' # folds, # repeats: ', newrow=False)
164 panel.Add(w_cvfolds, dcol=2)
165 panel.Add(w_cvreps)
167 panel.Add(HLine(panel, size=(600, 2)), dcol=6, newrow=True)
169 add_text('Build Model: ', newrow=True)
170 panel.Add(wids['use_selected'], dcol=2)
171 add_text('Attribute Name: ', newrow=False)
172 panel.Add(wids['varname'], dcol=4)
174 add_text('Read/Save Data: ', newrow=True)
175 panel.Add(wids['load_csv'], dcol=3)
176 panel.Add(wids['save_csv'], dcol=2)
178 panel.Add(wids['table'], newrow=True, dcol=5) # , drow=3)
180 panel.Add(HLine(panel, size=(550, 2)), dcol=5, newrow=True)
181 panel.Add((5, 5), newrow=True)
182 add_text('Train Model : ')
183 panel.Add(wids['train_model'], dcol=3)
184 panel.Add(wids['load_model'])
186 add_text('Use This Model : ')
187 panel.Add(wids['fit_group'], dcol=3)
188 panel.Add(wids['save_model'])
189 add_text('Fit Statistics : ')
190 panel.Add(wids['stat1'], dcol=4)
191 panel.Add((5, 5), newrow=True)
192 panel.Add(wids['stat2'], dcol=4)
193 panel.pack()
195 sizer = wx.BoxSizer(wx.VERTICAL)
196 sizer.Add((10, 10), 0, LEFT, 3)
197 sizer.Add(panel, 1, LEFT, 3)
198 pack(self, sizer)
199 self.onRegressMethod()
200 self.skip_process = False
202 def onRegressMethod(self, evt=None):
203 meth = self.wids['method'].GetStringSelection()
204 use_lasso = meth.lower().startswith('lasso')
205 self.wids['alpha'].Enable(use_lasso)
206 self.wids['auto_alpha'].Enable(use_lasso)
207 self.wids['fit_intercept'].Enable(use_lasso)
208 self.wids['auto_scale_pls'].Enable(not use_lasso)
209 self.wids['ncomps'].Enable(not use_lasso)
211 def onFitSpace(self, evt=None):
212 fitspace = self.wids['fitspace'].GetStringSelection()
213 self.update_config(dict(fitspace=fitspace))
214 arrname = Linear_ArrayChoices.get(fitspace, 'norm')
215 self.update_fit_xspace(arrname)
218 def fill_form(self, dgroup=None, opts=None):
219 conf = deepcopy(self.get_config(dgroup=dgroup, with_erange=True))
220 if opts is None:
221 opts = {}
222 conf.update(opts)
223 self.dgroup = dgroup
224 self.skip_process = True
225 wids = self.wids
227 for attr in ('fitspace','method'):
228 if attr in conf:
229 wids[attr].SetStringSelection(conf[attr])
231 for attr in ('elo', 'ehi', 'alpha', 'varname', 'cv_folds', 'cv_repeats'):
232 val = conf.get(attr, None)
233 if val is not None:
234 if attr == 'alpha':
235 if val < 0:
236 val = 0.001
237 conf['auto_alpha'] = True
238 val = '%.6g' % val
239 if attr in wids:
240 wids[attr].SetValue(val)
242 use_lasso = conf['method'].lower().startswith('lasso')
244 for attr in ('auto_alpha', 'fit_intercept','auto_scale_pls'):
245 val = conf.get(attr, True)
246 if attr == 'auto_scale_pls':
247 val = val and not use_lasso
248 else:
249 val = val and use_lasso
250 wids[attr].SetValue(val)
251 self.onRegressMethod()
253 self.skip_process = False
255 def read_form(self):
256 dgroup = self.controller.get_group()
257 form = {'groupname': getattr(dgroup, 'groupname', 'No Group')}
259 for k in ('fitspace', 'method'):
260 form[k] = self.wids[k].GetStringSelection()
262 for k in ('elo', 'ehi', 'alpha', 'cv_folds',
263 'cv_repeats', 'ncomps', 'varname'):
264 form[k] = self.wids[k].GetValue()
266 form['alpha'] = float(form['alpha'])
267 if form['alpha'] < 0:
268 form['alpha'] = 1.e-3
270 for k in ('auto_scale_pls', 'auto_alpha', 'fit_intercept'):
271 form[k] = self.wids[k].IsChecked()
273 mname = form['method'].lower()
274 form['use_lars'] = 'lars' in mname
275 form['funcname'] = 'pls'
276 if mname.startswith('lasso'):
277 form['funcname'] = 'lasso'
278 if form['auto_alpha']:
279 form['alpha'] = None
281 return form
284 def onFillTable(self, event=None):
285 selected_groups = self.controller.filelist.GetCheckedStrings()
286 varname = fix_varname(self.wids['varname'].GetValue())
287 predname = varname + '_predicted'
288 grid_data = []
289 for fname in self.controller.filelist.GetCheckedStrings():
290 gname = self.controller.file_groups[fname]
291 grp = self.controller.get_group(gname)
292 grid_data.append([fname, getattr(grp, varname, 0.0),
293 getattr(grp, predname, 0.0), 'Yes'])
295 self.wids['table'].table.data = grid_data
296 self.wids['table'].table.View.Refresh()
298 def onTrainModel(self, event=None):
299 form = self.read_form()
300 self.update_config(form)
301 varname = form['varname']
302 predname = varname + '_predicted'
304 grid_data = self.wids['table'].table.data
305 groups = []
306 for fname, yval, pval, istrain in grid_data:
307 gname = self.controller.file_groups[fname]
308 grp = self.controller.get_group(gname)
309 setattr(grp, varname, yval)
310 setattr(grp, predname, pval)
311 groups.append(gname)
313 cmds = ['# train linear regression model',
314 'training_groups = [%s]' % ', '.join(groups)]
316 copts = ["varname='%s'" % varname, "xmin=%.4f" % form['elo'],
317 "xmax=%.4f" % form['ehi']]
319 arrname = Linear_ArrayChoices.get(form['fitspace'], 'norm')
320 copts.append("arrayname='%s'" % arrname)
322 if form['method'].lower().startswith('lasso'):
323 if form['auto_alpha']:
324 copts.append('alpha=None')
325 else:
326 copts.append('alpha=%.6g' % form['alpha'])
327 copts.append('use_lars=%s' % repr('lars' in form['method'].lower()))
328 copts.append('fit_intercept=%s' % repr(form['fit_intercept']))
329 else:
330 copts.append('ncomps=%d' % form['ncomps'])
331 copts.append('scale=%s' % repr(form['auto_scale_pls']))
333 callargs = ', '.join(copts)
335 cmds.append("reg_model = %s_train(training_groups, %s)" %
336 (form['funcname'], callargs))
338 self.larch_eval('\n'.join(cmds))
339 reg_model = self.larch_get('reg_model')
340 reg_model.form = form
341 self.use_regmodel(reg_model)
343 def use_regmodel(self, reg_model):
344 if reg_model is None:
345 return
346 opts = self.read_form()
348 if hasattr(reg_model, 'form'):
349 opts.update(reg_model.form)
351 self.write_message('Regression Model trained: %s' % opts['method'])
352 rmse_cv = reg_model.rmse_cv
353 if rmse_cv is not None:
354 rmse_cv = "%.4f" % rmse_cv
355 stat = "RMSE_CV = %s, RMSE = %.4f" % (rmse_cv, reg_model.rmse)
356 self.wids['stat1'].SetLabel(stat)
357 if opts['funcname'].startswith('lasso'):
358 stat = "Alpha = %.4f, %d active components"
359 self.wids['stat2'].SetLabel(stat % (reg_model.alpha,
360 len(reg_model.active)))
362 if opts['auto_alpha']:
363 self.wids['alpha'].add_choice(reg_model.alpha)
365 else:
366 self.wids['stat2'].SetLabel('- - - ')
367 training_groups = reg_model.groupnames
368 ntrain = len(training_groups)
369 grid_data = self.wids['table'].table.data
370 grid_new = []
371 for i in range(ntrain): # min(ntrain, len(grid_data))):
372 fname = training_groups[i]
373 istrain = 'Yes' if fname in training_groups else 'No'
374 grid_new.append( [fname, reg_model.ydat[i], reg_model.ypred[i], istrain])
375 self.wids['table'].table.data = grid_new
376 self.wids['table'].table.View.Refresh()
378 if reg_model.cv_folds not in (0, None):
379 self.wids['cv_folds'].SetValue(reg_model.cv_folds)
380 if reg_model.cv_repeats not in (0, None):
381 self.wids['cv_repeats'].SetValue(reg_model.cv_repeats)
383 self.wids['save_model'].Enable()
384 self.wids['fit_group'].Enable()
386 wx.CallAfter(self.onPlotModel, model=reg_model)
388 def onPanelExposed(self, **kws):
389 # called when notebook is selected
390 try:
391 fname = self.controller.filelist.GetStringSelection()
392 gname = self.controller.file_groups[fname]
393 dgroup = self.controller.get_group(gname)
394 self.ensure_xas_processed(dgroup)
395 self.fill_form(dgroup)
396 except:
397 pass # print(" Cannot Fill prepeak panel from group ")
399 reg_model = getattr(self.larch.symtable, 'reg_model', None)
400 if reg_model is not None:
401 self.use_regmodel(reg_model)
404 def onPredictGroups(self, event=None):
405 opts = self.read_form()
406 varname = opts['varname'] + '_predicted'
408 reg_model = self.larch_get('reg_model')
409 training_groups = reg_model.groupnames
411 grid_data = self.wids['table'].table.data
413 gent = {}
414 if len(grid_data[0][0].strip()) == 0:
415 grid_data = []
416 else:
417 for i, row in enumerate(grid_data):
418 gent[row[0]] = i
420 for fname in self.controller.filelist.GetCheckedStrings():
421 gname = self.controller.file_groups[fname]
422 grp = self.controller.get_group(gname)
423 extval = getattr(grp, opts['varname'], 0)
424 cmd = "%s.%s = %s_predict(%s, reg_model)" % (gname, varname,
425 opts['funcname'], gname)
426 self.larch_eval(cmd)
427 val = self.larch_get('%s.%s' % (gname, varname))
428 if fname in gent:
429 grid_data[gent[fname]][2] = val
430 else:
431 istrain = 'Yes' if fname in training_groups else 'No'
432 grid_data.append([fname, extval, val, istrain])
433 self.wids['table'].table.data = grid_data
434 self.wids['table'].table.View.Refresh()
436 def onSaveModel(self, event=None):
437 try:
438 reg_model = self.larch_get('reg_model')
439 except:
440 title = "No regresion model to save"
441 message = [f"Cannot get regression model to save"]
442 ExceptionPopup(self, title, message)
443 return
445 fname = FileSave(self, "Save Regression Model",
446 defaultDir=get_cwd(),
447 defaultFile=self.save_modelfile,
448 wildcard=MODEL_WILDCARDS)
450 if fname is None:
451 return
452 save_groups(fname, ['#regression model 1.0', reg_model])
453 self.write_message('Wrote Regression Model to %s ' % fname)
455 def onLoadModel(self, event=None):
456 fname = FileOpen(self, "Load Regression Model",
457 defaultDir=get_cwd(),
458 wildcard=MODEL_WILDCARDS)
460 if fname is None:
461 return
462 dat = read_groups(fname)
463 if len(dat) != 2 or not dat[0].startswith('#regression model'):
464 Popup(self, f" '{rfile}' is not a valid Regression model file",
465 "Invalid file")
467 reg_model = dat[1]
468 self.controller.symtable.reg_model = reg_model
470 self.write_message('Read Regression Model from %s ' % fname)
471 self.wids['fit_group'].Enable()
473 self.use_regmodel(reg_model)
475 def onLoadCSV(self, event=None):
476 fname = FileOpen(self, "Load CSV Data File",
477 defaultDir=get_cwd(),
478 wildcard=CSV_WILDCARDS)
479 if fname is None:
480 return
482 self.save_csvfile = Path(fname).name
483 varname = fix_varname(self.wids['varname'].GetValue())
484 csvgroup = read_csv(fname)
485 script = []
486 grid_data = []
487 for sname, yval in zip(csvgroup.col_01, csvgroup.col_02):
488 if sname.startswith('#'):
489 continue
490 if sname in self.controller.file_groups:
491 gname = self.controller.file_groups[sname]
492 script.append('%s.%s = %f' % (gname, varname, yval))
493 grid_data.append([sname, yval, 0])
495 self.larch_eval('\n'.join(script))
496 self.wids['table'].table.data = grid_data
497 self.wids['table'].table.View.Refresh()
498 self.write_message('Read CSV File %s ' % fname)
500 def onSaveCSV(self, event=None):
501 wildcard = 'CSV file (*.csv)|*.csv|All files (*.*)|*.*'
502 fname = FileSave(self, message='Save CSV Data File',
503 wildcard=wildcard,
504 default_file=self.save_csvfile)
505 if fname is None:
506 return
507 self.save_csvfile = Path(fname).name
508 buff = []
509 for row in self.wids['table'].table.data:
510 buff.append("%s, %s, %s" % (row[0], gformat(row[1]), gformat(row[2])))
511 buff.append('')
512 with open(fname, 'w', encoding=sys.getdefaultencoding()) as fh:
513 fh.write('\n'.join(buff))
514 self.write_message('Wrote CSV File %s ' % fname)
516 def onPlotModel(self, event=None, model=None):
517 opts = self.read_form()
518 if model is None:
519 return
520 opts.update(model.form)
522 ppanel = self.controller.get_display(win=1).panel
523 viewlims = ppanel.get_viewlimits()
524 plotcmd = ppanel.plot
526 d_ave = model.spectra.mean(axis=0)
527 d_std = model.spectra.std(axis=0)
528 ymin, ymax = (d_ave-d_std).min(), (d_ave+d_std).max()
530 if opts['funcname'].startswith('lasso'):
531 active = [int(i) for i in model.active]
532 active_coefs = (model.coefs[active])
533 active_coefs = active_coefs/max(abs(active_coefs))
534 ymin = min(active_coefs.min(), ymin)
535 ymax = max(active_coefs.max(), ymax)
537 else:
538 ymin = min(model.coefs.min(), ymin)
539 ymax = max(model.coefs.max(), ymax)
541 ymin = ymin - 0.02*(ymax-ymin)
542 ymax = ymax + 0.02*(ymax-ymin)
545 title = '%s Regression results' % (opts['method'])
547 ppanel.plot(model.x, d_ave, win=1, title=title,
548 label='mean spectra', xlabel='Energy (eV)',
549 ylabel=opts['fitspace'], show_legend=True,
550 ymin=ymin, ymax=ymax)
551 ppanel.axes.fill_between(model.x, d_ave-d_std, d_ave+d_std,
552 color='#1f77b433')
553 if opts['funcname'].startswith('lasso'):
554 ppanel.axes.bar(model.x[active], active_coefs,
555 1.0, color='#9f9f9f88',
556 label='coefficients')
557 else:
558 _, ncomps = model.coefs.shape
559 for i in range(ncomps):
560 ppanel.oplot(model.x, model.coefs[:, i], label='coef %d' % (i+1))
562 ppanel.canvas.draw()
564 ngoups = len(model.groupnames)
565 indices = np.arange(len(model.groupnames))
566 diff = model.ydat - model.ypred
567 sx = np.argsort(model.ydat)
569 ppanel = self.controller.get_display(win=2).panel
571 ppanel.plot(model.ydat[sx], indices, xlabel='valence',
572 label='experimental', linewidth=0, marker='o',
573 markersize=8, win=2, new=True, title=title)
575 ppanel.oplot(model.ypred[sx], indices, label='predicted',
576 labelxsxfontsize=7, markersize=6, marker='o',
577 linewidth=0, show_legend=True, new=False)
579 ppanel.axes.barh(indices, diff[sx], 0.5, color='#9f9f9f88')
580 ppanel.axes.set_yticks(indices)
581 ppanel.axes.set_yticklabels([model.groupnames[o] for o in sx])
582 ppanel.conf.auto_margins = False
583 ppanel.conf.set_margins(left=0.35, right=0.05, bottom=0.15, top=0.1)
584 ppanel.canvas.draw()
585 self.controller.set_focus()
588 def onCopyParam(self, name=None, evt=None):
589 conf = self.get_config()