Coverage for larch/io/athena_project.py: 9%

621 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-10-16 21:04 +0000

1#!/usr/bin/env python 

2""" 

3Code to read and write Athena Project files 

4 

5""" 

6import io 

7import sys 

8import time 

9import json 

10import platform 

11from pathlib import Path 

12from fnmatch import fnmatch 

13from gzip import GzipFile 

14from copy import deepcopy 

15import numpy as np 

16from numpy.random import randint 

17 

18from larch import Group, repr_value 

19from larch import __version__ as larch_version 

20from larch.utils import bytes2str, str2bytes, fix_varname, asfloat, unixpath 

21 

22from xraydb import guess_edge 

23import asteval 

24 

25hexopen = '\\x{' 

26hexclose = '}' 

27 

28alist2json = str.maketrans("();'\n", "[] \" ") 

29 

30def plarray2json(text): 

31 return json.loads(text.split('=', 1)[1].strip().translate(alist2json)) 

32 

33def parse_arglist(text): 

34 txt = text.split('=', 1)[1].strip() 

35 if txt.endswith(';'): 

36 txt = txt[:-1] 

37 return json.loads(txt.translate(alist2json)) 

38 

39 

40 

41ERR_MSG = "Error reading Athena Project File" 

42 

43 

44def _read_raw_athena(filename): 

45 """try to read athena project file as plain text, 

46 to determine validity 

47 """ 

48 # try gzip 

49 text = None 

50 try: 

51 fh = GzipFile(unixpath(filename)) 

52 text = bytes2str(fh.read()) 

53 except Exception: 

54 errtype, errval, errtb = sys.exc_info() 

55 text = None 

56 

57 if text is None: 

58 # try plain text file 

59 try: 

60 fh = open(filename, 'r') 

61 text = bytes2str(fh.read()) 

62 except Exception: 

63 errtype, errval, errtb = sys.exc_info() 

64 text = None 

65 

66 return text 

67 

68 

69def _test_athena_text(text): 

70 return "Athena project file -- " in text[:500] 

71 

72 

73def is_athena_project(filename): 

74 """tests whether file is a valid Athena Project file""" 

75 text = _read_raw_athena(filename) 

76 if text is None: 

77 return False 

78 return _test_athena_text(text) 

79 

80 

81def make_hashkey(length=5): 

82 """generate an 'athena hash key': 5 random lower-case letters 

83 """ 

84 return ''.join([chr(randint(97, 122)) for i in range(length)]) 

85 

86def make_athena_args(group, hashkey=None, **kws): 

87 """make athena args line from a group""" 

88 # start with default args: 

89 from larch.xafs.xafsutils import etok 

90 

91 if hashkey is None: 

92 hashkey = make_hashkey() 

93 args = {} 

94 for k, v in (('annotation', ''), 

95 ('beamline', ''), 

96 ('beamline_identified', '0'), ('bft_dr', '0.0'), 

97 ('bft_rmax', '3'), ('bft_rmin', '1'), 

98 ('bft_rwindow', 'hanning'), ('bkg_algorithm', 'autobk'), 

99 ('bkg_cl', '0'), ('bkg_clamp1', '0'), ('bkg_clamp2', '24'), 

100 ('bkg_delta_eshift', '0'), ('bkg_dk', '1'), 

101 ('bkg_e0_fraction', '0.5'), ('bkg_eshift', '0'), 

102 ('bkg_fixstep', '0'), ('bkg_flatten', '1'), 

103 ('bkg_former_e0', '0'), ('bkg_funnorm', '0'), 

104 ('bkg_int', '7.'), ('bkg_kw', '1'), 

105 ('bkg_kwindow', 'hanning'), ('bkg_nclamp', '5'), 

106 ('bkg_rbkg', '1.0'), ('bkg_slope', '-0.0'), 

107 ('bkg_stan', 'None'), ('bkg_tie_e0', '0'), 

108 ('bkg_nc0', '0'), ('bkg_nc1', '0'), 

109 ('bkg_nc2', '0'), ('bkg_nc3', '0'), 

110 ('bkg_rbkg', '1.0'), ('bkg_slope', '0'), 

111 ('bkg_pre1', '-150'), ('bkg_pre2', '-30'), 

112 ('bkg_nor1', '150'), ('bkg_nor2', '800'), 

113 ('bkg_nnorm', '1'), ('bkg_nvict', '0'), 

114 ('prjrecord', 'athena.prj, 1'), ('chi_column', ''), 

115 ('chi_string', ''), ('collided', '0'), ('columns', ''), 

116 ('daq', ''), ('denominator', '1'), ('display', '0'), 

117 ('energy', ''), ('energy_string', ''), ('epsk', ''), 

118 ('epsr', ''), ('fft_dk', '4'), ('fft_edge', 'k'), 

119 ('fft_kmax', '15.'), ('fft_kmin', '2.00'), 

120 ('fft_kwindow', 'kaiser-bessel'), ('fft_pc', '0'), 

121 ('fft_pcpathgroup', ''), ('fft_pctype', 'central'), 

122 ('forcekey', '0'), ('from_athena', '1'), 

123 ('from_yaml', '0'), ('frozen', '0'), ('generated', '0'), 

124 ('i0_scale', '1'), ('i0_string', '1'), 

125 ('importance', '1'), ('inv', '0'), ('is_col', '1'), 

126 ('is_fit', '0'), ('is_kev', '0'), ('is_merge', ''), 

127 ('is_nor', '0'), ('is_pixel', '0'), ('is_special', '0'), 

128 ('is_xmu', '1'), ('ln', '0'), ('mark', '0'), 

129 ('marked', '0'), ('maxk', '15'), ('merge_weight', '1'), 

130 ('multiplier', '1'), ('nidp', '5'), ('nknots', '4'), 

131 ('numerator', ''), ('plot_scale', '1'), 

132 ('plot_yoffset', '0'), ('plotkey', ''), 

133 ('plotspaces', 'any'), ('provenance', ''), 

134 ('quenched', '0'), ('quickmerge', '0'), 

135 ('read_as_raw', '0'), ('rebinned', '0'), 

136 ('recommended_kmax', '1'), ('recordtype', 'mu(E)'), 

137 ('referencegroup', ''), ('rmax_out', '10'), 

138 ('signal_scale', '1'), ('signal_string', '-1'), 

139 ('trouble', ''), ('tying', '0'), 

140 ('unreadable', '0'), ('update_bft', '1'), 

141 ('update_bkg', '1'), ('update_columns', '0'), 

142 ('update_data', '0'), ('update_fft', '1'), 

143 ('update_norm', '1'), ('xdi_will_be_cloned', '0'), 

144 ('xdifile', ''), ('xmu_string', ''), 

145 ('valence', ''), ('lasso_yvalue', ''), 

146 ('atsym', ''), ('edge', '') ): 

