Coverage for sparc/sparc_parsers/atoms.py: 97%

198 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 16:19 +0000

1"""Convert ase atoms to structured dict following SPARC format 

2and vice versa 

3""" 

4 

5from copy import deepcopy 

6from warnings import warn 

7 

8import numpy as np 

9from ase import Atom, Atoms 

10from ase.constraints import FixAtoms, FixedLine, FixedPlane 

11from ase.units import Bohr 

12 

13from .inpt import _inpt_cell_to_ase_cell 

14from .ion import _ion_coord_to_ase_pos 

15from .pseudopotential import find_pseudo_path 

16from .utils import make_reverse_mapping 

17 

18# from .sparc_parsers.ion import read_ion, write_ion 

19 

20 

21def atoms_to_dict( 

22 atoms, 

23 sort=True, 

24 direct=False, 

25 wrap=False, 

26 ignore_constraints=False, 

27 psp_dir=None, 

28 pseudopotentials={}, 

29 comments="", 

30): 

31 """Given an ASE Atoms object, convert to SPARC ion and inpt data dict 

32 

33 psp_dir: search path for psp8 files pseudopotentials: a mapping 

34 between symbol and psp file names, similar to QE like 'Na': 

35 'Na-pbe.psp8'. If the file name does not contain path information, 

36 use psp_dir / filname, otherwise use the file path. 

37 

38 We don't do any env variable replace ment for psp_dir, it should be handled by the 

39 explicit _write_ion_and_inpt() function 

40 

41 At this step, the copy_psp is not applied, since we don't yet know the location to write 

42 

43 """ 

44 # Step 1: if we should sort the atoms? 

45 # origin_atoms = atoms.copy() 

46 # sort = True re-calculate the sorting information 

47 # sort = list re-uses the sorting information 

48 if sort: 

49 if isinstance(sort, list): 

50 sort_ = np.array(sort) 

51 resort_ = make_reverse_mapping(sort_) 

52 else: 

53 sort_ = np.argsort(atoms.get_chemical_symbols(), kind="stable") 

54 resort_ = make_reverse_mapping(sort_) 

55 # This is the sorted atoms object 

56 atoms = atoms[sort_] 

57 else: 

58 sort_ = [] 

59 resort_ = [] 

60 

61 # Step 2: determine the counts of each element 

62 symbol_counts = count_symbols(atoms.get_chemical_symbols()) 

63 write_spin = np.any(atoms.get_initial_magnetic_moments() != 0) 

64 has_charge = np.any(atoms.get_initial_charges() != 0) 

65 if has_charge: 

66 warn( 

67 "SPARC currently doesn't support changing total number of electrons! " 

68 "via nomimal charges. The initial charges in the structure will be ignored." 

69 ) 

70 

71 relax_mask = relax_from_all_constraints(atoms.constraints, len(atoms)) 

72 write_relax = (len(relax_mask) > 0) and (not ignore_constraints) 

73 

74 atom_blocks = [] 

75 # Step 3: write each block 

76 for symbol, start, end in symbol_counts: 

77 block_dict = {} 

78 block_dict["ATOM_TYPE"] = symbol 

79 block_dict["N_TYPE_ATOM"] = end - start 

80 # TODO: make pseudo finding work 

81 # TODO: write comment that psp file may not exist 

82 try: 

83 psp_file = find_pseudo_path(symbol, psp_dir, pseudopotentials) 

84 # TODO: add option to determine if psp file exists! 

85 block_dict["PSEUDO_POT"] = psp_file.resolve().as_posix() 

86 

87 except Exception: 

88 warn( 

89 ( 

90 f"Failed to find pseudo potential file for symbol {symbol}. I will use a dummy file name" 

91 ) 

92 ) 

93 block_dict[ 

94 "PSEUDO_POT" 

95 ] = f"{symbol}-dummy.psp8 # Please replace with real psp file name!" 

96 # TODO: atomic mass? 

97 p_atoms = atoms[start:end] 

98 if direct: 

99 pos = p_atoms.get_scaled_positions(wrap=wrap) 

100 block_dict["COORD_FRAC"] = pos 

101 else: 

102 # TODO: should we use default converter? 

103 pos = p_atoms.get_positions(wrap=wrap) / Bohr 

104 block_dict["COORD"] = pos 

105 if write_spin: 

