Coverage for larch/larchlib.py: 64%

436 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-10-16 21:04 +0000

1#!/usr/bin/env python 

2""" 

3Helper classes for larch interpreter 

4""" 

5import sys, os, time 

6from datetime import datetime 

7import ast 

8import numpy as np 

9import traceback 

10import toml 

11import inspect 

12from collections import namedtuple 

13from pathlib import Path 

14 

15import ctypes 

16import ctypes.util 

17 

18from .symboltable import Group, isgroup 

19from .site_config import user_larchdir 

20from .closure import Closure 

21from .utils import uname, bindir, get_cwd, read_textfile 

22 

23HAS_TERMCOLOR = False 

24try: 

25 from termcolor import colored 

26 if uname == 'win': 

27 # HACK (hopefully temporary): 

28 # disable color output for Windows command terminal 

29 # because it interferes with wx event loop. 

30 import CannotUseTermcolorOnWindowsWithWx 

31 # os.environ.pop('TERM') 

32 # import colorama 

33 # colorama.init() 

34 HAS_TERMCOLOR = True 

35except ImportError: 

36 HAS_TERMCOLOR = False 

37 

38 

39class Empty: 

40 def __nonzero__(self): return False 

41 

42# holder for 'returned None' from Larch procedure 

43ReturnedNone = Empty() 

44 

45def get_filetext(fname, lineno): 

46 """try to extract line from source text file""" 

47 out = '<could not find text>' 

48 try: 

49 ftmp = open(fname, 'r') 

50 lines = ftmp.readlines() 

51 ftmp.close() 

52 lineno = min(lineno, len(lines)) - 1 

53 out = lines[lineno][:-1] 

54 except: 

55 pass 

56 return out 

57 

58class LarchExceptionHolder: 

59 "basic exception handler" 

60 def __init__(self, node=None, msg='', fname='<stdin>', 

61 func=None, expr=None, exc=None, lineno=0): 

62 self.node = node 

63 self.fname = fname 

64 self.func = func 

65 self.expr = expr 

66 self.msg = msg 

67 self.exc = exc 

68 self.lineno = lineno 

69 self.exc_info = sys.exc_info() 

70 

71 if self.exc is None and self.exc_info[0] is not None: 

72 self.exc = self.exc_info[0] 

73 if self.msg in ('', None) and self.exc_info[1] is not None: 

74 self.msg = self.exc_info[1] 

75 

76 def get_error(self): 

77 "retrieve error data" 

78 col_offset = -1 

79 e_type, e_val, e_tb = self.exc_info 

80 if self.node is not None: 

81 try: 

82 col_offset = self.node.col_offset 

83 except AttributeError: 

84 pass 

85 try: 

86 exc_name = self.exc.__name__ 

87 except AttributeError: 

88 exc_name = str(self.exc) 

89 if exc_name in (None, 'None'): 

90 exc_name = 'UnknownError' 

91 

92 out = [] 

93 fname = self.fname 

94 

95 if isinstance(self.expr, ast.AST): 

96 self.expr = 'In compiled script' 

97 if self.expr is None: 

98 out.append('unknown error\n') 

99 elif '\n' in self.expr: 

100 out.append("\n%s" % self.expr) 

101 else: 

102 out.append(" %s" % self.expr) 

103 if col_offset > 0: 

104 out.append("%s^^^" % ((col_offset)*' ')) 

105 

106 fline = ' File %s, line %i' % (fname, self.lineno) 

107 if self.func is not None: 

108 func = self.func 

109 fname = self.fname 

110 if fname is None: 

111 if isinstance(func, Closure): 

112 func = func.func 

113 fname = inspect.getmodule(func).__file__ 

114 try: 

115 fname = inspect.getmodule(func).__file__ 

116 except AttributeError: 

117 fname = 'unknown' 

118 if fname.endswith('.pyc'): 

119 fname = fname[:-1] 

120 

121 if hasattr(self.func, 'name'): 

122 dec = '' 

123 if isinstance(self.func, Procedure): 

124 dec = 'procedure ' 

125 pname = self.func.name 

126 ftext = get_filetext(self.fname, self.lineno) 

127 fline = "%s, in %s%s\n%s" % (fline, dec, pname, ftext) 