147 args[k] = v 

148 

149 args['datagroup'] = args['tag'] = args['label'] = hashkey 

150 en = getattr(group, 'energy', []) 

151 args['npts'] = len(en) 

152 if len(en) > 0: 

153 args['xmin'] = '%.1f' % min(en) 

154 args['xmax'] = '%.1f' % max(en) 

155 

156 main_map = dict(source='filename', file='filename', label='filename', 

157 bkg_e0='e0', bkg_step='edge_step', 

158 bkg_fitted_step='edge_step', valence='valence', 

159 lasso_yvalue='lasso_yvalue', atsym='atsym', 

160 edge='edge') 

161 

162 for aname, lname in main_map.items(): 

163 val = getattr(group, lname, None) 

164 if val is not None: 

165 args[aname] = val 

166 

167 bkg_map = dict(nnorm='nnorm', nor1='norm1', nor2='norm2', pre1='pre1', 

168 pre2='pre2', nvict='nvict') 

169 

170 if hasattr(group, 'pre_edge_details'): 

171 for aname, lname in bkg_map.items(): 

172 val = getattr(group.pre_edge_details, lname, None) 

173 if val is not None: 

174 args['bkg_%s' % aname] = val 

175 

176 emax = max(group.energy) - group.e0 

177 args['bkg_spl1e'] = '0' 

178 args['bkg_spl2e'] = '%.5f' % emax 

179 args['bkg_spl1'] = '0' 

180 args['bkg_spl2'] = '%.5f' % etok(emax) 

181 args['bkg_eshift'] = getattr(group, 'energy_shift', 0.0) 

182 

183 autobk_details = getattr(group, 'autobk_details', None) 

184 autobk_args = getattr(autobk_details, 'call_args', None) 

185 if autobk_args is not None: 

186 args['bkg_rbkg'] = autobk_args['rbkg'] 

187 args['bkg_spl1'] = autobk_args['kmin'] 

188 args['bkg_spl2'] = autobk_args['kmax'] 

189 args['bkg_kw'] = autobk_args['kweight'] 

190 args['bkg_dk'] = autobk_args['dk'] 

191 args['bkg_kwindow'] = autobk_args['win'] 

192 args['bkg_nclamp'] = autobk_args['nclamp'] 

193 args['bkg_clamp1'] = autobk_args['clamp_lo'] 

194 args['bkg_clamp2'] = autobk_args['clamp_hi'] 

195 

196 xftf_details = getattr(group, 'xftf_details', None) 

197 xftf_args = getattr(xftf_details, 'call_args', None) 

198 if xftf_args is not None: 

199 args['fft_kmin'] = xftf_args['kmin'] 

200 args['fft_kmax'] = xftf_args['kmax'] 

201 args['fft_kw'] = xftf_args['kweight'] 

202 args['fft_dk'] = xftf_args['dk'] 

203 args['fft_kwindow'] = xftf_args['window'] 

204 args.update(kws) 

205 return args 

206 

207 

208def athena_array(group, arrname): 

209 """convert ndarray to athena representation""" 

210 arr = getattr(group, arrname, None) 

211 if arr is None: 

212 return None 

213 return arr # json.dumps([repr(i) for i in arr]) 

