Coverage for sparc/sparc_parsers/atoms.py: 97%
198 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
1"""Convert ase atoms to structured dict following SPARC format
2and vice versa
3"""
5from copy import deepcopy
6from warnings import warn
8import numpy as np
9from ase import Atom, Atoms
10from ase.constraints import FixAtoms, FixedLine, FixedPlane
11from ase.units import Bohr
13from .inpt import _inpt_cell_to_ase_cell
14from .ion import _ion_coord_to_ase_pos
15from .pseudopotential import find_pseudo_path
16from .utils import make_reverse_mapping
18# from .sparc_parsers.ion import read_ion, write_ion
21def atoms_to_dict(
22 atoms,
23 sort=True,
24 direct=False,
25 wrap=False,
26 ignore_constraints=False,
27 psp_dir=None,
28 pseudopotentials={},
29 comments="",
30):
31 """Given an ASE Atoms object, convert to SPARC ion and inpt data dict
33 psp_dir: search path for psp8 files pseudopotentials: a mapping
34 between symbol and psp file names, similar to QE like 'Na':
35 'Na-pbe.psp8'. If the file name does not contain path information,
36 use psp_dir / filname, otherwise use the file path.
38 We don't do any env variable replace ment for psp_dir, it should be handled by the
39 explicit _write_ion_and_inpt() function
41 At this step, the copy_psp is not applied, since we don't yet know the location to write
43 """
44 # Step 1: if we should sort the atoms?
45 # origin_atoms = atoms.copy()
46 # sort = True re-calculate the sorting information
47 # sort = list re-uses the sorting information
48 if sort:
49 if isinstance(sort, list):
50 sort_ = np.array(sort)
51 resort_ = make_reverse_mapping(sort_)
52 else:
53 sort_ = np.argsort(atoms.get_chemical_symbols(), kind="stable")
54 resort_ = make_reverse_mapping(sort_)
55 # This is the sorted atoms object
56 atoms = atoms[sort_]
57 else:
58 sort_ = []
59 resort_ = []
61 # Step 2: determine the counts of each element
62 symbol_counts = count_symbols(atoms.get_chemical_symbols())
63 write_spin = np.any(atoms.get_initial_magnetic_moments() != 0)
64 has_charge = np.any(atoms.get_initial_charges() != 0)
65 if has_charge:
66 warn(
67 "SPARC currently doesn't support changing total number of electrons! "
68 "via nomimal charges. The initial charges in the structure will be ignored."
69 )
71 relax_mask = relax_from_all_constraints(atoms.constraints, len(atoms))
72 write_relax = (len(relax_mask) > 0) and (not ignore_constraints)
74 atom_blocks = []
75 # Step 3: write each block
76 for symbol, start, end in symbol_counts:
77 block_dict = {}
78 block_dict["ATOM_TYPE"] = symbol
79 block_dict["N_TYPE_ATOM"] = end - start
80 # TODO: make pseudo finding work
81 # TODO: write comment that psp file may not exist
82 try:
83 psp_file = find_pseudo_path(symbol, psp_dir, pseudopotentials)
84 # TODO: add option to determine if psp file exists!
85 block_dict["PSEUDO_POT"] = psp_file.resolve().as_posix()
87 except Exception:
88 warn(
89 (
90 f"Failed to find pseudo potential file for symbol {symbol}. I will use a dummy file name"
91 )
92 )
93 block_dict[
94 "PSEUDO_POT"
95 ] = f"{symbol}-dummy.psp8 # Please replace with real psp file name!"
96 # TODO: atomic mass?
97 p_atoms = atoms[start:end]
98 if direct:
99 pos = p_atoms.get_scaled_positions(wrap=wrap)
100 block_dict["COORD_FRAC"] = pos
101 else:
102 # TODO: should we use default converter?
103 pos = p_atoms.get_positions(wrap=wrap) / Bohr
104 block_dict["COORD"] = pos
105 if write_spin:
106 # TODO: should we process atoms with already calculated magmoms?
107 n_atom = len(p_atoms)
108 block_dict["SPIN"] = p_atoms.get_initial_magnetic_moments().reshape(
109 n_atom, -1
110 )
111 if write_relax:
112 relax_this_block = relax_mask[start:end]
113 block_dict["RELAX"] = relax_this_block
114 # TODO: get write_relax
115 atom_blocks.append(block_dict)
117 # Step 4: inpt part
118 # TODO: what if atoms does not have cell?
119 cell_au = atoms.cell / Bohr
120 inpt_blocks = {"LATVEC": cell_au, "LATVEC_SCALE": [1.0, 1.0, 1.0]}
122 # Step 5: retrieve boundary condition information
123 # TODO: have to use space to join the single keywords
124 # breakpoint()
125 inpt_blocks.update(atoms_bc_to_sparc(atoms))
127 if not isinstance(comments, list):
128 comments = comments.split("\n")
129 ion_data = {
130 "atom_blocks": atom_blocks,
131 "comments": comments,
132 "sorting": {"sort": sort_, "resort": resort_},
133 "extra": {},
134 }
135 inpt_data = {"params": inpt_blocks, "comments": []}
136 return {"ion": ion_data, "inpt": inpt_data}
139def dict_to_atoms(data_dict):
140 """Given a SPARC struct dict, construct the ASE atoms object
142 Note: this method supports only 1 Atoms at a time
143 """
144 ase_cell = _inpt_cell_to_ase_cell(data_dict)
145 # TODO: is the new_data_dict really needed?
146 new_data_dict = deepcopy(data_dict)
147 _ion_coord_to_ase_pos(new_data_dict, ase_cell)
148 # Now the real thing to construct an atom object
149 atoms = Atoms()
150 atoms.cell = ase_cell
151 relax_dict = {}
153 atoms_count = 0
154 atom_blocks = new_data_dict["ion"]["atom_blocks"]
155 for block in atom_blocks:
156 element = block["ATOM_TYPE"]
157 positions = block["_ase_positions"]
158 if positions.ndim == 1:
159 positions = positions.reshape(1, -1)
160 # Consider moving spins to another function
161 spins = block.get("SPIN", None)
162 if spins is None:
163 spins = np.zeros(len(positions))
164 for pos, spin in zip(positions, spins):
165 # TODO: What about charge?
166 atoms.append(Atom(symbol=element, position=pos, magmom=spin))
167 relax = block.get("RELAX", np.array([]))
168 # Reshape relax into 2d array
169 relax = relax.reshape((-1, 3))
170 for i, r in enumerate(relax, start=atoms_count):
171 relax_dict[i] = r
172 atoms_count += len(positions)
174 if "sorting" in data_dict["ion"]:
175 resort = data_dict["ion"]["sorting"].get("resort", np.arange(len(atoms)))
176 # Resort may be None
177 if len(resort) == 0:
178 resort = np.arange(len(atoms))
179 else:
180 resort = np.arange(len(atoms))
182 if len(resort) != len(atoms):
183 # TODO: new exception
184 raise ValueError(
185 "Length of resort mapping is different from the number of atoms!"
186 )
187 # TODO: check if this mapping is correct
188 # print(relax_dict)
189 sort = make_reverse_mapping(resort)
190 # print(resort, sort)
191 sorted_relax_dict = {sort[i]: r for i, r in relax_dict.items()}
192 # Now we do a sort on the atom indices. The atom positions read from
193 # .ion correspond to the `sort` and we use `resort` to transform
195 # TODO: should we store the sorting information in SparcBundle?
197 atoms = atoms[resort]
198 constraints = constraints_from_relax(sorted_relax_dict)
199 atoms.constraints = constraints
201 # @2023.08.31 add support for PBC
202 # TODO: move to a more modular function
203 # TODO: Datatype for BC in the API, should it be string, or string array?
204 sparc_bc = new_data_dict["inpt"]["params"].get("BC", "P P P").split()
205 twist_angle = float(new_data_dict["inpt"]["params"].get("TWIST_ANGLE", 0))
206 modify_atoms_bc(atoms, sparc_bc, twist_angle)
208 # @2025.06.04 HUBBARD parameters
209 # The hubbard parameters are only recorded in the atoms
210 # for the last calculator results. Users should be careful
211 # when reusing them
212 hubbard_flag = new_data_dict["inpt"]["params"].get("HUBBARD_FLAG", 0)
213 if hubbard_flag > 0:
214 u_pairs = new_data_dict["ion"]["extra"].get("hubbard", {})
215 # TODO: make sure it makes sense for gpaw-like setups
216 # we should keep whatever value SPARC uses when recording
217 # the HUBBARD-U, and use eV-Angstrom only when using
218 # the SPARC calculator instance
219 # TODO: may need consistent naming for info
220 atoms.info["hubbard_u (hartree)"] = u_pairs
221 return atoms
224def count_symbols(symbols):
225 """Count the number of consecutive elements.
226 Output tuple is: element, start, end
227 For example, "CHCHHO" --> [('C', 0, 1), ('H', 1, 2), ('C', 2, 3), ('H', 3, 5), ('O', 5, 6)]
228 """
229 counts = []
230 current_count = 1
231 current_symbol = symbols[0]
232 for i, symbol in enumerate(symbols[1:], start=1):
233 if symbol == current_symbol:
234 current_count += 1
235 else:
236 counts.append((current_symbol, i - current_count, i))
237 current_count = 1
238 current_symbol = symbol
239 end = len(symbols)
240 counts.append((current_symbol, end - current_count, end))
241 return counts
244def constraints_from_relax(relax_dict):
245 """
246 Convert the SPARC RELAX fields to ASE's constraints
248 Arguments
249 relax: bool vector of size Nx3, i.e. [[True, True, True], [True, False, False]]
251 Supported ase constraints will be FixAtoms, FixedLine and FixedPlane.
252 For constraints in the same direction, all indices will be gathered.
254 Note: ase>=3.22 will have FixedLine and FixedPlane accepting only 1 index at a time!
256 The relax vector must be already sorted!
257 """
258 if len(relax_dict) == 0:
259 return []
261 cons_list = []
262 # gathered_indices is an intermediate dict that contains
263 # key: relax mask if not all True
264 # indices: indices that share the same mask
265 #
266 gathered_indices = {}
268 # breakpoint()
269 for i, r in relax_dict.items():
270 r = np.array(r)
271 r = tuple(np.ndarray.tolist(r.astype(bool)))
272 if np.all(r):
273 continue
275 if r not in gathered_indices:
276 gathered_indices[r] = [i]
277 else:
278 gathered_indices[r].append(i)
280 for relax_type, indices in gathered_indices.items():
281 degree_freedom = 3 - relax_type.count(False)
283 # DegreeF == 0 --> fix atom
284 if degree_freedom == 0:
285 cons_list.append(FixAtoms(indices=indices))
286 # DegreeF == 1 --> move along line, fix line
287 elif degree_freedom == 1:
288 for ind in indices:
289 cons_list.append(FixedLine(ind, np.array(relax_type).astype(int)))
290 # DegreeF == 1 --> move along line, fix plane
291 elif degree_freedom == 2:
292 for ind in indices:
293 cons_list.append(FixedPlane(ind, (~np.array(relax_type)).astype(int)))
294 return cons_list
297def relax_from_constraint(constraint):
298 """returns dict of {atom_index: relax_dimensions} for the given constraint"""
299 type_name = constraint.todict()["name"]
300 if isinstance(constraint, FixAtoms):
301 dimensions = [False] * 3
302 expected_free = 0
303 elif isinstance(constraint, FixedLine):
304 # Only supports orthogonal basis!
305 dimensions = [d == 1 for d in constraint.dir]
306 expected_free = 1
307 elif isinstance(constraint, FixedPlane):
308 dimensions = [d != 1 for d in constraint.dir]
309 expected_free = 2
310 else:
311 warn(
312 f"The constraint type {type_name} is not supported by"
313 " SPARC's .ion format. This constraint will be"
314 " ignored"
315 )
316 return {}
317 if dimensions.count(True) != expected_free:
318 warn(
319 "SPARC's .ion filetype can only support freezing entire "
320 f"dimensions (x,y,z). The {type_name} constraint will be ignored"
321 )
322 return {}
323 return {i: dimensions for i in constraint.get_indices()} # atom indices
326def relax_from_all_constraints(constraints, natoms):
327 """converts ASE atom constraints to SPARC relaxed dimensions for the atoms"""
328 if len(constraints) == 0:
329 return []
331 relax = [
332 [True, True, True],
333 ] * natoms # assume relaxed in all dimensions for all atoms
334 for c in constraints:
335 for atom_index, rdims in relax_from_constraint(c).items():
336 if atom_index >= natoms:
337 raise ValueError(
338 (
339 "Number of total atoms smaller than the constraint indices!\n"
340 "Please check your input"
341 )
342 )
343 # There might be multiple constraints applied on one index,
344 # always make it more constrained
345 relax[atom_index] = list(np.bitwise_and(relax[atom_index], rdims))
346 return relax
349def modify_atoms_bc(atoms, sparc_bc, twist_angle=0):
350 """Modify the atoms boundary conditions in-place from the bc information
351 sparc_bc is a keyword from inpt
352 twist_angle is the helix twist angle in inpt
354 conversion rules:
355 BC: P --> pbc=True
356 BC: D, H, C --> pbc=False
357 """
358 ase_bc = []
359 # print(sparc_bc, type(sparc_bc))
360 for bc_ in sparc_bc:
361 if bc_.upper() in ["C", "H"]:
362 warn(
363 (
364 "Parsing SPARC's helix or cyclic boundary conditions"
365 " into ASE atoms is only partially supported. "
366 "Saving the atoms object into other format may cause "
367 "data-loss of the SPARC-specific BC information."
368 )
369 )
370 pbc = (
371 False # Do not confuse ase-gui, we'll manually handle the visualization
372 )
373 elif bc_.upper() == "D":
374 pbc = False
375 elif bc_.upper() == "P":
376 pbc = True
377 else:
378 raise ValueError("Unknown BC keyword values!")
379 ase_bc.append(pbc)
380 atoms.info["sparc_bc"] = [bc_.upper() for bc_ in sparc_bc]
381 if twist_angle != 0:
382 atoms.info["twist_angle (rad/Bohr)"] = twist_angle
383 atoms.pbc = ase_bc
384 return
387def atoms_bc_to_sparc(atoms):
388 """Use atoms' internal pbc and info to construct inpt blocks
390 Returns:
391 a dict containing 'BC' or 'TWIST_ANGLE'
392 """
393 sparc_bc = ["P" if bc_ else "D" for bc_ in atoms.pbc]
395 # If "sparc_bc" info is stored in the atoms object, convert again
396 if "sparc_bc" in atoms.info.keys():
397 converted_bc = []
398 stored_sparc_bc = atoms.info["sparc_bc"]
399 for bc1, bc2 in zip(sparc_bc, stored_sparc_bc):
400 # We store helix and cyclic BCs as non-periodic in ase-atoms
401 # print(bc1, bc2)
402 if ((bc1 == "D") and (bc2 != "P")) or ((bc1 == "P") and (bc2 == "P")):
403 converted_bc.append(bc2)
404 else:
405 raise ValueError(
406 "Boundary conditions stored in ASE "
407 "atoms.pbc and atoms.info['sparc_bc'] "
408 "are different!"
409 )
410 sparc_bc = converted_bc
411 block = {"BC": " ".join(sparc_bc)}
412 if "twist_angle" in atoms.info.keys():
413 block["TWIST_ANGLE"] = atoms.info["twist_angle (rad/Bohr)"]
414 return block