128 

129 if fline is not None: 

130 out.append(fline) 

131 

132 tblist = [] 

133 for tb in traceback.extract_tb(self.exc_info[2]): 

134 if not (sys.prefix in tb[0] and 

135 ('ast.py' in tb[0] or 

136 Path('larch', 'utils').as_posix() in tb[0] or 

137 Path('larch', 'interpreter').as_posix() in tb[0] or 

138 Path('larch', 'symboltable').as_posix() in tb[0])): 

139 tblist.append(tb) 

140 if len(tblist) > 0: 

141 out.append(''.join(traceback.format_list(tblist))) 

142 

143 # try to get last error message, as from e_val.args 

144 ex_msg = getattr(e_val, 'args', None) 

145 try: 

146 ex_msg = ' '.join(ex_msg) 

147 except TypeError: 

148 pass 

149 

150 if ex_msg is None: 

151 ex_msg = getattr(e_val, 'message', None) 

152 if ex_msg is None: 

153 ex_msg = self.msg 

154 out.append("%s: %s" % (exc_name, ex_msg)) 

155 

156 out.append("") 

157 return (exc_name, '\n'.join(out)) 

158 

159 

160 

161class StdWriter(object): 

162 """Standard writer method for Larch, 

163 to be used in place of sys.stdout 

164 

165 supports methods: 

166 set_mode(mode) # one of 'text', 'text2', 'error', 'comment' 

167 write(text) 

168 flush() 

169 """ 

170 valid_termcolors = ('grey', 'red', 'green', 'yellow', 

171 'blue', 'magenta', 'cyan', 'white') 

172 

173 termcolor_attrs = ('bold', 'underline', 'blink', 'reverse') 

174 def __init__(self, stdout=None, has_color=True, _larch=None): 

175 if stdout is None: 

176 stdout = sys.stdout 

177 self.has_color = has_color and HAS_TERMCOLOR 

178 self.writer = stdout 

179 self._larch = _larch 

180 self.textstyle = None 

181 

182 def set_textstyle(self, mode='text'): 

183 """ set text style for output """ 

184 if not self.has_color: 

185 self.textstyle = None 

186 display_colors = self._larch.symtable._sys.display.colors 

187 self.textstyle = display_colors.get(mode, {}) 

188 

189 def write(self, text): 

190 """write text to writer 

191 write('hello') 

192 """ 

193 if self.textstyle is not None and HAS_TERMCOLOR: 

194 text = colored(text, **self.textstyle) 

195 self.writer.write(text) 

196 

197 def flush(self): 

198 self.writer.flush() 

199 

200 

201class Procedure(object): 

202 """larch procedure: function """ 

203 def __init__(self, name, _larch=None, doc=None, 

204 fname='<stdin>', lineno=0, 

205 body=None, args=None, kwargs=None, 

206 vararg=None, varkws=None): 

207 self.name = name 

208 self._larch = _larch 

209 self.modgroup = _larch.symtable._sys.moduleGroup 

210 self.body = body 

211 self.argnames = args 

212 self.kwargs = kwargs 

213 self.vararg = vararg 

214 self.varkws = varkws 

215 self.__doc__ = doc 

216 self.lineno = lineno 

217 self.__file__ = fname 

218 self.__name__ = name 

219 

220 def __repr__(self): 

221 return "<Procedure %s, file=%s>" % (self.name, self.__file__) 

222 

223 def _signature(self): 

224 sig = "" 

225 if len(self.argnames) > 0: 

226 sig = "%s%s" % (sig, ', '.join(self.argnames)) 

227 if self.vararg is not None: 

228 sig = "%s, *%s" % (sig, self.vararg) 

229 if len(self.kwargs) > 0: 

230 if len(sig) > 0: 

231 sig = "%s, " % sig 

232 _kw = ["%s=%s" % (k, repr(v)) for k, v in self.kwargs] 

233 sig = "%s%s" % (sig, ', '.join(_kw)) 

234 

235 if self.varkws is not None: 

236 sig = "%s, **%s" % (sig, self.varkws) 

237 return "%s(%s)" % (self.name, sig) 

238 

239 def raise_exc(self, **kws): 

