Coverage for sparc/sparc_parsers/atoms.py: 97%

195 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 01:13 +0000

1"""Convert ase atoms to structured dict following SPARC format 

2and vice versa 

3""" 

4 

5from copy import deepcopy 

6from warnings import warn 

7 

8import numpy as np 

9from ase import Atom, Atoms 

10from ase.constraints import FixAtoms, FixedLine, FixedPlane 

11from ase.units import Bohr 

12 

13from .inpt import _inpt_cell_to_ase_cell 

14from .ion import _ion_coord_to_ase_pos 

15from .pseudopotential import find_pseudo_path 

16from .utils import make_reverse_mapping 

17 

18# from .sparc_parsers.ion import read_ion, write_ion 

19 

20 

21def atoms_to_dict( 

22 atoms, 

23 sort=True, 

24 direct=False, 

25 wrap=False, 

26 ignore_constraints=False, 

27 psp_dir=None, 

28 pseudopotentials={}, 

29 comments="", 

30): 

31 """Given an ASE Atoms object, convert to SPARC ion and inpt data dict 

32 

33 psp_dir: search path for psp8 files pseudopotentials: a mapping 

34 between symbol and psp file names, similar to QE like 'Na': 

35 'Na-pbe.psp8'. If the file name does not contain path information, 

36 use psp_dir / filname, otherwise use the file path. 

37 

38 We don't do any env variable replace ment for psp_dir, it should be handled by the 

39 explicit _write_ion_and_inpt() function 

40 

41 At this step, the copy_psp is not applied, since we don't yet know the location to write 

42 

43 """ 

44 # Step 1: if we should sort the atoms? 

45 # origin_atoms = atoms.copy() 

46 # sort = True re-calculate the sorting information 

47 # sort = list re-uses the sorting information 

48 if sort: 

49 if isinstance(sort, list): 

50 sort_ = np.array(sort) 

51 resort_ = make_reverse_mapping(sort_) 

52 else: 

53 sort_ = np.argsort(atoms.get_chemical_symbols(), kind="stable") 

54 resort_ = make_reverse_mapping(sort_) 

55 # This is the sorted atoms object 

56 atoms = atoms[sort_] 

57 else: 

58 sort_ = [] 

59 resort_ = [] 

60 

61 # Step 2: determine the counts of each element 

62 symbol_counts = count_symbols(atoms.get_chemical_symbols()) 

63 write_spin = np.any(atoms.get_initial_magnetic_moments() != 0) 

64 has_charge = np.any(atoms.get_initial_charges() != 0) 

65 if has_charge: 

66 warn( 

67 "SPARC currently doesn't support changing total number of electrons! " 

68 "via nomimal charges. The initial charges in the structure will be ignored." 

69 ) 

70 

71 relax_mask = relax_from_all_constraints(atoms.constraints, len(atoms)) 

72 write_relax = (len(relax_mask) > 0) and (not ignore_constraints) 

73 

74 atom_blocks = [] 

75 # Step 3: write each block 

76 for symbol, start, end in symbol_counts: 

77 block_dict = {} 

78 block_dict["ATOM_TYPE"] = symbol 

79 block_dict["N_TYPE_ATOM"] = end - start 

80 # TODO: make pseudo finding work 

81 # TODO: write comment that psp file may not exist 

82 try: 

83 psp_file = find_pseudo_path(symbol, psp_dir, pseudopotentials) 

84 # TODO: add option to determine if psp file exists! 

85 block_dict["PSEUDO_POT"] = psp_file.resolve().as_posix() 

86 

87 except Exception: 

88 warn( 

89 ( 

90 f"Failed to find pseudo potential file for symbol {symbol}. I will use a dummy file name" 

91 ) 

92 ) 

93 block_dict[ 

94 "PSEUDO_POT" 

95 ] = f"{symbol}-dummy.psp8 # Please replace with real psp file name!" 

96 # TODO: atomic mass? 

97 p_atoms = atoms[start:end] 

98 if direct: 

99 pos = p_atoms.get_scaled_positions(wrap=wrap) 

100 block_dict["COORD_FRAC"] = pos 

