Coverage for larch/xrd/amcsd_utils.py: 29%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-10-16 21:04 +0000

1import os 

2import sqlite3 

3from base64 import b64encode, b64decode 

4 

5import numpy as np 

6 

7from sqlalchemy import MetaData, create_engine, func, text, and_ 

8from sqlalchemy.sql import select 

9from sqlalchemy.orm import sessionmaker 

10from sqlalchemy.pool import SingletonThreadPool 

11 

12try: 

13 from pymatgen.io.cif import CifParser 

14 from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 

15 from pymatgen.core import Molecule, IMolecule, IStructure 

16 from pymatgen.core import __version__ as pmg_version 

17except: 

18 CifParser = SpacegroupAnalyzer = None 

19 Molecule = IMolecule = IStructure = None 

20 pmg_version = None 

21 

22from larch.utils.physical_constants import ATOM_SYMS, ATOM_NAMES 

23 

24__version__ = '1' 

25 

26PMG_CIF_OPTS = dict(occupancy_tolerance=10, site_tolerance=5e-3) 

27 

28 

29def make_engine(dbname): 

30 "create engine for sqlite connection" 

31 return create_engine('sqlite:///%s' % (dbname), 

32 poolclass=SingletonThreadPool, 

33 connect_args={'check_same_thread': False}) 

34 

35def isAMCSD(dbname): 

36 """whether a file is a valid AMCSD database 

37 

38 Args: 

39 dbname (string): name of AMCSD database file 

40 

41 Returns: 

42 bool: is file a valid AMCSD database 

43 

44 Notes: 

45 1. must be a sqlite db file, with tables 

46 'cif', 'elements', 'spacegroup' 

47 """ 

48 _tables = ('cif', 'elements', 'spacegroups') 

49 result = False 

50 try: 

51 engine = make_engine(dbname) 

52 meta = MetaData() 

53 meta.reflect(bind=engine) 

54 result = all([t in meta.tables for t in _tables]) 

55 except: 

56 pass 

57 return result 

58 

59 

60farray_scale = 4.e6 

61 

62def encode_farray(dat): 

63 """encodes a list of fractional coordinate as strings (stricly on (-1,1)) 

64 to an string for saving to db, to be decoded by decode_farray() 

65 preserves precision to slightly better than 6 digits 

66 """ 

67 work = [] 

68 for d in dat: 

69 if d == '?': 

70 work.append(2) # out-of-bounds as '?' 

71 elif d == '.': 

72 work.append(3) # out-of-bounds as '.' 

73 else: 

74 if '(' in d or '(' in d: 

75 d = d.replace(')', ' : ').replace('(', ' : ') 

76 d = d.split(':')[0].strip() 

77 try: 

78 fval = float(d) 

79 except ValueError: 

80 d = '0' 

81 work.append(d) 

82 x = (farray_scale*np.array([float(x) for x in work])).round() 

83 return b64encode(x.astype(np.int32).tobytes()).decode('ascii') 

84 

85def decode_farray(dat): 

86 """decodes a string encoded by encode_farray() 

87 returns list of string 

88 """ 

89 arr = np.fromstring(b64decode(dat), dtype=np.int32)/farray_scale 

90 out = [] 

91 for a in arr: 

92 if (abs(a-2.0) < 1.e-5): 

93 out.append('?') 

94 elif (abs(a-3.0) < 1.e-5): 

95 out.append('.') 

96 else: 

97 out.append(f"{a:f}") 

98 return out 

99 

100def put_optarray(dat, attr): 

101 d = dat.get(attr, '0') 

102 if d != '0': 

103 d = encode_farray(d) 

104 return d 

105 

106def get_optarray(dat): 

107 if dat not in (0, '0'): 

108 dat = decode_farray(dat) 

109 return dat 

110 

111 

112schema = ( 

113 '''CREATE TABLE version (id integer primary key, tag text, date text, notes text);''', 

114 '''CREATE TABLE elements ( 

115 id integer not null, 

116 z INTEGER NOT NULL, 

117 name VARCHAR(40), 

118 symbol VARCHAR(2) NOT NULL primary key);''', 

119 

120 '''CREATE TABLE spacegroups ( 

121 id INTEGER primary key, 

122 hm_notation VARCHAR(16) not null unique, 

123 symmetry_xyz text NOT NULL, 

124 category text );''', 

125 

126 '''CREATE TABLE minerals ( 

127 id INTEGER not null primary key, 

128 name text not null unique);''', 

129 

130 '''CREATE TABLE authors ( 

131 id INTEGER NOT NULL primary key, 

132 name text unique);''', 

133 '''CREATE TABLE publications ( 

134 id INTEGER NOT NULL primary key, 

135 journalname text not null, 

136 volume text, 

137 year integer not null, 

138 page_first text, 

139 page_last text);''', 

140 

141 '''CREATE TABLE publication_authors ( 

142 publication_id INTEGER not null, 

143 author_id integer not null, 

144 FOREIGN KEY(publication_id) REFERENCES publications (id), 

145 FOREIGN KEY(author_id) REFERENCES authors (id));''', 

146 

147 '''CREATE TABLE cif ( 

148 id integer not null primary key, 

149 mineral_id INTEGER, 

150 spacegroup_id INTEGER, 

151 publication_id INTEGER, 

152 formula text, 

153 compound text, 

154 pub_title text, 

155 formula_title text, 

156 a text, 

157 b text, 

158 c text, 

159 alpha text, 

160 beta text, 

161 gamma text, 

162 cell_volume text, 

163 crystal_density text, 

164 atoms_sites text, 

165 atoms_x text, 

166 atoms_y text, 

167 atoms_z text, 

168 atoms_occupancy text, 

169 atoms_u_iso text, 

170 atoms_aniso_label text, 

171 atoms_aniso_u11 text, 

172 atoms_aniso_u22 text, 

173 atoms_aniso_u33 text, 

174 atoms_aniso_u12 text, 

175 atoms_aniso_u13 text, 

176 atoms_aniso_u23 text, 

177 qdat text, 

178 amcsd_url text, 

179 FOREIGN KEY(spacegroup_id) REFERENCES spacegroups (id), 

180 FOREIGN KEY(mineral_id) REFERENCES minerals (id), 

181 FOREIGN KEY(publication_id) REFERENCES publications (id));''', 

182 

183 '''CREATE TABLE cif_elements ( 

184 cif_id text not null, 

185 element VARCHAR(2) not null);''', 

186 ) 

187 

188 

189def create_amcsd(dbname='test.db'): 

190 if os.path.exists(dbname): 

191 os.unlink(dbname) 

192 

193 conn = sqlite3.connect(dbname) 

194 cursor = conn.cursor() 

195 for s in schema: 

196 cursor.execute(s) 

197 

198 cursor.execute('insert into version values (?,?,?,?)', 

199 ('0', 'in progress', 'today', 'in progress')) 

200 

201 atz, i = 0, 0 

202 for sym, name in zip(ATOM_SYMS, ATOM_NAMES): 

203 i += 1 

204 atz += 1 

205 if sym == 'D': 

206 atz = 1 

207 cursor.execute('insert into elements values (?,?,?,?)', (i, atz, sym, name))