240 ekws = dict(lineno=self.lineno, func=self, fname=self.__file__) 

241 ekws.update(kws) 

242 self._larch.raise_exception(None, **ekws) 

243 

244 def __call__(self, *args, **kwargs): 

245 # msg = 'Cannot run Procedure %s' % self.name 

246 lgroup = Group() 

247 lgroup.__name__ = hex(id(lgroup)) 

248 args = list(args) 

249 nargs = len(args) 

250 nkws = len(kwargs) 

251 nargs_expected = len(self.argnames) 

252 

253 

254 # case 1: too few arguments, but the correct keyword given 

255 if (nargs < nargs_expected) and nkws > 0: 

256 for name in self.argnames[nargs:]: 

257 if name in kwargs: 

258 args.append(kwargs.pop(name)) 

259 nargs = len(args) 

260 nargs_expected = len(self.argnames) 

261 nkws = len(kwargs) 

262 

263 # case 2: multiple values for named argument 

264 if len(self.argnames) > 0 and kwargs is not None: 

265 msg = "%s() got multiple values for keyword argument '%s'" 

266 for targ in self.argnames: 

267 if targ in kwargs: 

268 self.raise_exc(exc=TypeError, 

269 msg=msg % (self.name, targ)) 

270 return 

271 

272 # case 3: too few args given 

273 if nargs < nargs_expected: 

274 mod = 'at least' 

275 if len(self.kwargs) == 0: 

276 mod = 'exactly' 

277 msg = '%s() expected %s %i arguments (got %i)' 

278 self.raise_exc(exc=TypeError, 

279 msg=msg%(self.name, mod, nargs_expected, nargs)) 

280 return 

281 

282 # case 4: more args given than expected, varargs not given 

283 if nargs > nargs_expected and self.vararg is None: 

284 if nargs - nargs_expected > len(self.kwargs): 

285 msg = 'too many arguments for %s() expected at most %i, got %i' 

286 msg = msg % (self.name, len(self.kwargs)+nargs_expected, nargs) 

287 self.raise_exc(exc=TypeError, msg=msg) 

288 return 

289 for i, xarg in enumerate(args[nargs_expected:]): 

290 kw_name = self.kwargs[i][0] 

291 if kw_name not in kwargs: 

292 kwargs[kw_name] = xarg 

293 

294 for argname in self.argnames: 

295 if len(args) > 0: 

296 setattr(lgroup, argname, args.pop(0)) 

297 try: 

298 if self.vararg is not None: 

299 setattr(lgroup, self.vararg, tuple(args)) 

300 

301 for key, val in self.kwargs: 

302 if key in kwargs: 

303 val = kwargs.pop(key) 

304 setattr(lgroup, key, val) 

305 

306 if self.varkws is not None: 

307 setattr(lgroup, self.varkws, kwargs) 

308 elif len(kwargs) > 0: 

309 msg = 'extra keyword arguments for procedure %s (%s)' 

310 msg = msg % (self.name, ','.join(list(kwargs.keys()))) 

311 self.raise_exc(exc=TypeError, msg=msg) 

312 return 

313 

314 except (ValueError, LookupError, TypeError, 

315 NameError, AttributeError): 

316 msg = 'incorrect arguments for procedure %s' % self.name 

317 self.raise_exc(msg=msg) 

318 return 

319 

320 stable = self._larch.symtable 

321 stable.save_frame() 

322 stable.set_frame((lgroup, self.modgroup)) 

323 retval = None 

324 self._larch.retval = None 

325 self._larch._calldepth += 1 

326 self._larch.debug = True 

327 for node in self.body: 

328 self._larch.run(node, fname=self.__file__, func=self, 

329 lineno=node.lineno+self.lineno-1, with_raise=False) 

330 if len(self._larch.error) > 0: 

331 break 

332 if self._larch.retval is not None: 

333 retval = self._larch.retval 

334 if retval is ReturnedNone: retval = None 

335 break 

336 stable.restore_frame() 

337 self._larch._calldepth -= 1 

338 self._larch.debug = False 

339 self._larch.retval = None 

340 del lgroup 

341 return retval 

342 

343 

344def add2path(envvar='PATH', dirname='.'): 