214 # return "(%s)" % ','.join(["'%s'" % i for i in arr]) 

215 

216 

217def format_dict(d): 

218 """ format dictionary for Athena Project file""" 

219 o = [] 

220 for key in sorted(d.keys()): 

221 o.append("'%s'" % key) 

222 val = d[key] 

223 if val is None: val = '' 

224 o.append("'%s'" % val) 

225 return ','.join(o) 

226 

227def format_array(arr): 

228 """ format dictionary for Athena Project file""" 

229 o = ["'%s'" % v for v in arr] 

230 return ','.join(o) 

231 

232def clean_bkg_params(grp): 

233 grp.nnorm = getattr(grp, 'nnorm', 2) 

234 grp.nvict = getattr(grp, 'nvict', 0) 

235 grp.e0 = getattr(grp, 'e0', -1) 

236 grp.rbkg = getattr(grp, 'rbkg', 1) 

237 grp.pre1 = getattr(grp, 'pre1', -150) 

238 grp.pre2 = getattr(grp, 'pre2', -25) 

239 grp.nor1 = getattr(grp, 'nor1', 100) 

240 grp.nor2 = getattr(grp, 'nor2', 1200) 

241 grp.spl1 = getattr(grp, 'spl1', 0) 

242 grp.spl2 = getattr(grp, 'spl2', 30) 

243 grp.kw = getattr(grp, 'kw', 1) 

244 grp.dk = getattr(grp, 'dk', 3) 

245 grp.flatten = getattr(grp, 'flatten', 0) 

246 if getattr(grp, 'kwindow', None) is None: 

247 grp.kwindow = getattr(grp, 'win', 'hanning') 

248 

249 try: 

250 grp.clamp1 = float(grp.clamp1) 

251 except Exception: 

252 grp.clamp1 = 1 

253 try: 

254 grp.clamp2 = float(grp.clamp2) 

255 except Exception: 

256 grp.clamp2 = 1 

257 

258 return grp 

259 

260 

261def clean_fft_params(grp): 

262 grp.kmin = getattr(grp, 'kmin', 0) 

263 grp.kmax = getattr(grp, 'kmax', 25) 

264 grp.kweight = getattr(grp, 'kweight', 2) 

265 grp.dk = getattr(grp, 'dk', 3) 

266 grp.kwindow = getattr(grp, 'kwindow', 'hanning') 

267 return grp 

268 

269 

270def text2list(text): 

271 key, txt = [a.strip() for a in text.split('=', 1)] 

272 if txt.endswith('\n'): 

273 txt = txt[:-1] 

274 if txt.endswith(';'): 

275 txt = txt[:-1] 

276 txt = txt.replace('=>', ':').replace('\n', ' ').replace('\r', ' ').replace('\t', ' ') 

277 # re-cast unicode stored by perl (\x{e34} -> 0xe4) 

278 if hexopen in txt: 

279 w = [] 

280 k = 0 

281 for i in range(len(txt)-3): 

282 if txt[i:i+3] == hexopen: 

283 j = txt[i:i+8].find(hexclose) 

284 if j > 0: 

285 w.extend((txt[k:i], chr(int('0x' + txt[i+3:i+j], base=16)))) 

286 k = i+j+1 

287 w.append(txt[k:]) 

288 txt = ''.join(w) 

289 return txt 

290 

291 

292def parse_perlathena(text, filename): 

293 """ 

294 parse old athena file format to Group of Groups 

295 """ 

296 aout = io.StringIO() 

297 aeval = asteval.Interpreter(minimal=True, writer=aout, err_writer=aout, 

298 max_statement_length=12543000) 

299 

300 lines = text.split('\n') 

301 athenagroups = [] 

302 raw = {'name':''} 

303 vline = lines.pop(0) 

304 if "Athena project file -- " not in vline: 

305 raise ValueError("%s '%s': invalid Athena File" % (ERR_MSG, filename)) 

306 major, minor, fix = '0', '0', '0' 

307 if 'Demeter' in vline: 

308 try: 

309 vs = vline.split("Athena project file -- Demeter version")[1] 

310 major, minor, fix = vs.split('.') 

311 except: 

312 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

313 else: 

314 try: 

315 vs = vline.split("Athena project file -- Athena version")[1] 

316 major, minor, fix = vs.split('.') 

317 except: 

318 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

319 

320 header = [vline] 

321 journal = [''] 

322 is_header = True 

323 ix = 0 

324 for t in lines: 

325 ix += 1 

326 if t.startswith('#') or len(t) < 2 or 'undef' in t: 

327 if is_header: 

328 header.append(t) 

329 continue 

330 is_header = False 

331 key = t.split()[0].strip() 

332 key = key.replace('$', '').replace('@', '').replace('%', '').strip() 

333 if key == 'old_group': 

334 raw['name'] = aeval(text2list(t)) 

335 elif key == '[record]': 

336 athenagroups.append(raw) 

337 raw = {'name':''} 

338 elif key == 'journal': 

339 try: 

