Coverage for larch/utils/jsonutils.py: 53%

268 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-10-16 21:04 +0000

1#!/usr/bin/env python 

2""" 

3 json utilities for larch objects 

4""" 

5import json 

6import io 

7import numpy as np 

8import h5py 

9from datetime import datetime 

10from collections import namedtuple 

11from pathlib import Path, PosixPath 

12from types import ModuleType 

13import importlib 

14import logging 

15 

16 

17from lmfit import Parameter, Parameters 

18from lmfit.model import Model, ModelResult 

19from lmfit.minimizer import Minimizer, MinimizerResult 

20from lmfit.parameter import SCIPY_FUNCTIONS 

21 

22from larch import Group, isgroup 

23from larch.utils.logging import getLogger 

24from larch.utils.logging import _levels as LoggingLevels 

25 

26HAS_STATE = {} 

27LarchGroupTypes = {} 

28 

29def setup_larchtypes(): 

30 global HAS_STATE, LarchGroupTypes 

31 if len(HAS_STATE) == 0 or len(LarchGroupTypes)==0: 

32 try: 

33 from sklearn.cross_decomposition import PLSRegression 

34 from sklearn.linear_model import LassoLarsCV, LassoLars, Lasso 

35 HAS_STATE.update({'PLSRegression': PLSRegression, 

36 'LassoLarsCV':LassoLarsCV, 

37 'LassoLars': LassoLars, 'Lasso': Lasso}) 

38 

39 except ImportError: 

40 pass 

41 

42 from larch import Journal, Group 

43 

44 HAS_STATE['Journal'] = Journal 

45 

46 from larch.xafs.feffutils import FeffCalcResults 

47 HAS_STATE['FeffCalcResults'] = FeffCalcResults 

48 

49 from larch.xafs import FeffDatFile, FeffPathGroup 

50 HAS_STATE['FeffDatFile'] = FeffDatFile 

51 HAS_STATE['FeffPathGroup'] = FeffPathGroup 

52 

53 from larch import ParameterGroup 

54 from larch.io.athena_project import AthenaGroup 

55 from larch.xafs import FeffitDataSet, TransformGroup 

56 

57 LarchGroupTypes = {'Group': Group, 

58 'AthenaGroup': AthenaGroup, 

59 'ParameterGroup': ParameterGroup, 

60 'FeffitDataSet': FeffitDataSet, 

61 'TransformGroup': TransformGroup, 

62 'MinimizerResult': MinimizerResult, 

63 'Minimizer': Minimizer, 

64 'FeffDatFile': FeffDatFile, 

65 'FeffPathGroup': FeffPathGroup, 

66 } 

67 

68 

69def unpack_minimizer(out): 

70 "this will unpack a minimizer, which can appear in a few places" 

71 params = out.pop('params') 

72 userfunc = out.pop('userfcn', None) 

73 kws = out.pop('kws', None) 

74 if kws is None: 

75 kws = {} 

76 if 'kw' in out: 

77 kwx = out.pop('kw', None) 

78 if isinstance(kwx, dict): 

79 kws.update(kwx) 

80 for kname in ('scale_covar', 'max_nfev', 'nan_policy'): 

81 kws[kname] = out.pop(kname) 

82 mini = Minimizer(userfunc, params, **kws) 

83 for kname in ('success', 'nfev', 'nfree', 'ndata', 'ier', 

84 'errorbars', 'message', 'lmdif_message', 'chisqr', 

85 'redchi', 'covar', 'userkws', 'userargs', 'result'): 

86 setattr(mini, kname, out.pop(kname)) 

87 return mini 

88 

89 

90def encode4js(obj): 

91 """return an object ready for json encoding. 

92 has special handling for many Python types 

93 numpy array 

94 complex numbers 

95 Larch Groups 

96 Larch Parameters 

97 """ 

98 setup_larchtypes() 

99 if obj is None: 

100 return None 

101 if isinstance(obj, np.ndarray): 