345 """add specified dir to begninng of PATH and 

346 DYLD_LIBRARY_PATH, LD_LIBRARY_PATH environmental variables, 

347 returns previous definition of PATH, for restoration""" 

348 sep = ':' 

349 if uname == 'win': 

350 sep = ';' 

351 oldpath = os.environ.get(envvar, '') 

352 if oldpath == '': 

353 os.environ[envvar] = dirname 

354 else: 

355 paths = oldpath.split(sep) 

356 paths.insert(0, Path(dirname).absolute().as_posix()) 

357 os.environ[envvar] = sep.join(paths) 

358 return oldpath 

359 

360 

361def isNamedClass(obj, cls): 

362 """this is essentially a replacement for 

363 isinstance(obj, cls) 

364 that looks if an objects class name matches that of a class 

365 obj.__class__.__name__ == cls.__name__ 

366 """ 

367 return obj.__class__.__name__ == cls.__name__ 

368 

369def get_dll(libname): 

370 """find and load a shared library""" 

371 _dylib_formats = {'win': '%s.dll', 'linux': 'lib%s.so', 

372 'darwin': 'lib%s.dylib'} 

373 

374 loaddll = ctypes.cdll.LoadLibrary 

375 if uname == 'win': 

376 loaddll = ctypes.windll.LoadLibrary 

377 

378 # normally, we expect the dll to be here in the larch dlls tree 

379 # if we find it there, use that one 

380 fname = _dylib_formats[uname] % libname 

381 dllpath = Path(bindir, fname).absolute() 

382 if dllpath.exists(): 

383 return loaddll(dllpath.as_posix()) 

384 

385 # if not found in the larch dlls tree, try your best! 

386 dllpath = Path(ctypes.util.find_library(libname)).absolute() 

387 if dllpath is not None and dllpath.exists(): 

388 return loaddll(dllpath.as_posix()) 

389 return None 

390 

391 

392def read_workdir(conffile): 

393 """read working dir from a config file in the users larch dir 

394 compare save_workdir(conffile) which will save this value 

395 

396 can be used to ensure that application startup starts in 

397 last working directory 

398 """ 

399 

400 try: 

401 w_file = Path(user_larchdir, conffile).absolute() 

402 if w_file.exists(): 

403 line = open(w_file, 'r').readlines() 

404 workdir = line[0][:-1] 

405 os.chdir(workdir) 

406 except: 

407 pass 

408 

409def save_workdir(conffile): 

410 """write working dir to a config file in the users larch dir 

411 compare read_workdir(conffile) which will read this value 

412 

413 can be used to ensure that application startup starts in 

414 last working directory 

415 """ 

416 

417 try: 

418 w_file = Path(user_larchdir, conffile).absolute() 

419 fh = open(w_file, 'w', encoding=sys.getdefaultencoding()) 

420 fh.write("%s\n" % get_cwd()) 

421 fh.close() 

422 except: 

423 pass 

424 

425 

426def read_config(conffile): 

427 """read toml config file from users larch dir 

428 compare save_config(conffile) which will save such a config 

429 

430 returns dictionary / configuration 

431 """ 

432 cfile = Path(user_larchdir, conffile).absolute() 

433 out = None 

434 if cfile.exists(): 

435 data = read_textfile(cfile) 

436 try: 

437 out = toml.loads(data) 

438 except: 

439 pass 

440 return out 

441 

442def save_config(conffile, config): 

443 """write yaml config file in the users larch dir 

444 compare read_confif(conffile) which will read this value 

445 

446 """ 

447 cfile = Path(user_larchdir, conffile).absolute() 

448 dat = toml.dumps(config).encode('utf-8') 

449 with open(cfile, 'wb') as fh: 

450 fh.write(dat) 

451 #except: 

452 # print(f"Could not save configuration file '{conffile:s}'") 

453 

454def parse_group_args(arg0, members=None, group=None, defaults=None, 

455 fcn_name=None, check_outputs=True): 