101 else: 

102 # TODO: should we use default converter? 

103 pos = p_atoms.get_positions(wrap=wrap) / Bohr 

104 block_dict["COORD"] = pos 

105 if write_spin: 

106 # TODO: should we process atoms with already calculated magmoms? 

107 n_atom = len(p_atoms) 

108 block_dict["SPIN"] = p_atoms.get_initial_magnetic_moments().reshape(n_atom,-1) 

109 if write_relax: 

110 relax_this_block = relax_mask[start:end] 

111 block_dict["RELAX"] = relax_this_block 

112 # TODO: get write_relax 

113 atom_blocks.append(block_dict) 

114 

115 # Step 4: inpt part 

116 # TODO: what if atoms does not have cell? 

117 cell_au = atoms.cell / Bohr 

118 inpt_blocks = {"LATVEC": cell_au, "LATVEC_SCALE": [1.0, 1.0, 1.0]} 

119 

120 # Step 5: retrieve boundary condition information 

121 # TODO: have to use space to join the single keywords 

122 # breakpoint() 

123 inpt_blocks.update(atoms_bc_to_sparc(atoms)) 

124 

125 if not isinstance(comments, list): 

126 comments = comments.split("\n") 

127 ion_data = { 

128 "atom_blocks": atom_blocks, 

129 "comments": comments, 

130 "sorting": {"sort": sort_, "resort": resort_}, 

131 } 

132 inpt_data = {"params": inpt_blocks, "comments": []} 

133 return {"ion": ion_data, "inpt": inpt_data} 

134 

135 

136def dict_to_atoms(data_dict): 

137 """Given a SPARC struct dict, construct the ASE atoms object 

138 

139 Note: this method supports only 1 Atoms at a time 

140 """ 

141 ase_cell = _inpt_cell_to_ase_cell(data_dict) 

142 new_data_dict = deepcopy(data_dict) 

143 _ion_coord_to_ase_pos(new_data_dict, ase_cell) 

144 # Now the real thing to construct an atom object 

145 atoms = Atoms() 

146 atoms.cell = ase_cell 

147 relax_dict = {} 

148 

149 atoms_count = 0 

150 atom_blocks = new_data_dict["ion"]["atom_blocks"] 

151 for block in atom_blocks: 

152 element = block["ATOM_TYPE"] 

153 positions = block["_ase_positions"] 

154 if positions.ndim == 1: 

155 positions = positions.reshape(1, -1) 

156 # Consider moving spins to another function 

157 spins = block.get("SPIN", None) 

158 if spins is None: 

159 spins = np.zeros(len(positions)) 

160 for pos, spin in zip(positions, spins): 

161 # TODO: What about charge? 

162 atoms.append(Atom(symbol=element, position=pos, magmom=spin)) 

163 relax = block.get("RELAX", np.array([])) 

164 # Reshape relax into 2d array 

165 relax = relax.reshape((-1, 3)) 

166 for i, r in enumerate(relax, start=atoms_count): 

167 relax_dict[i] = r 

168 atoms_count += len(positions) 

169 

170 if "sorting" in data_dict["ion"]: 

171 resort = data_dict["ion"]["sorting"].get("resort", np.arange(len(atoms))) 

172 # Resort may be None 

173 if len(resort) == 0: 

174 resort = np.arange(len(atoms)) 

175 else: 

176 resort = np.arange(len(atoms)) 

177 

178 if len(resort) != len(atoms): 

179 # TODO: new exception 

180 raise ValueError( 

181 "Length of resort mapping is different from the number of atoms!" 

182 ) 

183 # TODO: check if this mapping is correct 

184 # print(relax_dict) 

185 sort = make_reverse_mapping(resort) 

186 # print(resort, sort) 

187 sorted_relax_dict = {sort[i]: r for i, r in relax_dict.items()} 

188 # Now we do a sort on the atom indices. The atom positions read from 

189 # .ion correspond to the `sort` and we use `resort` to transform 

190 

191 # TODO: should we store the sorting information in SparcBundle? 

192 

193 atoms = atoms[resort] 

194 constraints = constraints_from_relax(sorted_relax_dict) 