106 # TODO: should we process atoms with already calculated magmoms? 

107 n_atom = len(p_atoms) 

108 block_dict["SPIN"] = p_atoms.get_initial_magnetic_moments().reshape( 

109 n_atom, -1 

110 ) 

111 if write_relax: 

112 relax_this_block = relax_mask[start:end] 

113 block_dict["RELAX"] = relax_this_block 

114 # TODO: get write_relax 

115 atom_blocks.append(block_dict) 

116 

117 # Step 4: inpt part 

118 # TODO: what if atoms does not have cell? 

119 cell_au = atoms.cell / Bohr 

120 inpt_blocks = {"LATVEC": cell_au, "LATVEC_SCALE": [1.0, 1.0, 1.0]} 

121 

122 # Step 5: retrieve boundary condition information 

123 # TODO: have to use space to join the single keywords 

124 # breakpoint() 

125 inpt_blocks.update(atoms_bc_to_sparc(atoms)) 

126 

127 if not isinstance(comments, list): 

128 comments = comments.split("\n") 

129 ion_data = { 

130 "atom_blocks": atom_blocks, 

131 "comments": comments, 

132 "sorting": {"sort": sort_, "resort": resort_}, 

133 "extra": {}, 

134 } 

135 inpt_data = {"params": inpt_blocks, "comments": []} 

136 return {"ion": ion_data, "inpt": inpt_data} 

137 

138 

139def dict_to_atoms(data_dict): 

140 """Given a SPARC struct dict, construct the ASE atoms object 

141 

142 Note: this method supports only 1 Atoms at a time 

143 """ 

144 ase_cell = _inpt_cell_to_ase_cell(data_dict) 

145 # TODO: is the new_data_dict really needed? 

146 new_data_dict = deepcopy(data_dict) 

147 _ion_coord_to_ase_pos(new_data_dict, ase_cell) 

148 # Now the real thing to construct an atom object 

149 atoms = Atoms() 

150 atoms.cell = ase_cell 

151 relax_dict = {} 

152 

153 atoms_count = 0 

154 atom_blocks = new_data_dict["ion"]["atom_blocks"] 

155 for block in atom_blocks: 

156 element = block["ATOM_TYPE"] 

157 positions = block["_ase_positions"] 

158 if positions.ndim == 1: 

159 positions = positions.reshape(1, -1) 

160 # Consider moving spins to another function 

161 spins = block.get("SPIN", None) 

162 if spins is None: 

163 spins = np.zeros(len(positions)) 

164 for pos, spin in zip(positions, spins): 

165 # TODO: What about charge? 

166 atoms.append(Atom(symbol=element, position=pos, magmom=spin)) 

167 relax = block.get("RELAX", np.array([])) 

168 # Reshape relax into 2d array 

169 relax = relax.reshape((-1, 3)) 

170 for i, r in enumerate(relax, start=atoms_count): 

171 relax_dict[i] = r 

172 atoms_count += len(positions) 

173 

174 if "sorting" in data_dict["ion"]: 

175 resort = data_dict["ion"]["sorting"].get("resort", np.arange(len(atoms))) 

176 # Resort may be None 

177 if len(resort) == 0: 

178 resort = np.arange(len(atoms)) 

179 else: 

180 resort = np.arange(len(atoms)) 

181 

182 if len(resort) != len(atoms): 

183 # TODO: new exception 

184 raise ValueError( 

185 "Length of resort mapping is different from the number of atoms!" 

186 ) 

187 # TODO: check if this mapping is correct 

188 # print(relax_dict) 

189 sort = make_reverse_mapping(resort) 

190 # print(resort, sort) 

191 sorted_relax_dict = {sort[i]: r for i, r in relax_dict.items()} 

192 # Now we do a sort on the atom indices. The atom positions read from 

193 # .ion correspond to the `sort` and we use `resort` to transform 

194 

195 # TODO: should we store the sorting information in SparcBundle? 

196 

197 atoms = atoms[resort] 

198 constraints = constraints_from_relax(sorted_relax_dict) 

199 atoms.constraints = constraints 

200 

201 # @2023.08.31 add support for PBC 

202 # TODO: move to a more modular function 

203 # TODO: Datatype for BC in the API, should it be string, or string array? 

204 sparc_bc = new_data_dict["inpt"]["params"].get("BC", "P P P").split() 