340 journal = aeval(text2list(t)) 

341 except ValueError: 

342 pass 

343 if len(aeval.error) > 0: 

344 print(f" warning: may not read journal from '{filename:s}' completely") 

345 journal = text2list(t) 

346 

347 elif key == 'args': 

348 raw['args'] = aeval(text2list(t)) 

349 elif key == 'xdi': 

350 raw['xdi'] = t 

351 elif key in ('x', 'y', 'i0', 'signal', 'stddev'): 

352 raw[key] = np.array([float(x) for x in aeval(text2list(t))]) 

353 elif key in ('1;', 'indicator', 'lcf_data', 'plot_features'): 

354 pass 

355 else: 

356 print(" do not know what to do with key '%s' at '%s'" % (key, raw['name'])) 

357 

358 out = Group() 

359 out.__doc__ = """XAFS Data from Athena Project File %s""" % (filename) 

360 out.journal = '\n'.join(journal) 

361 out.group_names = [] 

362 out.header = '\n'.join(header) 

363 for dat in athenagroups: 

364 label = dat.get('name', 'unknown') 

365 this = Group(energy=dat['x'], mu=dat['y'], 

366 athena_params=Group(id=label, bkg=Group(), fft=Group())) 

367 

368 if 'i0' in dat: 

369 this.i0 = dat['i0'] 

370 if 'signal' in dat: 

371 this.signal = dat['signal'] 

372 if 'stddev' in dat: 

373 this.stddev = dat['stddev'] 

374 if 'args' in dat: 