102 out = {'__class__': 'Array', '__shape__': obj.shape, 

103 '__dtype__': obj.dtype.name} 

104 out['value'] = obj.flatten().tolist() 

105 

106 if 'complex' in obj.dtype.name: 

107 out['value'] = [(obj.real).tolist(), (obj.imag).tolist()] 

108 elif obj.dtype.name == 'object': 

109 out['value'] = [encode4js(i) for i in out['value']] 

110 return out 

111 elif isinstance(obj, (bool, np.bool_)): 

112 return bool(obj) 

113 elif isinstance(obj, (int, np.int64, np.int32)): 

114 return int(obj) 

115 elif isinstance(obj, (float, np.float64, np.float32)): 

116 return float(obj) 

117 elif isinstance(obj, str): 

118 return str(obj) 

119 elif isinstance(obj, bytes): 

120 return obj.decode('utf-8') 

121 elif isinstance(obj, datetime): 

122 return {'__class__': 'Datetime', 'isotime': obj.isoformat()} 

123 elif isinstance(obj, (Path, PosixPath)): 

124 return {'__class__': 'Path', 'value': obj.as_posix()} 

125 elif isinstance(obj,(complex, np.complex128)): 

126 return {'__class__': 'Complex', 'value': (obj.real, obj.imag)} 

127 elif isinstance(obj, io.IOBase): 

128 out ={'__class__': 'IOBase', 'class': obj.__class__.__name__, 

129 'name': obj.name, 'closed': obj.closed, 

130 'readable': obj.readable(), 'writable': False} 

131 try: 

132 out['writable'] = obj.writable() 

133 except ValueError: 

134 out['writable'] = False 

135 elif isinstance(obj, h5py.File): 

136 return {'__class__': 'HDF5File', 

137 'value': (obj.name, obj.filename, obj.mode, obj.libver), 

138 'keys': list(obj.keys())} 

139 elif isinstance(obj, h5py.Group): 

140 return {'__class__': 'HDF5Group', 'value': (obj.name, obj.file.filename), 

141 'keys': list(obj.keys())} 

142 elif isinstance(obj, slice): 

143 return {'__class__': 'Slice', 'value': (obj.start, obj.stop, obj.step)} 

144 

145 elif isinstance(obj, list): 

146 return {'__class__': 'List', 'value': [encode4js(item) for item in obj]} 

147 elif isinstance(obj, tuple): 

148 if hasattr(obj, '_fields'): # named tuple! 

149 return {'__class__': 'NamedTuple', 

150 '__name__': obj.__class__.__name__, 

151 '_fields': obj._fields, 

152 'value': [encode4js(item) for item in obj]} 

153 else: 

154 return {'__class__': 'Tuple', 'value': [encode4js(item) for item in obj]} 

155 elif isinstance(obj, dict): 

156 out = {'__class__': 'Dict'} 

157 for key, val in obj.items(): 

158 out[encode4js(key)] = encode4js(val) 

159 return out 

160 elif isinstance(obj, logging.Logger): 

161 level = 'DEBUG' 

162 for key, val in LoggingLevels.items(): 

163 if obj.level == val: 

164 level = key 

165 return {'__class__': 'Logger', 'name': obj.name, 'level': level} 

166 elif isinstance(obj, Model): 

167 return {'__class__': 'Model', 'value': obj.dumps()} 

168 elif isinstance(obj, ModelResult): 

169 return {'__class__': 'ModelResult', 'value': obj.dumps()} 

170 

171 elif isinstance(obj, Minimizer) and not isinstance(obj, ModelResult): 

172 out = {'__class__': 'Minimizer'} 

173 

174 for attr in ('userfcn', 'params', 'kw', 'scale_covar', 'max_nfev', 

175 'nan_policy', 'success', 'nfev', 'nfree', 'ndata', 'ier', 

176 'errorbars', 'message', 'lmdif_message', 'chisqr', 

177 'redchi', 'covar', 'userkws', 'userargs', 'result'): 

178 out[attr] = encode4js(getattr(obj, attr, None)) 