205 twist_angle = float(new_data_dict["inpt"]["params"].get("TWIST_ANGLE", 0)) 

206 modify_atoms_bc(atoms, sparc_bc, twist_angle) 

207 

208 # @2025.06.04 HUBBARD parameters 

209 # The hubbard parameters are only recorded in the atoms 

210 # for the last calculator results. Users should be careful 

211 # when reusing them 

212 hubbard_flag = new_data_dict["inpt"]["params"].get("HUBBARD_FLAG", 0) 

213 if hubbard_flag > 0: 

214 u_pairs = new_data_dict["ion"]["extra"].get("hubbard", {}) 

215 # TODO: make sure it makes sense for gpaw-like setups 

216 # we should keep whatever value SPARC uses when recording 

217 # the HUBBARD-U, and use eV-Angstrom only when using 

218 # the SPARC calculator instance 

219 # TODO: may need consistent naming for info 

220 atoms.info["hubbard_u (hartree)"] = u_pairs 

221 return atoms 

222 

223 

224def count_symbols(symbols): 

225 """Count the number of consecutive elements. 

226 Output tuple is: element, start, end 

227 For example, "CHCHHO" --> [('C', 0, 1), ('H', 1, 2), ('C', 2, 3), ('H', 3, 5), ('O', 5, 6)] 

228 """ 

229 counts = [] 

230 current_count = 1 

231 current_symbol = symbols[0] 

232 for i, symbol in enumerate(symbols[1:], start=1): 

233 if symbol == current_symbol: 

234 current_count += 1 

235 else: 

236 counts.append((current_symbol, i - current_count, i)) 

237 current_count = 1 

238 current_symbol = symbol 

239 end = len(symbols) 

240 counts.append((current_symbol, end - current_count, end)) 

241 return counts 

242 

243 

244def constraints_from_relax(relax_dict): 

245 """ 

246 Convert the SPARC RELAX fields to ASE's constraints 

247 

248 Arguments 

249 relax: bool vector of size Nx3, i.e. [[True, True, True], [True, False, False]] 

250 

251 Supported ase constraints will be FixAtoms, FixedLine and FixedPlane. 

252 For constraints in the same direction, all indices will be gathered. 

253 

254 Note: ase>=3.22 will have FixedLine and FixedPlane accepting only 1 index at a time! 

255 

256 The relax vector must be already sorted! 

257 """ 

258 if len(relax_dict) == 0: 

259 return [] 

260 

261 cons_list = [] 

262 # gathered_indices is an intermediate dict that contains 

263 # key: relax mask if not all True 

264 # indices: indices that share the same mask 

265 # 

266 gathered_indices = {} 

267 

268 # breakpoint() 

269 for i, r in relax_dict.items(): 

270 r = np.array(r) 

271 r = tuple(np.ndarray.tolist(r.astype(bool))) 

272 if np.all(r): 

273 continue 

274 

275 if r not in gathered_indices: 

276 gathered_indices[r] = [i] 

277 else: 

278 gathered_indices[r].append(i) 

279 

280 for relax_type, indices in gathered_indices.items(): 

281 degree_freedom = 3 - relax_type.count(False) 

282 

283 # DegreeF == 0 --> fix atom 

284 if degree_freedom == 0: 

285 cons_list.append(FixAtoms(indices=indices)) 

286 # DegreeF == 1 --> move along line, fix line 

287 elif degree_freedom == 1: 

288 for ind in indices: 

289 cons_list.append(FixedLine(ind, np.array(relax_type).astype(int))) 

290 # DegreeF == 1 --> move along line, fix plane 

291 elif degree_freedom == 2: 

292 for ind in indices: 

293 cons_list.append(FixedPlane(ind, (~np.array(relax_type)).astype(int))) 

294 return cons_list 

295 

296 

297def relax_from_constraint(constraint): 

298 """returns dict of {atom_index: relax_dimensions} for the given constraint""" 

299 type_name = constraint.todict()["name"] 

300 if isinstance(constraint, FixAtoms): 

301 dimensions = [False] * 3 

302 expected_free = 0 

303 elif isinstance(constraint, FixedLine): 

304 # Only supports orthogonal basis! 

305 dimensions = [d == 1 for d in constraint.dir] 

306 expected_free = 1 

307 elif isinstance(constraint, FixedPlane): 

308 dimensions = [d != 1 for d in constraint.dir] 