456 """parse arguments for functions supporting First Argument Group convention 

457 

458 That is, if the first argument is a Larch Group and contains members 

459 named in 'members', this will return data extracted from that group. 

460 

461 Arguments 

462 ---------- 

463 arg0: first argument for function call. 

464 members: list/tuple of names of required members (in order) 

465 defaults: tuple of default values for remaining required 

466 arguments past the first (in order) 

467 group: group sent to parent function, used for outputs 

468 fcn_name: name of parent function, used for error messages 

469 check_output: True/False (default True) setting whether a Warning should 

470 be raised in any of the outputs (except for the final group) 

471 are None. This effectively checks that all expected inputs 

472 have been specified 

473 Returns 

474 ------- 

475 tuple of output values in the order listed by members, followed by the 

476 output group (which could be None). 

477 

478 Notes 

479 ----- 

480 This implements the First Argument Group convention, used for many Larch functions. 

481 As an example, the function _xafs.find_e0 is defined like this: 

482 find_e0(energy, mu=None, group=None, ...) 

483 

484 and uses this function as 

485 energy, mu, group = parse_group_arg(energy, members=('energy', 'mu'), 

486 defaults=(mu,), group=group, 

487 fcn_name='find_e0', check_output=True) 

488 

489 This allows the caller to use 

490 find_e0(grp) 

491 as a shorthand for 

492 find_e0(grp.energy, grp.mu, group=grp) 

493 

494 as long as the Group grp has member 'energy', and 'mu'. 

495 

496 With 'check_output=True', the value for 'mu' is not actually allowed to be None. 

497 

498 The defaults tuple should be passed so that correct values are assigned 

499 if the caller actually specifies arrays as for the full call signature. 

500 """ 

501 if members is None: 

502 members = [] 

503 if isgroup(arg0, *members): 

504 if group is None: 

505 group = arg0 

506 out = [getattr(arg0, attr) for attr in members] 

507 else: 

508 out = [arg0] + list(defaults) 

509 

510 # test that all outputs are non-None 

511 if check_outputs: 

512 _errmsg = """%s: needs First Argument Group or valid arguments for 

513 %s""" 

514 if fcn_name is None: 

515 fcn_name ='unknown function' 

516 for i, nam in enumerate(members): 

517 if out[i] is None: 

518 raise Warning(_errmsg % (fcn_name, ', '.join(members))) 

519 

520 out.append(group) 

521 return out 

522 

523def Make_CallArgs(skipped_args): 

524 """ 

525 decorator to create a 'call_args' dictionary 

526 containing function arguments 

527 If a Group is included in the call arguments, 

528 these call_args will be added to the group's journal 

529 """ 

530 def wrap(fcn): 

531 def wrapper(*args, **kwargs): 

532 result = fcn(*args, **kwargs) 

533 argspec = inspect.getfullargspec(fcn) 

534 

535 offset = len(argspec.args) - len(argspec.defaults) 

536 call_args = {} 

537 

538 for k in argspec.args[:offset]: 

539 call_args[k] = None 

540 for k, v in zip(argspec.args[offset:], argspec.defaults): 

541 call_args[k] = v 

542 

543 for iarg, arg in enumerate(args): 

544 call_args[argspec.args[iarg]] = arg 

545 

546 call_args.update(kwargs) 

547 

548 skipped = skipped_args[:] 

549 at0 = skipped[0] 

550 at1 = skipped[1] 

551 a, b, groupx = parse_group_args(call_args[at0], 

552 members=(at0, at1), 

553 defaults=(call_args[at1],), 

554 group=call_args['group'], 

555 fcn_name=fcn.__name__) 

556 

557 for k in skipped + ['group', '_larch']: 

558 if k in call_args: 

559 call_args.pop(k) 

560 

561 if groupx is not None: 

562 fname = fcn.__name__ 

563 if not hasattr(groupx, 'journal'): groupx.journal = Journal() 

564 if not hasattr(groupx, 'callargs'): groupx.callargs = Group() 

565 setattr(groupx.callargs, fname, call_args) 

566 groupx.journal.add(f'{fname}_callargs', call_args) 

567 

568 return result 

569 wrapper.__doc__ = fcn.__doc__ 

570 wrapper.__name__ = fcn.__name__ 

571 wrapper._larchfunc_ = fcn 

572 wrapper.__filename__ = fcn.__code__.co_filename 

573 wrapper.__dict__.update(fcn.__dict__) 

574 return wrapper 

575 return wrap 