195 atoms.constraints = constraints 

196 

197 # @2023.08.31 add support for PBC 

198 # TODO: move to a more modular function 

199 # TODO: Datatype for BC in the API, should it be string, or string array? 

200 sparc_bc = new_data_dict["inpt"]["params"].get("BC", "P P P").split() 

201 twist_angle = float(new_data_dict["inpt"]["params"].get("TWIST_ANGLE", 0)) 

202 modify_atoms_bc(atoms, sparc_bc, twist_angle) 

203 

204 return atoms 

205 

206 

207def count_symbols(symbols): 

208 """Count the number of consecutive elements. 

209 Output tuple is: element, start, end 

210 For example, "CHCHHO" --> [('C', 0, 1), ('H', 1, 2), ('C', 2, 3), ('H', 3, 5), ('O', 5, 6)] 

211 """ 

212 counts = [] 

213 current_count = 1 

214 current_symbol = symbols[0] 

215 for i, symbol in enumerate(symbols[1:], start=1): 

216 if symbol == current_symbol: 

217 current_count += 1 

218 else: 

219 counts.append((current_symbol, i - current_count, i)) 

220 current_count = 1 

221 current_symbol = symbol 

222 end = len(symbols) 

223 counts.append((current_symbol, end - current_count, end)) 

224 return counts 

225 

226 

227def constraints_from_relax(relax_dict): 

228 """ 

229 Convert the SPARC RELAX fields to ASE's constraints 

230 

231 Arguments 

232 relax: bool vector of size Nx3, i.e. [[True, True, True], [True, False, False]] 

233 

234 Supported ase constraints will be FixAtoms, FixedLine and FixedPlane. 

235 For constraints in the same direction, all indices will be gathered. 

236 

237 Note: ase>=3.22 will have FixedLine and FixedPlane accepting only 1 index at a time! 

238 

239 The relax vector must be already sorted! 

240 """ 

241 if len(relax_dict) == 0: 

242 return [] 

243 

244 cons_list = [] 

245 # gathered_indices is an intermediate dict that contains 

246 # key: relax mask if not all True 

247 # indices: indices that share the same mask 

248 # 

249 gathered_indices = {} 

250 

251 # breakpoint() 

252 for i, r in relax_dict.items(): 

253 r = np.array(r) 

254 r = tuple(np.ndarray.tolist(r.astype(bool))) 

255 if np.all(r): 

256 continue 

257 

258 if r not in gathered_indices: 

259 gathered_indices[r] = [i] 

260 else: 

261 gathered_indices[r].append(i) 

262 

263 for relax_type, indices in gathered_indices.items(): 

264 degree_freedom = 3 - relax_type.count(False) 

265 

266 # DegreeF == 0 --> fix atom 

267 if degree_freedom == 0: 

268 cons_list.append(FixAtoms(indices=indices)) 

269 # DegreeF == 1 --> move along line, fix line 

270 elif degree_freedom == 1: 

271 for ind in indices: 

272 cons_list.append(FixedLine(ind, np.array(relax_type).astype(int))) 

273 # DegreeF == 1 --> move along line, fix plane 

274 elif degree_freedom == 2: 

275 for ind in indices: 

276 cons_list.append(FixedPlane(ind, (~np.array(relax_type)).astype(int))) 

277 return cons_list 

278 

279 

280def relax_from_constraint(constraint): 

281 """returns dict of {atom_index: relax_dimensions} for the given constraint""" 

282 type_name = constraint.todict()["name"] 

283 if isinstance(constraint, FixAtoms): 

284 dimensions = [False] * 3 

285 expected_free = 0 

286 elif isinstance(constraint, FixedLine): 

287 # Only supports orthogonal basis! 

288 dimensions = [d == 1 for d in constraint.dir] 

289 expected_free = 1 

290 elif isinstance(constraint, FixedPlane): 

291 dimensions = [d != 1 for d in constraint.dir] 

292 expected_free = 2 

293 else: 

294 warn( 

295 f"The constraint type {type_name} is not supported by" 

296 " SPARC's .ion format. This constraint will be" 

297 " ignored" 

298 ) 

