Coverage for larch/io/xas_data_source/nexus.py: 85%

117 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-10-16 21:04 +0000

1import re 

2from contextlib import contextmanager 

3from typing import Iterator, List, Optional, Tuple 

4import numpy 

5import h5py 

6from . import base 

7from . import hdf5_utils 

8 

9 

10class NexusSingleXasDataSource(base.XasDataSource): 

11 """NeXus compliant HDF5 file. Each NXentry contains 1 XAS spectrum.""" 

12 

13 TYPE = "HDF5-NEXUS" 

14 

15 def __init__( 

16 self, 

17 filename: str, 

18 title_regex_pattern: Optional[str] = None, 

19 counter_group: Optional[str] = None, 

20 **kw, 

21 ) -> None: 

22 self._nxroot = None 

23 if title_regex_pattern: 

24 title_regex_pattern = re.compile(title_regex_pattern) 

25 self._title_regex_pattern = title_regex_pattern 

26 self._counter_group = counter_group 

27 self._instrument = None 

28 super().__init__(filename, **kw) 

29 

30 def get_source_info(self) -> str: 

31 return f"HDF5: {self._filename}" 

32 

33 def get_scan(self, scan_name: str) -> Optional[base.XasScan]: 

34 with self._open() as nxroot: 

35 scan = nxroot[scan_name] 

36 datasets = sorted(self._iter_datasets(scan), key=lambda tpl: tpl[0]) 

37 if datasets: 

38 labels, data = zip(*datasets) 

39 else: 

40 labels = list() 

41 data = list() 

42 description = self._get_string(scan, "title") 

43 if not description: 

44 description = scan_name 

45 start_time = self._get_string(scan, "start_time") 

46 return base.XasScan( 

47 name=scan_name, 

48 description=description, 

49 start_time=start_time, 

50 info=description, 

51 labels=list(labels), 

52 data=numpy.asarray(data), 

53 ) 

54 

55 def get_scan_names(self) -> List[str]: 

56 return list(self._iter_scan_names()) 

57 

58 def _iter_scan_names(self) -> Iterator[str]: 

59 with self._open() as nxroot: 

60 for name in nxroot["/"]: # index at "/" to preserve order 

61 try: 

62 scan = nxroot[name] 

63 except KeyError: 

64 continue # broken link 

65 if self._title_regex_pattern is not None: 

66 title = self._get_string(scan, "title") 

67 if not self._title_regex_pattern.match(title): 

68 continue 

69 yield name 

70 

71 @contextmanager 

72 def _open(self) -> Iterator[h5py.File]: 

73 """Re-entrant context to get access to the HDF5 file""" 

74 if self._nxroot is not None: 

75 yield self._nxroot 

76 return 

77 with hdf5_utils.open(self._filename) as nxroot: 

78 self._nxroot = nxroot 

79 try: 

80 yield nxroot 

81 finally: 

82 self._nxroot = None 

83 

84 def _iter_datasets(self, scan: h5py.Group) -> Iterator[Tuple[str, h5py.Dataset]]: 

85 if self._counter_group: 

86 yield from self._iter_counter_group(scan) 

87 else: 

88 yield from self._iter_instrument_group(scan) 

89 

90 def _iter_counter_group( 

91 self, scan: h5py.Group 

92 ) -> Iterator[Tuple[str, h5py.Dataset]]: 

93 try: 

94 counter_group = scan[self._counter_group] 

95 except KeyError: 

96 return # broken link or not existing 

97 for name in counter_group: 

98 try: 

99 dset = counter_group[name] 

100 except KeyError: 

101 continue # broken link 

102 if not hasattr(dset, "ndim"): 

103 continue 

104 if dset.ndim == 1: 

105 yield name, dset 

106 

107 def _iter_instrument_group( 

108 self, scan: h5py.Group 

109 ) -> Iterator[Tuple[str, h5py.Dataset]]: 

110 instrument = self._get_instrument(scan) 

111 if instrument is None: 

112 return 

113 for name in instrument: 

114 try: 

115 detector = instrument[name] 

116 except KeyError: 

117 continue # broken link 

118 nxclass = detector.attrs.get("NX_class") 

119 if nxclass not in ("NXdetector", "NXpositioner"): 

120 continue 

121 try: 

122 if nxclass == "NXpositioner": 

123 dset = detector["value"] 

124 else: 

125 dset = detector["data"] 

126 except KeyError: 

127 continue # no data 

128 if dset.ndim == 1: 

129 yield name, dset 

130 

131 def _get_instrument(self, scan: h5py.Group) -> Optional[h5py.Group]: 

132 if self._instrument: 

133 return scan[self._instrument] 

134 instrument = hdf5_utils.find_nexus_class(scan, "NXinstrument") 

135 if instrument is not None: 

136 self._instrument = instrument.name.split("/")[-1] 

137 return instrument 

138 

139 def _get_string(self, group: h5py.Group, name) -> str: 

140 try: 

141 s = group[name][()] 

142 except KeyError: 

143 return "" 

144 return hdf5_utils.asstr(s) 

145 

146 

147class EsrfSingleXasDataSource(NexusSingleXasDataSource): 

148 TYPE = "HDF5-NEXUS-ESRF" 

149 

150 def __init__(self, filename: str, **kw) -> None: 

151 kw.setdefault("counter_group", "measurement") 

152 super().__init__(filename, **kw) 

153 

154 

155class SoleilSingleXasDataSource(NexusSingleXasDataSource): 

156 TYPE = "HDF5-NEXUS-SOLEIL" 

157 

158 def __init__(self, filename: str, **kw) -> None: 

159 kw.setdefault("counter_group", "scan_data") 

160 super().__init__(filename, **kw)