576 

577 

578def ensuremod(_larch, modname=None): 

579 "ensure that a group exists" 

580 if _larch is not None: 

581 symtable = _larch.symtable 

582 if modname is not None and not symtable.has_group(modname): 

583 symtable.newgroup(modname) 

584 return symtable 

585 

586Entry = namedtuple('Entry', ('key', 'value', 'datetime')) 

587 

588def _get_dtime(dtime=None): 

589 """get datetime from input 

590 dtime can be: 

591 datetime : used as is 

592 str : assumed to be isoformat 

593 float : assumed to unix timestamp 

594 None : means now 

595 """ 

596 if isinstance(dtime, datetime): 

597 return dtime 

598 if isinstance(dtime, (int, float)): 

599 return datetime.fromtimestamp(dtime) 

600 elif isinstance(dtime, str): 

601 return datetime.fromisoformat(dtime) 

602 return datetime.now() 

603 

604class Journal: 

605 """list of journal entries""" 

606 def __init__(self, *args, **kws): 

607 self.data = [] 

608 for arg in args: 

609 if isinstance(arg, Journal): 

610 for entry in arg.data: 

611 self.add(entry.key, entry.value, dtime=entry.datetime) 

612 elif isinstance(arg, (list, tuple)): 

613 for entry in arg: 

614 self.add(entry[0], entry[1], dtime=entry[2]) 

615 

616 for k, v in kws.items(): 

617 self.add(k, v) 

618 

619 def tolist(self): 

620 return [(x.key, x.value, x.datetime.isoformat()) for x in self.data] 

621 

622 def __repr__(self): 

623 return repr(self.tolist()) 

624 

625 def __iter__(self): 

626 return iter(self.data) 

627 

628 

629 def add(self, key, value, dtime=None): 

630 """add journal entry: 

631 key, value pair with optional datetime 

632 """ 

633 self.data.append(Entry(key, value, _get_dtime(dtime))) 

634 

635 def add_ifnew(self, key, value, dtime=None): 

636 """add journal entry unless it already matches latest 

637 value (and dtime if supplied) 

638 """ 

639 needs_add = True 

640 latest = self.get(key, latest=True) 

641 if latest is not None: 

642 needs_add = (latest.value != value) 

643 if not needs_add and dtime is not None: 

644 dtime = _get_dtime(dtime) 

645 needs_add = needs_add or (latest.dtime != dtime) 

646 

647 if needs_add: 

648 self.add(key, value, dtime=dtime) 

649 

650 def get(self, key, latest=True): 

651 """get journal entries by key 

652 

653 Arguments 

654 ---------- 

655 latest [bool] whether to return latest matching entry only [True] 

656 

657 Notes: 

658 ------- 

659 if latest is True, one value will be returned, 

660 otherwise a list of entries (possibly length 1) will be returned. 

661 

662 """ 

663 matches = [x for x in self.data if x.key==key] 

664 if latest: 

665 tlatest = 0 

666 latest = None 

667 for m in matches: 

668 if m.datetime.timestamp() > tlatest: 

669 latest = m 

670 return latest 

671 return matches 

672 

673 def keys(self): 

674 return [x.key for x in self.data] 

675 

676 def values(self): 

677 return [x.values for x in self.data] 

678 

679 def items(self): 

680 return [(x.key, x.value) for x in self.data] 

681 

682 def get_latest(self, key): 

683 return self.get(key, latest=True) 

684 

685 def get_matches(self, key): 

686 return self.get(key, latest=False) 

687 

688 def sorted(self, sortby='time'): 

689 "return all entries, sorted by time or alphabetically by key" 

690 if 'time' in sortby.lower(): 

691 return sorted(self.data, key=lambda x: x.datetime.timestamp()) 

692 else: 

693 return sorted(self.data, key=lambda x: x.key) 

694 

695 def __getstate__(self): 

696 "get state for pickle / json encoding" 

697 return [(x.key, x.value, x.datetime.isoformat()) for x in self.data] 

698 

699 def __setstate__(self, state): 

700 "set state from pickle / json encoding" 

701 self.data = [] 

702 for key, value, dt in state: 

703 self.data.append(Entry(key, value, datetime.fromisoformat(dt)))