299 return {} 

300 if dimensions.count(True) != expected_free: 

301 warn( 

302 "SPARC's .ion filetype can only support freezing entire " 

303 f"dimensions (x,y,z). The {type_name} constraint will be ignored" 

304 ) 

305 return {} 

306 return {i: dimensions for i in constraint.get_indices()} # atom indices 

307 

308 

309def relax_from_all_constraints(constraints, natoms): 

310 """converts ASE atom constraints to SPARC relaxed dimensions for the atoms""" 

311 if len(constraints) == 0: 

312 return [] 

313 

314 relax = [ 

315 [True, True, True], 

316 ] * natoms # assume relaxed in all dimensions for all atoms 

317 for c in constraints: 

318 for atom_index, rdims in relax_from_constraint(c).items(): 

319 if atom_index >= natoms: 

320 raise ValueError( 

321 ( 

322 "Number of total atoms smaller than the constraint indices!\n" 

323 "Please check your input" 

324 ) 

325 ) 

326 # There might be multiple constraints applied on one index, 

327 # always make it more constrained 

328 relax[atom_index] = list(np.bitwise_and(relax[atom_index], rdims)) 

329 return relax 

330 

331 

332def modify_atoms_bc(atoms, sparc_bc, twist_angle=0): 

333 """Modify the atoms boundary conditions in-place from the bc information 

334 sparc_bc is a keyword from inpt 

335 twist_angle is the helix twist angle in inpt 

336 

337 conversion rules: 

338 BC: P --> pbc=True 

339 BC: D, H, C --> pbc=False 

340 """ 

341 ase_bc = [] 

342 # print(sparc_bc, type(sparc_bc)) 

343 for bc_ in sparc_bc: 

344 if bc_.upper() in ["C", "H"]: 

345 warn( 

346 ( 

347 "Parsing SPARC's helix or cyclic boundary conditions" 

348 " into ASE atoms is only partially supported. " 

349 "Saving the atoms object into other format may cause " 

350 "data-loss of the SPARC-specific BC information." 

351 ) 

352 ) 

353 pbc = ( 

354 False # Do not confuse ase-gui, we'll manually handle the visualization 

355 ) 

356 elif bc_.upper() == "D": 

357 pbc = False 

358 elif bc_.upper() == "P": 

359 pbc = True 

360 else: 

361 raise ValueError("Unknown BC keyword values!") 

362 ase_bc.append(pbc) 

363 atoms.info["sparc_bc"] = [bc_.upper() for bc_ in sparc_bc] 

364 if twist_angle != 0: 

365 atoms.info["twist_angle (rad/Bohr)"] = twist_angle 

366 atoms.pbc = ase_bc 

367 return 

368 

369 

370def atoms_bc_to_sparc(atoms): 

371 """Use atoms' internal pbc and info to construct inpt blocks 

372 

373 Returns: 

374 a dict containing 'BC' or 'TWIST_ANGLE' 

375 """ 

376 sparc_bc = ["P" if bc_ else "D" for bc_ in atoms.pbc] 

377 

378 # If "sparc_bc" info is stored in the atoms object, convert again 

379 if "sparc_bc" in atoms.info.keys(): 

380 converted_bc = [] 

381 stored_sparc_bc = atoms.info["sparc_bc"] 

382 for bc1, bc2 in zip(sparc_bc, stored_sparc_bc): 

383 # We store helix and cyclic BCs as non-periodic in ase-atoms 

384 print(bc1, bc2) 

385 if ((bc1 == "D") and (bc2 != "P")) or ((bc1 == "P") and (bc2 == "P")): 

386 converted_bc.append(bc2) 

387 else: 

388 raise ValueError( 

389 "Boundary conditions stored in ASE " 

390 "atoms.pbc and atoms.info['sparc_bc'] " 

391 "are different!" 

392 ) 

393 sparc_bc = converted_bc 

394 block = {"BC": " ".join(sparc_bc)} 

395 if "twist_angle" in atoms.info.keys(): 

396 block["TWIST_ANGLE"] = atoms.info["twist_angle (rad/Bohr)"] 

397 return block