179 return out 

180 elif isinstance(obj, MinimizerResult): 

181 out = {'__class__': 'MinimizerResult'} 

182 for attr in ('aborted', 'aic', 'bic', 'call_kws', 'chisqr', 

183 'covar', 'errorbars', 'ier', 'init_vals', 

184 'init_values', 'last_internal_values', 

185 'lmdif_message', 'message', 'method', 'ndata', 'nfev', 

186 'nfree', 'nvarys', 'params', 'redchi', 'residual', 

187 'success', 'var_names'): 

188 out[attr] = encode4js(getattr(obj, attr, None)) 

189 return out 

190 elif isinstance(obj, Parameters): 

191 out = {'__class__': 'Parameters'} 

192 o_ast = obj._asteval 

193 out['unique_symbols'] = {key: encode4js(o_ast.symtable[key]) 

194 for key in o_ast.user_defined_symbols()} 

195 out['params'] = [(p.name, p.__getstate__()) for p in obj.values()] 

196 return out 

197 elif isinstance(obj, Parameter): 

198 return {'__class__': 'Parameter', 'name': obj.name, 'state': obj.__getstate__()} 

199 elif isgroup(obj): 

200 try: 

201 classname = obj.__class__.__name__ 

202 except: 

203 classname = 'Group' 

204 out = {'__class__': classname} 

205 

206 if classname == 'ParameterGroup': # save in order of parameter names 

207 parnames = dir(obj) 

208 for par in obj.__params__.keys(): 

209 if par in parnames: 

210 out[par] = encode4js(getattr(obj, par)) 

211 else: 

212 for item in dir(obj): 

213 out[item] = encode4js(getattr(obj, item)) 

214 return out 

215 

216 elif isinstance(obj, ModuleType): 

217 return {'__class__': 'Module', 'value': obj.__name__} 

218 elif hasattr(obj, '__getstate__') and not callable(obj): 

219 return {'__class__': 'StatefulObject', 

220 '__type__': obj.__class__.__name__, 

221 'value': encode4js(obj.__getstate__())} 

222 elif isinstance(obj, type): 

223 return {'__class__': 'Type', 'value': repr(obj), 

224 'module': getattr(obj, '__module__', None)} 

225 elif callable(obj): 

226 return {'__class__': 'Method', '__name__': repr(obj)} 

227 elif hasattr(obj, 'dumps'): 

228 print("Encode Warning: using dumps for ", obj) 

229 return {'__class__': 'DumpableObject', 'value': obj.dumps()} 

230 else: 

231 print("Encode Warning: generic object dump for ", repr(obj)) 

232 out = {'__class__': 'Object', '__repr__': repr(obj), 

233 '__classname__': obj.__class__.__name__} 

234 for attr in dir(obj): 

235 if attr.startswith('__') and attr.endswith('__'): 

236 continue 

237 thing = getattr(obj, attr) 

238 if not callable(thing): 

239 # print("will try to encode thing ", thing, type(thing)) 

240 out[attr] = encode4js(thing) 

241 return out 

242 

243 return obj 

244 

245def decode4js(obj): 

246 """ 

247 return decoded Python object from encoded object. 

248 

249 """ 

250 if not isinstance(obj, dict): 

251 return obj 

252 setup_larchtypes() 

253 out = obj 

254 classname = obj.pop('__class__', None) 

255 if classname is None: 

256 return obj 

257 

258 if classname == 'Complex': 

259 out = obj['value'][0] + 1j*obj['value'][1] 

260 elif classname in ('List', 'Tuple', 'NamedTuple'): 

261 out = [] 

262 for item in obj['value']: 

263 out.append(decode4js(item)) 

264 if classname == 'Tuple': 

265 out = tuple(out) 

266 elif classname == 'NamedTuple': 

267 out = namedtuple(obj['__name__'], obj['_fields'])(*out) 

268 elif classname == 'Array': 

269 if obj['__dtype__'].startswith('complex'): 