375 for i in range(len(dat['args'])//2): 

376 key = dat['args'][2*i] 

377 val = dat['args'][2*i+1] 

378 if key.startswith('bkg_'): 

379 setattr(this.athena_params.bkg, key[4:], asfloat(val)) 

380 elif key.startswith('fft_'): 

381 setattr(this.athena_params.fft, key[4:], asfloat(val)) 

382 elif key == 'label': 

383 label = this.label = val 

384 elif key in ('valence', 'lasso_yvalue', 'epsk', 'epsr'): 

385 setattr(this, key, asfloat(val)) 

386 elif key in ('atsym', 'edge'): 

387 setattr(this, key, val) 

388 else: 

389 setattr(this.athena_params, key, asfloat(val)) 

390 this.__doc__ = """Athena Group Name %s (key='%s')""" % (label, dat['name']) 

391 if label.startswith(' '): 

392 label = 'd_' + label.strip() 

393 name = fix_varname(label) 

394 if name.startswith('_'): 

395 name = 'd' + name 

396 setattr(out, name, this) 

397 out.group_names.append(name) 

398 return out 

399 

400 

401def parse_perlathena_old(text, filename): 

402 """ 

403 parse old athena file format to Group of Groups 

404 """ 

405 lines = text.split('\n') 

406 athenagroups = [] 

407 raw = {'name':''} 

408 vline = lines.pop(0) 

409 if "Athena project file -- " not in vline: 

410 raise ValueError("%s '%s': invalid Athena File" % (ERR_MSG, filename)) 

411 major, minor, fix = '0', '0', '0' 

412 if 'Demeter' in vline: 

413 try: 

414 vs = vline.split("Athena project file -- Demeter version")[1] 

415 major, minor, fix = vs.split('.') 

416 except: 

417 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

418 else: 

419 try: 

420 vs = vline.split("Athena project file -- Athena version")[1] 

421 major, minor, fix = vs.split('.') 

422 except: 

423 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

424 

425 header = [vline] 

426 journal = [''] 

427 is_header = True 

428 for t in lines: 

429 if t.startswith('#') or len(t) < 2 or 'undef' in t: 

430 if is_header: 

431 header.append(t) 

432 continue 

433 is_header = False 

434 key = t.split()[0].strip() 

435 key = key.replace('$', '').replace('@', '').replace('%', '').strip() 

436 if key == 'old_group': 

437 raw['name'] = plarray2json(t) 

438 elif key == '[record]': 

439 athenagroups.append(raw) 

440 raw = {'name':''} 

441 elif key == 'journal': 

442 journal = parse_arglist(t) 

443 elif key == 'args': 

444 raw['args'] = parse_arglist(t) 

445 elif key == 'xdi': 

446 raw['xdi'] = t 

447 elif key in ('x', 'y', 'i0', 'signal', 'stddev'): 

448 raw[key] = np.array([float(x) for x in plarray2json(t)]) 

449 elif key in ('1;', 'indicator', 'lcf_data', 'plot_features'): 

450 pass 

451 else: 

452 print(" do not know what to do with key '%s' at '%s'" % (key, raw['name'])) 

453 

454 out = Group() 

455 out.__doc__ = """XAFS Data from Athena Project File %s""" % (filename) 

456 out.journal = '\n'.join(journal) 

457 out.group_names = [] 

458 out.header = '\n'.join(header) 

459 for dat in athenagroups: 

460 label = dat.get('name', 'unknown') 

461 this = Group(energy=dat['x'], mu=dat['y'], 

462 athena_params=Group(id=label, bkg=Group(), fft=Group())) 

463 

464 if 'i0' in dat: 

465 this.i0 = dat['i0'] 

466 if 'signal' in dat: 

467 this.signal = dat['signal'] 

468 if 'stddev' in dat: 

469 this.stddev = dat['stddev'] 

470 if 'args' in dat: 

471 for i in range(len(dat['args'])//2): 

472 key = dat['args'][2*i] 

473 val = dat['args'][2*i+1] 

474 if key.startswith('bkg_'): 

475 setattr(this.athena_params.bkg, key[4:], asfloat(val)) 

476 elif key.startswith('fft_'): 

477 setattr(this.athena_params.fft, key[4:], asfloat(val)) 

478 elif key == 'label': 

479 label = this.label = val 

480 elif key in ('valence', 'lasso_yvalue', 'epsk', 'epsr'): 

481 setattr(this, key, asfloat(val)) 

482 elif key in ('atsym', 'edge'): 

483 setattr(this, key, val) 

484 else: 

485 setattr(this.athena_params, key, asfloat(val)) 

486 this.__doc__ = """Athena Group Name %s (key='%s')""" % (label, dat['name']) 

487 name = fix_varname(label) 

488 if name.startswith('_'): 

489 name = 'd' + name 

490 setattr(out, name, this) 

491 out.group_names.append(name) 

492 

493 return out 

494 

495 

496def parse_jsonathena(text, filename): 

497 """parse a JSON-style athena file""" 

498 jsdict = json.loads(text) 

499 

500 out = Group() 

501 out.__doc__ = """XAFS Data from Athena Project File %s""" % (filename) 

502 

503 header = [] 

504 athena_names = [] 

505 for key, val in jsdict.items(): 

506 if key.startswith('_____head'): 

507 header.append(val) 

508 elif key.startswith('_____journ'): 

509 journal = val 

510 elif key.startswith('_____order'): 

511 athena_names = val 

512 

513 out.journal = journal 

514 out.header = '\n'.join(header) 

515 out.group_names = [] 

516 for name in athena_names: 

517 label = name 

518 dat = jsdict[name] 

519 x = np.array(dat['x'], dtype='float64') 

520 y = np.array(dat['y'], dtype='float64') 

521 this = Group(energy=x, mu=y, 

522 athena_params=Group(id=name, bkg=Group(), fft=Group())) 

523 

524 if 'i0' in dat: 

525 this.i0 = np.array(dat['i0'], dtype='float64') 

526 if 'signal' in dat: 

527 this.signal = np.array(dat['signal'], dtype='float64') 

528 if 'stddev' in dat: 

529 this.stddev = np.array(dat['stddev'], dtype='float64') 

530 if 'args' in dat: 

531 for key, val in dat['args'].items(): 

532 if key.startswith('bkg_'): 

533 setattr(this.athena_params.bkg, key[4:], asfloat(val)) 

534 elif key.startswith('fft_'): 

535 setattr(this.athena_params.fft, key[4:], asfloat(val)) 

536 elif key == 'label': 

537 label = this.label = val 

538 elif key in ('valence', 'lasso_yvalue', 'epsk', 'epsr'): 

539 setattr(this, key, asfloat(val)) 

540 elif key in ('atsym', 'edge'): 

541 setattr(this, key, val) 

542 else: 

543 setattr(this.athena_params, key, asfloat(val)) 

544 this.__doc__ = """Athena Group Name %s (key='%s')""" % (label, name) 

545 name = fix_varname(label) 

546 if name.startswith('_'): 

547 name = 'd' + name 

548 setattr(out, name, this) 

549 out.group_names.append(name) 

550 return out 

551 

552 

553class AthenaGroup(Group): 

554 """A special Group for handling datasets loaded from Athena project files""" 

555 

556 def __init__(self, **kws): 

557 """Constructor 

558 

559 Parameters 

560 ---------- 

561 

562 show_sel : boolean, False 

563 if True, it shows the selection flag in HTML representation 

564 """ 

565 super().__init__(**kws) 

566 

567 def _repr_html_(self): 

568 """HTML representation for Jupyter notebook""" 

569 html = ["<table><tr><td><b>Group Name</b></td>", 

570 "<td><b>Label/File Name</b></td>", 

571 "<td><b>Selected</b></td></tr>"] 

572 for name, grp in self.groups.items(): 

573 try: 

574 if grp.sel == 1: 

575 sel = "\u2714" 

576 else: 

577 sel = "" 

578 except AttributeError: 

579 sel = "" 

580 fname = getattr(grp, 'filename', getattr(grp, 'label', 'unkownn')) 

581 html.append(f"<tr><td>{name}</td><td>{fname}</td><td>{sel}</td></tr>") 

582 html.append("</table>") 

583 return '\n'.join(html) 

584 

585 

586 def __getitem__(self, key): 

587 if isinstance(key, int): 

588 raise IndexError("AthenaGroup does not support integer indexing") 

589 

590 return getattr(self, key) 

591 

592 def __setitem__(self, key, value): 

593 if isinstance(key, int): 

594 raise IndexError("AthenaGroup does not support integer indexing") 

595 

596 return setattr(self, key, value) 

597 

598 def keys(self): 

599 return list(self.groups.keys()) 

600 

601 def values(self): 

602 return list(self.groups.values()) 

603 

604 def items(self): 

605 return list(self.groups.items()) 

606 

607class AthenaProject(object): 

608 """read and write Athena Project files, mapping to Larch group 

609 containing sub-groups for each spectra / record 

610 

611 note that two generations of Project files are supported for reading: 

612 

613 1. Perl save file (custom format?) 

614 2. JSON format 

615 

616 In addition, project files may be Gzipped or not. 

617 

618 By default, files are saved in Gzipped JSON format 

619 """ 

620 

621 def __init__(self, filename=None): 

622 self.groups = {} 

623 self.header = None 

624 self.journal = None 

625 self.filename = filename 

626 if filename is not None: 

627 if Path(filename).exists() and is_athena_project(filename): 

628 self.read(filename) 

629 

630 def add_group(self, group, signal=None): 

631 """add Larch group (presumably XAFS data) to Athena project""" 

632 from larch.xafs import pre_edge 

633 

634 x = athena_array(group, 'energy') 

635 if hasattr(group, 'energy_orig'): 

636 x = athena_array(group, 'energy_orig') 

637 yname = None 

638 for _name in ('mu', 'mutrans', 'mufluor'): 

639 if hasattr(group, _name): 

640 yname = _name 

641 break 

642 if x is None or yname is None: 

643 raise ValueError("can only add XAFS data to Athena project") 

644 

645 y = athena_array(group, yname) 

646 i0 = athena_array(group, 'i0') 

647 if signal is not None: 

648 signal = athena_array(group, signal) 

649 elif yname in ('mu', 'mutrans'): 

650 sname = None 

651 for _name in ('i1', 'itrans'): 

652 if hasattr(group, _name): 

653 sname = _name 

654 break 

655 if sname is not None: 

656 signal = athena_array(group, sname) 

657 

658 apars = getattr(group, 'athena_params', None) 

659 hashkey = getattr(group, 'id', None) 

660 if hashkey is None or hashkey in self.groups: 

661 hashkey = make_hashkey() 

662 while hashkey in self.groups: 

663 hashkey = make_hashkey() 

664 

665 # fill in data from pre-edge subtraction 

666 if not (hasattr(group, 'e0') and hasattr(group, 'edge_step')): 

667 pre_edge(group) 

668 group.args = make_athena_args(group, hashkey) 

669 

670 # fix parameters that are incompatible with athena 

671 group.args['bkg_nnorm'] = max(0, min(3, int(group.args['bkg_nnorm']))) 

672 

673 _elem, _edge = guess_edge(group.e0) 

674 group.args['bkg_z'] = _elem 

675 group.x = x 

676 group.y = y 

677 group.i0 = i0 

678 group.signal = signal 

679 

680 # add a selection flag 

681 group.sel = 1 

682 self.groups[hashkey] = group 

683 

684 def save(self, filename=None, use_gzip=True): 

685 if filename is not None: 

686 self.filename = filename 

687 iso_now = time.strftime('%Y-%m-%dT%H:%M:%S') 

688 pyosversion = "Python %s on %s" % (platform.python_version(), 

689 platform.platform()) 

690 

691 buff = ["# Athena project file -- Demeter version 0.9.24", 

692 "# This file created at %s" % iso_now, 

693 "# Using Larch version %s, %s" % (larch_version, pyosversion)] 

694 

695 for key, dat in self.groups.items(): 

696 if not hasattr(dat, 'args'): 

697 continue 

698 buff.append("") 

699 groupname = getattr(dat, 'groupname', key) 

700 

701 buff.append("$old_group = '%s';" % groupname) 

702 buff.append("@args = (%s);" % format_dict(dat.args)) 

703 buff.append("@x = (%s);" % format_array(dat.x)) 

704 buff.append("@y = (%s);" % format_array(dat.y)) 

705 if getattr(dat, 'i0', None) is not None: 

706 buff.append("@i0 = (%s);" % format_array(dat.i0)) 

707 if getattr(dat, 'signal', None) is not None: 

708 buff.append("@signal = (%s);" % format_array(dat.signal)) 

709 if getattr(dat, 'stddev', None) is not None: 

710 buff.append("@stddev = (%s);" % format_array(dat.stddev)) 

711 buff.append("[record] # ") 

712 

713 buff.extend(["", "@journal = {};", "", "1;", "", "", 

714 "# Local Variables:", "# truncate-lines: t", 

715 "# End:", ""]) 

716 fopen = GzipFile if use_gzip else open 

717 fh = fopen(unixpath(self.filename), 'w') 

718 fh.write(str2bytes("\n".join([bytes2str(t) for t in buff]))) 

719 fh.close() 

720 

721 def read(self, filename=None, match=None, do_preedge=True, do_bkg=False, 

722 do_fft=False, use_hashkey=False): 

723 """ 

724 read Athena project to group of groups, one for each Athena dataset 

725 in the project file. This supports both gzipped and unzipped files 

726 and old-style perl-like project files and new-style JSON project files 

727 

728 Arguments: 

729 filename (string): name of Athena Project file 

730 match (string): pattern to use to limit imported groups (see Note 1) 

731 do_preedge (bool): whether to do pre-edge subtraction [True] 

732 do_bkg (bool): whether to do XAFS background subtraction [False] 

733 do_fft (bool): whether to do XAFS Fast Fourier transform [False] 

734 use_hashkey (bool): whether to use Athena's hash key as the 

735 group name instead of the Athena label [False] 

736 Returns: 

737 None, fills in attributes `header`, `journal`, `filename`, `groups` 

738 

739 Notes: 

740 1. To limit the imported groups, use the pattern in `match`, 

741 using '*' to match 'all', '?' to match any single character, 

742 or [sequence] to match any of a sequence of letters. Matching 

743 is insensitive to case, and done with Python's fnmatch module. 

744 3. do_preedge, do_bkg, and do_fft will attempt to reproduce the 

745 pre-edge, background subtraction, and FFT from Athena by using 

746 the parameters saved in the project file. 

747 2. use_hashkey=True will name groups from the internal 5 character 

748 string used by Athena, instead of the group label. 

749 

750 Example: 

751 1. read in all groups from a project file: 

752 cr_data = read_athena('My Cr Project.prj') 

753 

754 2. read in only the "merged" data from a Project, do BKG and FFT: 

755 zn_data = read_athena('Zn on Stuff.prj', match='*merge*', do_bkg=True, do_fft=True) 

756 """ 

757 if filename is not None: 

758 self.filename = filename 

759 if not Path(self.filename).exists(): 

760 raise IOError("%s '%s': cannot find file" % (ERR_MSG, self.filename)) 

761 

762 from larch.xafs import pre_edge, autobk, xftf 

763 

764 if not Path(filename).exists(): 

765 raise IOError("file '%s' not found" % filename) 

766 

767 text = _read_raw_athena(filename) 

768 # failed to read: 

769 if text is None: 

770 raise OSError("failed to read '%s'" % filename) 

771 if not _test_athena_text(text): 

772 raise ValueError("%s '%s': invalid Athena File" % (ERR_MSG, filename)) 

773 

774 # decode JSON or Perl format 

775 data = None 

776 if '____header' in text[:500]: 

777 try: 

778 data = parse_jsonathena(text, self.filename) 

779 except Exception: 

780 pass 

781 

782 if data is None: 

783 data = parse_perlathena(text, self.filename) 

784 

785 if data is None: 

786 raise ValueError("cannot read file '%s' as Athena Project File" % (self.filename)) 

787 

788 self.header = data.header 

789 self.journal = data.journal 

790 self.group_names = data.group_names 

791 

792 if match is not None: 

793 match = match.lower() 

794 for gname in data.group_names: 

795 oname = gname 

796 if match is not None: 

797 if not fnmatch(gname.lower(), match): 

798 continue 

799 this = getattr(data, gname) 

800 

801 this.athena_id = this.athena_params.id 

802 if use_hashkey: 

803 oname = this.athena_params.id 

804 is_xmu = bool(int(getattr(this.athena_params, 'is_xmu', 1.0))) 

805 is_chi = bool(int(getattr(this.athena_params, 'is_chi', 0.0))) 

806 is_xmu = is_xmu and not is_chi 

807 for aname in ('is_xmudat', 'is_bkg', 'is_diff', 

808 'is_proj', 'is_pixel', 'is_rsp'): 

809 val = bool(int(getattr(this.athena_params, aname, 0.0))) 

810 is_xmu = is_xmu and not val 

811 

812 if is_xmu and (do_preedge or do_bkg): 

813 pars = clean_bkg_params(this.athena_params.bkg) 

814 this.energy_shift = getattr(this.athena_params.bkg, 'eshift', 0.) 

815 pre_edge(this, e0=float(pars.e0), 

816 pre1=float(pars.pre1), pre2=float(pars.pre2), 

817 norm1=float(pars.nor1), norm2=float(pars.nor2), 

818 nnorm=float(pars.nnorm), 

819 nvict=float(pars.nvict), 

820 make_flat=bool(pars.flatten)) 

821 if do_bkg and hasattr(pars, 'rbkg'): 

822 autobk(this, e0=float(pars.e0), rbkg=float(pars.rbkg), 

823 kmin=float(pars.spl1), kmax=float(pars.spl2), 

824 kweight=float(pars.kw), dk=float(pars.dk), 

825 clamp_lo=float(pars.clamp1), 

826 clamp_hi=float(pars.clamp2)) 

827 if do_fft: 

828 pars = clean_fft_params(this.athena_params.fft) 

829 kweight=2 

830 if hasattr(pars, 'kw'): 

831 kweight = float(pars.kw) 

832 xftf(this, kmin=float(pars.kmin), 

833 kmax=float(pars.kmax), kweight=kweight, 

834 window=pars.kwindow, dk=float(pars.dk)) 

835 if is_chi: 

836 this.k = this.energy*1.0 

837 this.chi = this.mu*1.0 

838 del this.energy 

839 del this.mu 

840 

841 # add a selection flag and XAS datatypes, as used by Larix 

842 this.sel = 1 

843 this.datatype = 'xas' 

844 this.filename = getattr(this, 'label', 'unknown') 

845 this.xdat = 1.0*this.energy 

846 this.ydat = 1.0*this.mu 

847 this.yerr = 1.0 

848 this.plot_xlabel = 'energy' 

849 this.plot_ylabel = 'mu' 

850 self.groups[oname] = this 

851 

852 def as_group(self): 

853 """convert AthenaProject to Larch group""" 

854 out = AthenaGroup() 

855 out.__doc__ = """XAFS Data from Athena Project File %s""" % (self.filename) 

856 out.filename = self.filename 

857 out.journal = self.journal 

858 out.header = self.header 

859 out.groups = {} 

860 

861 for name, group in self.groups.items(): 

862 out.groups[name] = group 

863 setattr(out, name, group) 

864 return out 

865 

866 def as_dict(self): 

867 """convert AthenaProject to a nested dictionary""" 

868 out = dict() 

869 out["_doc"] = """XAFS Data from Athena Project File %s""" % (self.filename) 

870 out["filename"] = self.filenamel # str 

871 out["journal"] = self.journal # str 

872 out["header"] = self.header # str 

873 out["groups"] = dict() 

874 

875 for name, group in self.groups.items(): 

876 gdict = group.__dict__ 

877 _ = gdict.pop("__name__") 

878 par_key = "_params" 

879 gout = deepcopy(gdict) 

880 gout[par_key] = dict() 

881 for subname, subgroup in gdict.items(): 

882 if isinstance(subgroup, Group): 

883 subdict = gout.pop(subname).__dict__ 

884 _ = subdict.pop("__name__") 

885 par_name = subname.split(par_key)[0] # group all paramters in common dictionary 

886 gout[par_key][par_name] = subdict 

887 out["groups"][name] = gout 

888 

889 return out 

890 

891 

892def read_athena(filename, match=None, do_preedge=True, do_bkg=False, 

893 do_fft=False, use_hashkey=False): 

894 """read athena project file 

895 returns a Group of Groups, one for each Athena Group in the project file 

896 

897 Arguments: 

898 filename (string): name of Athena Project file 

899 match (string): pattern to use to limit imported groups (see Note 1) 

900 do_preedge (bool): whether to do pre-edge subtraction [True] 

901 do_bkg (bool): whether to do XAFS background subtraction [False] 

902 do_fft (bool): whether to do XAFS Fast Fourier transform [False] 

903 use_hashkey (bool): whether to use Athena's hash key as the 

904 group name instead of the Athena label [False] 

905 

906 Returns: 

907 group of groups each named according the label used by Athena. 

908 

909 Notes: 

910 1. To limit the imported groups, use the pattern in `match`, 

911 using '*' to match 'all', '?' to match any single character, 

912 or [sequence] to match any of a sequence of letters. The match 

913 will always be insensitive to case. 

914 2. do_preedge, do_bkg, and do_fft will attempt to reproduce the 

915 pre-edge, background subtraction, and FFT from Athena by using 

916 the parameters saved in the project file. 

917 3. use_hashkey=True will name groups from the internal 5 character 

918 string used by Athena, instead of the group label. 

919 

920 Example: 

921 1. read in all groups from a project file: 

922 cr_data = read_athena('My Cr Project.prj') 

923 

924 2. read in only the "merged" data from a Project, and do BKG and FFT: 

925 zn_data = read_athena('Zn on Stuff.prj', match='*merge*', do_bkg=True, do_fft=True) 

926 

927 """ 

928 if not Path(filename).exists(): 

929 raise IOError("%s '%s': cannot find file" % (ERR_MSG, filename)) 

930 

931 aprj = AthenaProject() 

932 aprj.read(filename, match=match, do_preedge=do_preedge, do_bkg=do_bkg, 

933 do_fft=do_fft, use_hashkey=use_hashkey) 

934 

935 return aprj.as_group() 

936 

937 

938def create_athena(filename=None): 

939 """create athena project file""" 

940 return AthenaProject(filename=filename) 

941 

942 

943def extract_athenagroup(dgroup): 

944 '''deprecated -- no longer needed (extract xas group from athena group)''' 

945 dgroup.datatype = 'xas' 

946 dgroup.filename = getattr(dgroup, 'label', 'unknown') 

947 dgroup.xdat = 1.0*dgroup.energy 

948 dgroup.ydat = 1.0*dgroup.mu 

949 dgroup.yerr = 1.0 

950 dgroup.plot_xlabel = 'energy' 

951 dgroup.plot_ylabel = 'mu' 

952 return dgroup 

953#enddef