Coverage for larch/utils/jsonutils.py: 53%
268 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-10-16 21:04 +0000
« prev ^ index » next coverage.py v7.6.0, created at 2024-10-16 21:04 +0000
1#!/usr/bin/env python
2"""
3 json utilities for larch objects
4"""
5import json
6import io
7import numpy as np
8import h5py
9from datetime import datetime
10from collections import namedtuple
11from pathlib import Path, PosixPath
12from types import ModuleType
13import importlib
14import logging
17from lmfit import Parameter, Parameters
18from lmfit.model import Model, ModelResult
19from lmfit.minimizer import Minimizer, MinimizerResult
20from lmfit.parameter import SCIPY_FUNCTIONS
22from larch import Group, isgroup
23from larch.utils.logging import getLogger
24from larch.utils.logging import _levels as LoggingLevels
26HAS_STATE = {}
27LarchGroupTypes = {}
29def setup_larchtypes():
30 global HAS_STATE, LarchGroupTypes
31 if len(HAS_STATE) == 0 or len(LarchGroupTypes)==0:
32 try:
33 from sklearn.cross_decomposition import PLSRegression
34 from sklearn.linear_model import LassoLarsCV, LassoLars, Lasso
35 HAS_STATE.update({'PLSRegression': PLSRegression,
36 'LassoLarsCV':LassoLarsCV,
37 'LassoLars': LassoLars, 'Lasso': Lasso})
39 except ImportError:
40 pass
42 from larch import Journal, Group
44 HAS_STATE['Journal'] = Journal
46 from larch.xafs.feffutils import FeffCalcResults
47 HAS_STATE['FeffCalcResults'] = FeffCalcResults
49 from larch.xafs import FeffDatFile, FeffPathGroup
50 HAS_STATE['FeffDatFile'] = FeffDatFile
51 HAS_STATE['FeffPathGroup'] = FeffPathGroup
53 from larch import ParameterGroup
54 from larch.io.athena_project import AthenaGroup
55 from larch.xafs import FeffitDataSet, TransformGroup
57 LarchGroupTypes = {'Group': Group,
58 'AthenaGroup': AthenaGroup,
59 'ParameterGroup': ParameterGroup,
60 'FeffitDataSet': FeffitDataSet,
61 'TransformGroup': TransformGroup,
62 'MinimizerResult': MinimizerResult,
63 'Minimizer': Minimizer,
64 'FeffDatFile': FeffDatFile,
65 'FeffPathGroup': FeffPathGroup,
66 }
69def unpack_minimizer(out):
70 "this will unpack a minimizer, which can appear in a few places"
71 params = out.pop('params')
72 userfunc = out.pop('userfcn', None)
73 kws = out.pop('kws', None)
74 if kws is None:
75 kws = {}
76 if 'kw' in out:
77 kwx = out.pop('kw', None)
78 if isinstance(kwx, dict):
79 kws.update(kwx)
80 for kname in ('scale_covar', 'max_nfev', 'nan_policy'):
81 kws[kname] = out.pop(kname)
82 mini = Minimizer(userfunc, params, **kws)
83 for kname in ('success', 'nfev', 'nfree', 'ndata', 'ier',
84 'errorbars', 'message', 'lmdif_message', 'chisqr',
85 'redchi', 'covar', 'userkws', 'userargs', 'result'):
86 setattr(mini, kname, out.pop(kname))
87 return mini
90def encode4js(obj):
91 """return an object ready for json encoding.
92 has special handling for many Python types
93 numpy array
94 complex numbers
95 Larch Groups
96 Larch Parameters
97 """
98 setup_larchtypes()
99 if obj is None:
100 return None
101 if isinstance(obj, np.ndarray):
102 out = {'__class__': 'Array', '__shape__': obj.shape,
103 '__dtype__': obj.dtype.name}
104 out['value'] = obj.flatten().tolist()
106 if 'complex' in obj.dtype.name:
107 out['value'] = [(obj.real).tolist(), (obj.imag).tolist()]
108 elif obj.dtype.name == 'object':
109 out['value'] = [encode4js(i) for i in out['value']]
110 return out
111 elif isinstance(obj, (bool, np.bool_)):
112 return bool(obj)
113 elif isinstance(obj, (int, np.int64, np.int32)):
114 return int(obj)
115 elif isinstance(obj, (float, np.float64, np.float32)):
116 return float(obj)
117 elif isinstance(obj, str):
118 return str(obj)
119 elif isinstance(obj, bytes):
120 return obj.decode('utf-8')
121 elif isinstance(obj, datetime):
122 return {'__class__': 'Datetime', 'isotime': obj.isoformat()}
123 elif isinstance(obj, (Path, PosixPath)):
124 return {'__class__': 'Path', 'value': obj.as_posix()}
125 elif isinstance(obj,(complex, np.complex128)):
126 return {'__class__': 'Complex', 'value': (obj.real, obj.imag)}
127 elif isinstance(obj, io.IOBase):
128 out ={'__class__': 'IOBase', 'class': obj.__class__.__name__,
129 'name': obj.name, 'closed': obj.closed,
130 'readable': obj.readable(), 'writable': False}
131 try:
132 out['writable'] = obj.writable()
133 except ValueError:
134 out['writable'] = False
135 elif isinstance(obj, h5py.File):
136 return {'__class__': 'HDF5File',
137 'value': (obj.name, obj.filename, obj.mode, obj.libver),
138 'keys': list(obj.keys())}
139 elif isinstance(obj, h5py.Group):
140 return {'__class__': 'HDF5Group', 'value': (obj.name, obj.file.filename),
141 'keys': list(obj.keys())}
142 elif isinstance(obj, slice):
143 return {'__class__': 'Slice', 'value': (obj.start, obj.stop, obj.step)}
145 elif isinstance(obj, list):
146 return {'__class__': 'List', 'value': [encode4js(item) for item in obj]}
147 elif isinstance(obj, tuple):
148 if hasattr(obj, '_fields'): # named tuple!
149 return {'__class__': 'NamedTuple',
150 '__name__': obj.__class__.__name__,
151 '_fields': obj._fields,
152 'value': [encode4js(item) for item in obj]}
153 else:
154 return {'__class__': 'Tuple', 'value': [encode4js(item) for item in obj]}
155 elif isinstance(obj, dict):
156 out = {'__class__': 'Dict'}
157 for key, val in obj.items():
158 out[encode4js(key)] = encode4js(val)
159 return out
160 elif isinstance(obj, logging.Logger):
161 level = 'DEBUG'
162 for key, val in LoggingLevels.items():
163 if obj.level == val:
164 level = key
165 return {'__class__': 'Logger', 'name': obj.name, 'level': level}
166 elif isinstance(obj, Model):
167 return {'__class__': 'Model', 'value': obj.dumps()}
168 elif isinstance(obj, ModelResult):
169 return {'__class__': 'ModelResult', 'value': obj.dumps()}
171 elif isinstance(obj, Minimizer) and not isinstance(obj, ModelResult):
172 out = {'__class__': 'Minimizer'}
174 for attr in ('userfcn', 'params', 'kw', 'scale_covar', 'max_nfev',
175 'nan_policy', 'success', 'nfev', 'nfree', 'ndata', 'ier',
176 'errorbars', 'message', 'lmdif_message', 'chisqr',
177 'redchi', 'covar', 'userkws', 'userargs', 'result'):
178 out[attr] = encode4js(getattr(obj, attr, None))
179 return out
180 elif isinstance(obj, MinimizerResult):
181 out = {'__class__': 'MinimizerResult'}
182 for attr in ('aborted', 'aic', 'bic', 'call_kws', 'chisqr',
183 'covar', 'errorbars', 'ier', 'init_vals',
184 'init_values', 'last_internal_values',
185 'lmdif_message', 'message', 'method', 'ndata', 'nfev',
186 'nfree', 'nvarys', 'params', 'redchi', 'residual',
187 'success', 'var_names'):
188 out[attr] = encode4js(getattr(obj, attr, None))
189 return out
190 elif isinstance(obj, Parameters):
191 out = {'__class__': 'Parameters'}
192 o_ast = obj._asteval
193 out['unique_symbols'] = {key: encode4js(o_ast.symtable[key])
194 for key in o_ast.user_defined_symbols()}
195 out['params'] = [(p.name, p.__getstate__()) for p in obj.values()]
196 return out
197 elif isinstance(obj, Parameter):
198 return {'__class__': 'Parameter', 'name': obj.name, 'state': obj.__getstate__()}
199 elif isgroup(obj):
200 try:
201 classname = obj.__class__.__name__
202 except:
203 classname = 'Group'
204 out = {'__class__': classname}
206 if classname == 'ParameterGroup': # save in order of parameter names
207 parnames = dir(obj)
208 for par in obj.__params__.keys():
209 if par in parnames:
210 out[par] = encode4js(getattr(obj, par))
211 else:
212 for item in dir(obj):
213 out[item] = encode4js(getattr(obj, item))
214 return out
216 elif isinstance(obj, ModuleType):
217 return {'__class__': 'Module', 'value': obj.__name__}
218 elif hasattr(obj, '__getstate__') and not callable(obj):
219 return {'__class__': 'StatefulObject',
220 '__type__': obj.__class__.__name__,
221 'value': encode4js(obj.__getstate__())}
222 elif isinstance(obj, type):
223 return {'__class__': 'Type', 'value': repr(obj),
224 'module': getattr(obj, '__module__', None)}
225 elif callable(obj):
226 return {'__class__': 'Method', '__name__': repr(obj)}
227 elif hasattr(obj, 'dumps'):
228 print("Encode Warning: using dumps for ", obj)
229 return {'__class__': 'DumpableObject', 'value': obj.dumps()}
230 else:
231 print("Encode Warning: generic object dump for ", repr(obj))
232 out = {'__class__': 'Object', '__repr__': repr(obj),
233 '__classname__': obj.__class__.__name__}
234 for attr in dir(obj):
235 if attr.startswith('__') and attr.endswith('__'):
236 continue
237 thing = getattr(obj, attr)
238 if not callable(thing):
239 # print("will try to encode thing ", thing, type(thing))
240 out[attr] = encode4js(thing)
241 return out
243 return obj
245def decode4js(obj):
246 """
247 return decoded Python object from encoded object.
249 """
250 if not isinstance(obj, dict):
251 return obj
252 setup_larchtypes()
253 out = obj
254 classname = obj.pop('__class__', None)
255 if classname is None:
256 return obj
258 if classname == 'Complex':
259 out = obj['value'][0] + 1j*obj['value'][1]
260 elif classname in ('List', 'Tuple', 'NamedTuple'):
261 out = []
262 for item in obj['value']:
263 out.append(decode4js(item))
264 if classname == 'Tuple':
265 out = tuple(out)
266 elif classname == 'NamedTuple':
267 out = namedtuple(obj['__name__'], obj['_fields'])(*out)
268 elif classname == 'Array':
269 if obj['__dtype__'].startswith('complex'):
270 re = np.asarray(obj['value'][0], dtype='double')
271 im = np.asarray(obj['value'][1], dtype='double')
272 out = re + 1j*im
273 elif obj['__dtype__'].startswith('object'):
274 val = [decode4js(v) for v in obj['value']]
275 out = np.array(val, dtype=obj['__dtype__'])
277 else:
278 out = np.asarray(obj['value'], dtype=obj['__dtype__'])
279 out.shape = obj['__shape__']
280 elif classname in ('Dict', 'dict'):
281 out = {}
282 for key, val in obj.items():
283 out[key] = decode4js(val)
284 elif classname == 'Datetime':
285 obj = datetime.fromisoformat(obj['isotime'])
286 elif classname in ('Path', 'PosixPath'):
287 obj = Path(obj['value'])
289 elif classname == 'IOBase':
290 mode = 'r'
291 if obj['readable'] and obj['writable']:
292 mode = 'a'
293 elif not obj['readable'] and obj['writable']:
294 mode = 'w'
295 out = open(obj['name'], mode=mode)
296 if obj['closed']:
297 out.close()
299 elif classname == 'Module':
300 out = importlib.import_module(obj.__name__)
302 elif classname == 'Parameters':
303 out = Parameters()
304 out.clear()
305 unique_symbols = {key: decode4js(obj['unique_symbols'][key]) for key
306 in obj['unique_symbols']}
308 state = {'unique_symbols': unique_symbols, 'params': []}
309 for name, parstate in obj['params']:
310 par = Parameter(decode4js(name))
311 par.__setstate__(decode4js(parstate))
312 state['params'].append(par)
313 out.__setstate__(state)
314 elif classname in ('Parameter', 'parameter'):
315 name = decode4js(obj['name'])
316 state = decode4js(obj['state'])
317 out = Parameter(name)
318 out.__setstate__(state)
320 elif classname == 'Model':
321 mod = Model(lambda x: x)
322 out = mod.loads(decode4js(obj['value']))
324 elif classname == 'ModelResult':
325 params = Parameters()
326 res = ModelResult(Model(lambda x: x, None), params)
327 out = res.loads(decode4js(obj['value']))
329 elif classname == 'Logger':
330 out = getLogger(obj['name'], level=obj['level'])
332 elif classname == 'StatefulObject':
333 dtype = obj.get('__type__')
334 if dtype in HAS_STATE:
335 out = HAS_STATE[dtype]()
336 out.__setstate__(decode4js(obj.get('value')))
337 elif dtype == 'Minimizer':
338 out = unpack_minimizer(out['value'])
339 else:
340 print(f"Warning: cannot re-create stateful object of type '{dtype}'")
342 elif classname in LarchGroupTypes:
343 out = {}
344 for key, val in obj.items():
345 if (isinstance(val, dict) and
346 val.get('__class__', None) == 'Method' and
347 val.get('__name__', None) is not None):
348 pass # ignore class methods for subclassed Groups
349 else:
350 out[key] = decode4js(val)
351 if classname == 'Minimizer':
352 out = unpack_minimizer(out)
353 elif classname == 'FeffDatFile':
354 from larch.xafs import FeffDatFile
355 path = FeffDatFile()
356 path._set_from_dict(**out)
357 out = path
358 else:
359 out = LarchGroupTypes[classname](**out)
360 elif classname == 'Method':
361 mname = obj.get('__name__', '')
362 if 'ufunc' in mname:
363 mname = mname.replace('<ufunc', '').replace('>', '').replace("'","").strip()
364 out = SCIPY_FUNCTIONS.get(mname, None)
366 else:
367 print("cannot decode ", classname, repr(obj)[:100])
368 return out