270 re = np.asarray(obj['value'][0], dtype='double') 

271 im = np.asarray(obj['value'][1], dtype='double') 

272 out = re + 1j*im 

273 elif obj['__dtype__'].startswith('object'): 

274 val = [decode4js(v) for v in obj['value']] 

275 out = np.array(val, dtype=obj['__dtype__']) 

276 

277 else: 

278 out = np.asarray(obj['value'], dtype=obj['__dtype__']) 

279 out.shape = obj['__shape__'] 

280 elif classname in ('Dict', 'dict'): 

281 out = {} 

282 for key, val in obj.items(): 

283 out[key] = decode4js(val) 

284 elif classname == 'Datetime': 

285 obj = datetime.fromisoformat(obj['isotime']) 

286 elif classname in ('Path', 'PosixPath'): 

287 obj = Path(obj['value']) 

288 

289 elif classname == 'IOBase': 

290 mode = 'r' 

291 if obj['readable'] and obj['writable']: 

292 mode = 'a' 

293 elif not obj['readable'] and obj['writable']: 

294 mode = 'w' 

295 out = open(obj['name'], mode=mode) 

296 if obj['closed']: 

297 out.close() 

298 

299 elif classname == 'Module': 

300 out = importlib.import_module(obj.__name__) 

301 

302 elif classname == 'Parameters': 

303 out = Parameters() 

304 out.clear() 

305 unique_symbols = {key: decode4js(obj['unique_symbols'][key]) for key 

306 in obj['unique_symbols']} 

307 

308 state = {'unique_symbols': unique_symbols, 'params': []} 

309 for name, parstate in obj['params']: 

310 par = Parameter(decode4js(name)) 

311 par.__setstate__(decode4js(parstate)) 

312 state['params'].append(par) 

313 out.__setstate__(state) 

314 elif classname in ('Parameter', 'parameter'): 

315 name = decode4js(obj['name']) 

316 state = decode4js(obj['state']) 

317 out = Parameter(name) 

318 out.__setstate__(state) 

319 

320 elif classname == 'Model': 

321 mod = Model(lambda x: x) 

322 out = mod.loads(decode4js(obj['value'])) 

323 

324 elif classname == 'ModelResult': 

325 params = Parameters() 

326 res = ModelResult(Model(lambda x: x, None), params) 

327 out = res.loads(decode4js(obj['value'])) 

328 

329 elif classname == 'Logger': 

330 out = getLogger(obj['name'], level=obj['level']) 

331 

332 elif classname == 'StatefulObject': 

333 dtype = obj.get('__type__') 

334 if dtype in HAS_STATE: 

335 out = HAS_STATE[dtype]() 

336 out.__setstate__(decode4js(obj.get('value'))) 

337 elif dtype == 'Minimizer': 

338 out = unpack_minimizer(out['value']) 

339 else: 

340 print(f"Warning: cannot re-create stateful object of type '{dtype}'") 

341 

342 elif classname in LarchGroupTypes: 

343 out = {} 

344 for key, val in obj.items(): 

345 if (isinstance(val, dict) and 

346 val.get('__class__', None) == 'Method' and 

347 val.get('__name__', None) is not None): 

348 pass # ignore class methods for subclassed Groups 

349 else: 

350 out[key] = decode4js(val) 

351 if classname == 'Minimizer': 

352 out = unpack_minimizer(out) 

353 elif classname == 'FeffDatFile': 

354 from larch.xafs import FeffDatFile 

355 path = FeffDatFile() 

356 path._set_from_dict(**out) 

357 out = path 

358 else: 

359 out = LarchGroupTypes[classname](**out) 

360 elif classname == 'Method': 

361 mname = obj.get('__name__', '') 

362 if 'ufunc' in mname: 

363 mname = mname.replace('<ufunc', '').replace('>', '').replace("'","").strip() 

364 out = SCIPY_FUNCTIONS.get(mname, None) 

365 

366 else: 

367 print("cannot decode ", classname, repr(obj)[:100]) 

368 return out