309 expected_free = 2 

310 else: 

311 warn( 

312 f"The constraint type {type_name} is not supported by" 

313 " SPARC's .ion format. This constraint will be" 

314 " ignored" 

315 ) 

316 return {} 

317 if dimensions.count(True) != expected_free: 

318 warn( 

319 "SPARC's .ion filetype can only support freezing entire " 

320 f"dimensions (x,y,z). The {type_name} constraint will be ignored" 

321 ) 

322 return {} 

323 return {i: dimensions for i in constraint.get_indices()} # atom indices 

324 

325 

326def relax_from_all_constraints(constraints, natoms): 

327 """converts ASE atom constraints to SPARC relaxed dimensions for the atoms""" 

328 if len(constraints) == 0: 

329 return [] 

330 

331 relax = [ 

332 [True, True, True], 

333 ] * natoms # assume relaxed in all dimensions for all atoms 

334 for c in constraints: 

335 for atom_index, rdims in relax_from_constraint(c).items(): 

336 if atom_index >= natoms: 

337 raise ValueError( 

338 ( 

339 "Number of total atoms smaller than the constraint indices!\n" 

340 "Please check your input" 

341 ) 

342 ) 

343 # There might be multiple constraints applied on one index, 

344 # always make it more constrained 

345 relax[atom_index] = list(np.bitwise_and(relax[atom_index], rdims)) 

346 return relax 

347 

348 

349def modify_atoms_bc(atoms, sparc_bc, twist_angle=0): 

350 """Modify the atoms boundary conditions in-place from the bc information 

351 sparc_bc is a keyword from inpt 

352 twist_angle is the helix twist angle in inpt 

353 

354 conversion rules: 

355 BC: P --> pbc=True 

356 BC: D, H, C --> pbc=False 

357 """ 

358 ase_bc = [] 

359 # print(sparc_bc, type(sparc_bc)) 

360 for bc_ in sparc_bc: 

361 if bc_.upper() in ["C", "H"]: 

362 warn( 

363 ( 

364 "Parsing SPARC's helix or cyclic boundary conditions" 

365 " into ASE atoms is only partially supported. " 

366 "Saving the atoms object into other format may cause " 

367 "data-loss of the SPARC-specific BC information." 

368 ) 

369 ) 

370 pbc = ( 

371 False # Do not confuse ase-gui, we'll manually handle the visualization 

372 ) 

373 elif bc_.upper() == "D": 

374 pbc = False 

375 elif bc_.upper() == "P": 

376 pbc = True 

377 else: 

378 raise ValueError("Unknown BC keyword values!") 

379 ase_bc.append(pbc) 

380 atoms.info["sparc_bc"] = [bc_.upper() for bc_ in sparc_bc] 

381 if twist_angle != 0: 

382 atoms.info["twist_angle (rad/Bohr)"] = twist_angle 

383 atoms.pbc = ase_bc 

384 return 

385 

386 

387def atoms_bc_to_sparc(atoms): 

388 """Use atoms' internal pbc and info to construct inpt blocks 

389 

390 Returns: 

391 a dict containing 'BC' or 'TWIST_ANGLE' 

392 """ 

393 sparc_bc = ["P" if bc_ else "D" for bc_ in atoms.pbc] 

394 

395 # If "sparc_bc" info is stored in the atoms object, convert again 

396 if "sparc_bc" in atoms.info.keys(): 

397 converted_bc = [] 

398 stored_sparc_bc = atoms.info["sparc_bc"] 

399 for bc1, bc2 in zip(sparc_bc, stored_sparc_bc): 

400 # We store helix and cyclic BCs as non-periodic in ase-atoms 

401 # print(bc1, bc2) 

402 if ((bc1 == "D") and (bc2 != "P")) or ((bc1 == "P") and (bc2 == "P")): 

403 converted_bc.append(bc2) 

404 else: 

405 raise ValueError( 

406 "Boundary conditions stored in ASE " 

407 "atoms.pbc and atoms.info['sparc_bc'] " 

408 "are different!" 

409 ) 

410 sparc_bc = converted_bc 

411 block = {"BC": " ".join(sparc_bc)} 

412 if "twist_angle" in atoms.info.keys(): 

413 block["TWIST_ANGLE"] = atoms.info["twist_angle (rad/Bohr)"] 

414 return block