Coverage for sparc/sparc_parsers/atoms.py: 97%
195 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
1"""Convert ase atoms to structured dict following SPARC format
2and vice versa
3"""
5from copy import deepcopy
6from warnings import warn
8import numpy as np
9from ase import Atom, Atoms
10from ase.constraints import FixAtoms, FixedLine, FixedPlane
11from ase.units import Bohr
13from .inpt import _inpt_cell_to_ase_cell
14from .ion import _ion_coord_to_ase_pos
15from .pseudopotential import find_pseudo_path
16from .utils import make_reverse_mapping
18# from .sparc_parsers.ion import read_ion, write_ion
21def atoms_to_dict(
22 atoms,
23 sort=True,
24 direct=False,
25 wrap=False,
26 ignore_constraints=False,
27 psp_dir=None,
28 pseudopotentials={},
29 comments="",
30):
31 """Given an ASE Atoms object, convert to SPARC ion and inpt data dict
33 psp_dir: search path for psp8 files pseudopotentials: a mapping
34 between symbol and psp file names, similar to QE like 'Na':
35 'Na-pbe.psp8'. If the file name does not contain path information,
36 use psp_dir / filname, otherwise use the file path.
38 We don't do any env variable replace ment for psp_dir, it should be handled by the
39 explicit _write_ion_and_inpt() function
41 At this step, the copy_psp is not applied, since we don't yet know the location to write
43 """
44 # Step 1: if we should sort the atoms?
45 # origin_atoms = atoms.copy()
46 # sort = True re-calculate the sorting information
47 # sort = list re-uses the sorting information
48 if sort:
49 if isinstance(sort, list):
50 sort_ = np.array(sort)
51 resort_ = make_reverse_mapping(sort_)
52 else:
53 sort_ = np.argsort(atoms.get_chemical_symbols(), kind="stable")
54 resort_ = make_reverse_mapping(sort_)
55 # This is the sorted atoms object
56 atoms = atoms[sort_]
57 else:
58 sort_ = []
59 resort_ = []
61 # Step 2: determine the counts of each element
62 symbol_counts = count_symbols(atoms.get_chemical_symbols())
63 write_spin = np.any(atoms.get_initial_magnetic_moments() != 0)
64 has_charge = np.any(atoms.get_initial_charges() != 0)
65 if has_charge:
66 warn(
67 "SPARC currently doesn't support changing total number of electrons! "
68 "via nomimal charges. The initial charges in the structure will be ignored."
69 )
71 relax_mask = relax_from_all_constraints(atoms.constraints, len(atoms))
72 write_relax = (len(relax_mask) > 0) and (not ignore_constraints)
74 atom_blocks = []
75 # Step 3: write each block
76 for symbol, start, end in symbol_counts:
77 block_dict = {}
78 block_dict["ATOM_TYPE"] = symbol
79 block_dict["N_TYPE_ATOM"] = end - start
80 # TODO: make pseudo finding work
81 # TODO: write comment that psp file may not exist
82 try:
83 psp_file = find_pseudo_path(symbol, psp_dir, pseudopotentials)
84 # TODO: add option to determine if psp file exists!
85 block_dict["PSEUDO_POT"] = psp_file.resolve().as_posix()
87 except Exception:
88 warn(
89 (
90 f"Failed to find pseudo potential file for symbol {symbol}. I will use a dummy file name"
91 )
92 )
93 block_dict[
94 "PSEUDO_POT"
95 ] = f"{symbol}-dummy.psp8 # Please replace with real psp file name!"
96 # TODO: atomic mass?
97 p_atoms = atoms[start:end]
98 if direct:
99 pos = p_atoms.get_scaled_positions(wrap=wrap)
100 block_dict["COORD_FRAC"] = pos
101 else:
102 # TODO: should we use default converter?
103 pos = p_atoms.get_positions(wrap=wrap) / Bohr
104 block_dict["COORD"] = pos
105 if write_spin:
106 # TODO: should we process atoms with already calculated magmoms?
107 n_atom = len(p_atoms)
108 block_dict["SPIN"] = p_atoms.get_initial_magnetic_moments().reshape(n_atom,-1)
109 if write_relax:
110 relax_this_block = relax_mask[start:end]
111 block_dict["RELAX"] = relax_this_block
112 # TODO: get write_relax
113 atom_blocks.append(block_dict)
115 # Step 4: inpt part
116 # TODO: what if atoms does not have cell?
117 cell_au = atoms.cell / Bohr
118 inpt_blocks = {"LATVEC": cell_au, "LATVEC_SCALE": [1.0, 1.0, 1.0]}
120 # Step 5: retrieve boundary condition information
121 # TODO: have to use space to join the single keywords
122 # breakpoint()
123 inpt_blocks.update(atoms_bc_to_sparc(atoms))
125 if not isinstance(comments, list):
126 comments = comments.split("\n")
127 ion_data = {
128 "atom_blocks": atom_blocks,
129 "comments": comments,
130 "sorting": {"sort": sort_, "resort": resort_},
131 }
132 inpt_data = {"params": inpt_blocks, "comments": []}
133 return {"ion": ion_data, "inpt": inpt_data}
136def dict_to_atoms(data_dict):
137 """Given a SPARC struct dict, construct the ASE atoms object
139 Note: this method supports only 1 Atoms at a time
140 """
141 ase_cell = _inpt_cell_to_ase_cell(data_dict)
142 new_data_dict = deepcopy(data_dict)
143 _ion_coord_to_ase_pos(new_data_dict, ase_cell)
144 # Now the real thing to construct an atom object
145 atoms = Atoms()
146 atoms.cell = ase_cell
147 relax_dict = {}
149 atoms_count = 0
150 atom_blocks = new_data_dict["ion"]["atom_blocks"]
151 for block in atom_blocks:
152 element = block["ATOM_TYPE"]
153 positions = block["_ase_positions"]
154 if positions.ndim == 1:
155 positions = positions.reshape(1, -1)
156 # Consider moving spins to another function
157 spins = block.get("SPIN", None)
158 if spins is None:
159 spins = np.zeros(len(positions))
160 for pos, spin in zip(positions, spins):
161 # TODO: What about charge?
162 atoms.append(Atom(symbol=element, position=pos, magmom=spin))
163 relax = block.get("RELAX", np.array([]))
164 # Reshape relax into 2d array
165 relax = relax.reshape((-1, 3))
166 for i, r in enumerate(relax, start=atoms_count):
167 relax_dict[i] = r
168 atoms_count += len(positions)
170 if "sorting" in data_dict["ion"]:
171 resort = data_dict["ion"]["sorting"].get("resort", np.arange(len(atoms)))
172 # Resort may be None
173 if len(resort) == 0:
174 resort = np.arange(len(atoms))
175 else:
176 resort = np.arange(len(atoms))
178 if len(resort) != len(atoms):
179 # TODO: new exception
180 raise ValueError(
181 "Length of resort mapping is different from the number of atoms!"
182 )
183 # TODO: check if this mapping is correct
184 # print(relax_dict)
185 sort = make_reverse_mapping(resort)
186 # print(resort, sort)
187 sorted_relax_dict = {sort[i]: r for i, r in relax_dict.items()}
188 # Now we do a sort on the atom indices. The atom positions read from
189 # .ion correspond to the `sort` and we use `resort` to transform
191 # TODO: should we store the sorting information in SparcBundle?
193 atoms = atoms[resort]
194 constraints = constraints_from_relax(sorted_relax_dict)
195 atoms.constraints = constraints
197 # @2023.08.31 add support for PBC
198 # TODO: move to a more modular function
199 # TODO: Datatype for BC in the API, should it be string, or string array?
200 sparc_bc = new_data_dict["inpt"]["params"].get("BC", "P P P").split()
201 twist_angle = float(new_data_dict["inpt"]["params"].get("TWIST_ANGLE", 0))
202 modify_atoms_bc(atoms, sparc_bc, twist_angle)
204 return atoms
207def count_symbols(symbols):
208 """Count the number of consecutive elements.
209 Output tuple is: element, start, end
210 For example, "CHCHHO" --> [('C', 0, 1), ('H', 1, 2), ('C', 2, 3), ('H', 3, 5), ('O', 5, 6)]
211 """
212 counts = []
213 current_count = 1
214 current_symbol = symbols[0]
215 for i, symbol in enumerate(symbols[1:], start=1):
216 if symbol == current_symbol:
217 current_count += 1
218 else:
219 counts.append((current_symbol, i - current_count, i))
220 current_count = 1
221 current_symbol = symbol
222 end = len(symbols)
223 counts.append((current_symbol, end - current_count, end))
224 return counts
227def constraints_from_relax(relax_dict):
228 """
229 Convert the SPARC RELAX fields to ASE's constraints
231 Arguments
232 relax: bool vector of size Nx3, i.e. [[True, True, True], [True, False, False]]
234 Supported ase constraints will be FixAtoms, FixedLine and FixedPlane.
235 For constraints in the same direction, all indices will be gathered.
237 Note: ase>=3.22 will have FixedLine and FixedPlane accepting only 1 index at a time!
239 The relax vector must be already sorted!
240 """
241 if len(relax_dict) == 0:
242 return []
244 cons_list = []
245 # gathered_indices is an intermediate dict that contains
246 # key: relax mask if not all True
247 # indices: indices that share the same mask
248 #
249 gathered_indices = {}
251 # breakpoint()
252 for i, r in relax_dict.items():
253 r = np.array(r)
254 r = tuple(np.ndarray.tolist(r.astype(bool)))
255 if np.all(r):
256 continue
258 if r not in gathered_indices:
259 gathered_indices[r] = [i]
260 else:
261 gathered_indices[r].append(i)
263 for relax_type, indices in gathered_indices.items():
264 degree_freedom = 3 - relax_type.count(False)
266 # DegreeF == 0 --> fix atom
267 if degree_freedom == 0:
268 cons_list.append(FixAtoms(indices=indices))
269 # DegreeF == 1 --> move along line, fix line
270 elif degree_freedom == 1:
271 for ind in indices:
272 cons_list.append(FixedLine(ind, np.array(relax_type).astype(int)))
273 # DegreeF == 1 --> move along line, fix plane
274 elif degree_freedom == 2:
275 for ind in indices:
276 cons_list.append(FixedPlane(ind, (~np.array(relax_type)).astype(int)))
277 return cons_list
280def relax_from_constraint(constraint):
281 """returns dict of {atom_index: relax_dimensions} for the given constraint"""
282 type_name = constraint.todict()["name"]
283 if isinstance(constraint, FixAtoms):
284 dimensions = [False] * 3
285 expected_free = 0
286 elif isinstance(constraint, FixedLine):
287 # Only supports orthogonal basis!
288 dimensions = [d == 1 for d in constraint.dir]
289 expected_free = 1
290 elif isinstance(constraint, FixedPlane):
291 dimensions = [d != 1 for d in constraint.dir]
292 expected_free = 2
293 else:
294 warn(
295 f"The constraint type {type_name} is not supported by"
296 " SPARC's .ion format. This constraint will be"
297 " ignored"
298 )
299 return {}
300 if dimensions.count(True) != expected_free:
301 warn(
302 "SPARC's .ion filetype can only support freezing entire "
303 f"dimensions (x,y,z). The {type_name} constraint will be ignored"
304 )
305 return {}
306 return {i: dimensions for i in constraint.get_indices()} # atom indices
309def relax_from_all_constraints(constraints, natoms):
310 """converts ASE atom constraints to SPARC relaxed dimensions for the atoms"""
311 if len(constraints) == 0:
312 return []
314 relax = [
315 [True, True, True],
316 ] * natoms # assume relaxed in all dimensions for all atoms
317 for c in constraints:
318 for atom_index, rdims in relax_from_constraint(c).items():
319 if atom_index >= natoms:
320 raise ValueError(
321 (
322 "Number of total atoms smaller than the constraint indices!\n"
323 "Please check your input"
324 )
325 )
326 # There might be multiple constraints applied on one index,
327 # always make it more constrained
328 relax[atom_index] = list(np.bitwise_and(relax[atom_index], rdims))
329 return relax
332def modify_atoms_bc(atoms, sparc_bc, twist_angle=0):
333 """Modify the atoms boundary conditions in-place from the bc information
334 sparc_bc is a keyword from inpt
335 twist_angle is the helix twist angle in inpt
337 conversion rules:
338 BC: P --> pbc=True
339 BC: D, H, C --> pbc=False
340 """
341 ase_bc = []
342 # print(sparc_bc, type(sparc_bc))
343 for bc_ in sparc_bc:
344 if bc_.upper() in ["C", "H"]:
345 warn(
346 (
347 "Parsing SPARC's helix or cyclic boundary conditions"
348 " into ASE atoms is only partially supported. "
349 "Saving the atoms object into other format may cause "
350 "data-loss of the SPARC-specific BC information."
351 )
352 )
353 pbc = (
354 False # Do not confuse ase-gui, we'll manually handle the visualization
355 )
356 elif bc_.upper() == "D":
357 pbc = False
358 elif bc_.upper() == "P":
359 pbc = True
360 else:
361 raise ValueError("Unknown BC keyword values!")
362 ase_bc.append(pbc)
363 atoms.info["sparc_bc"] = [bc_.upper() for bc_ in sparc_bc]
364 if twist_angle != 0:
365 atoms.info["twist_angle (rad/Bohr)"] = twist_angle
366 atoms.pbc = ase_bc
367 return
370def atoms_bc_to_sparc(atoms):
371 """Use atoms' internal pbc and info to construct inpt blocks
373 Returns:
374 a dict containing 'BC' or 'TWIST_ANGLE'
375 """
376 sparc_bc = ["P" if bc_ else "D" for bc_ in atoms.pbc]
378 # If "sparc_bc" info is stored in the atoms object, convert again
379 if "sparc_bc" in atoms.info.keys():
380 converted_bc = []
381 stored_sparc_bc = atoms.info["sparc_bc"]
382 for bc1, bc2 in zip(sparc_bc, stored_sparc_bc):
383 # We store helix and cyclic BCs as non-periodic in ase-atoms
384 print(bc1, bc2)
385 if ((bc1 == "D") and (bc2 != "P")) or ((bc1 == "P") and (bc2 == "P")):
386 converted_bc.append(bc2)
387 else:
388 raise ValueError(
389 "Boundary conditions stored in ASE "
390 "atoms.pbc and atoms.info['sparc_bc'] "
391 "are different!"
392 )
393 sparc_bc = converted_bc
394 block = {"BC": " ".join(sparc_bc)}
395 if "twist_angle" in atoms.info.keys():
396 block["TWIST_ANGLE"] = atoms.info["twist_angle (rad/Bohr)"]
397 return block