Coverage for sparc/calculator.py: 62%
687 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
1import datetime
2import os
3import signal
4import subprocess
5import tempfile
6from pathlib import Path
7from warnings import warn, warn_explicit
9import numpy as np
10import psutil
11from ase.atoms import Atoms
12from ase.calculators.calculator import Calculator, FileIOCalculator, all_changes
14# 2024-11-28: @alchem0x2a add support for ase.config
15# In the first we only use cfg as parser for configurations
16from ase.config import cfg as _cfg
17from ase.parallel import world
18from ase.stress import full_3x3_to_voigt_6_stress
19from ase.units import Bohr, GPa, Hartree, eV
20from ase.utils import IOContext
22from .api import SparcAPI
23from .io import SparcBundle
24from .socketio import (
25 SPARCProtocol,
26 SPARCSocketClient,
27 SPARCSocketServer,
28 generate_random_socket_name,
29)
30from .utils import (
31 _find_default_sparc,
32 _find_mpi_process,
33 _get_slurm_jobid,
34 _locate_slurm_step,
35 _slurm_signal,
36 compare_dict,
37 deprecated,
38 h2gpts,
39 locate_api,
40 monitor_process,
41 parse_hubbard_string,
42 time_limit,
43)
45# Below are a list of ASE-compatible calculator input parameters that are
46# in Angstrom/eV units
47# Ideas are taken from GPAW calculator
48sparc_python_inputs = [
49 "xc",
50 "h",
51 "kpts",
52 "convergence",
53 "gpts",
54 "nbands",
55 "setups",
56]
58# The socket mode in SPARC calculator uses a relay-based mechanism
59# Several scenarios:
60# 1) use_socket = False --> Turn off all socket communications. SPARC runs from cold-start
61# 2) use_socket = True, port < 0 --> Only connect the sparc binary using ephemeral unix socket. Interface appears as if it is a normal calculator
62# 3) use_socket = True, port > 0 --> Use an out-going socket to relay information
63# 4) use_socket = True, server_only = True --> Act as a SocketServer
64# We do not support outgoing unix socket because the limited user cases
65default_socket_params = {
66 "use_socket": False, # Main switch to use socket or not
67 "host": "localhost", # Name of the socket host (only outgoing)
68 "port": -1, # Port number of the outgoing socket
69 "allow_restart": True, # If True, allow the socket server to restart
70 "server_only": False, # Start the calculator as a server
71}
74class SPARC(FileIOCalculator, IOContext):
75 """Calculator interface to the SPARC codes via the FileIOCalculator"""
77 implemented_properties = ["energy", "forces", "fermi", "stress"]
78 name = "sparc"
79 ase_objtype = "sparc_calculator" # For JSON storage
80 special_inputs = sparc_python_inputs
81 default_params = {
82 "xc": "pbe",
83 "kpts": (1, 1, 1),
84 "h": 0.25, # Angstrom equivalent to MESH_SPACING = 0.47
85 }
86 # TODO: ASE 3.23 compatibility. should use profile
87 # TODO: remove the legacy command check for future releases
88 _legacy_default_command = "sparc not initialized"
90 def __init__(
91 self,
92 restart=None,
93 directory=".",
94 *,
95 label=None,
96 atoms=None,
97 command=None,
98 psp_dir=None,
99 log="sparc.log",
100 sparc_json_file=None,
101 sparc_doc_path=None,
102 check_version=False,
103 keep_old_files=True,
104 use_socket=False,
105 socket_params={},
106 **kwargs,
107 ):
108 """
109 Initialize the SPARC calculator similar to FileIOCalculator. The validator uses the JSON API guessed
110 from sparc_json_file or sparc_doc_path.
112 Arguments:
113 restart (str or None): Path to the directory for restarting a calculation. If None, starts a new calculation.
114 directory (str or Path): Directory for SPARC calculation files.
115 label (str, optional): Custom label for identifying calculation files.
116 atoms (Atoms, optional): ASE Atoms object representing the system to be calculated.
117 command (str, optional): Command to execute SPARC. If None, it will be determined automatically.
118 psp_dir (str or Path, optional): Directory containing pseudopotentials.
119 log (str, optional): Name of the log file.
120 sparc_json_file (str, optional): Path to a JSON file with SPARC parameters.
121 sparc_doc_path (str, optional): Path to the SPARC doc LaTeX code for parsing parameters.
122 check_version (bool): Check if SPARC and document versions match
123 keep_old_files (bool): Whether older SPARC output files should be preserved.
124 If True, SPARC program will rewrite the output files
125 with suffix like .out_01, .out_02 etc
126 use_socket (bool): Main switch for the socket mode. Alias for socket_params["use_socket"]
127 socket_params (dict): Parameters to control the socket behavior. Please check default_socket_params
128 **kwargs: Additional keyword arguments to set up the calculator.
129 """
130 # 2024-11-28 @alchem0x2a added cfg as the default validator
131 self.validator = locate_api(
132 json_file=sparc_json_file, doc_path=sparc_doc_path, cfg=self.cfg
133 )
134 self.valid_params = {}
135 self.special_params = {}
136 self.inpt_state = {} # Store the inpt file states
137 self.system_state = {} # Store the system parameters (directory, bundle etc)
138 FileIOCalculator.__init__(
139 self,
140 restart=None,
141 label=None,
142 atoms=atoms,
143 command=command,
144 directory=directory,
145 **kwargs,
146 )
148 # sparc bundle will set the label. self.label will be available after the init
149 if label is None:
150 label = "SPARC" if restart is None else None
152 # Use psp dir from user input or env
153 self.sparc_bundle = SparcBundle(
154 directory=Path(self.directory),
155 mode="w",
156 atoms=self.atoms,
157 label=label, # The order is tricky here. Use label not self.label
158 psp_dir=psp_dir,
159 validator=self.validator,
160 cfg=self.cfg,
161 )
163 # Try restarting from an old calculation and set results
164 self._restart(restart=restart)
166 # self.log = self.directory / log if log is not None else None
167 self.log = log
168 self.keep_old_files = keep_old_files
169 if check_version:
170 self.sparc_version = self.detect_sparc_version()
171 else:
172 self.sparc_version = None
174 # Partially update the socket params, so that when setting use_socket = True,
175 # User can directly use the socket client
176 self.socket_params = default_socket_params.copy()
177 # Everything in argument socket_params will overwrite
178 self.socket_params.update(use_socket=use_socket)
179 self.socket_params.update(**socket_params)
181 # TODO: check parameter compatibility with socket params
182 self.process = None
183 # self.pid = None
185 # Initialize the socket settings
186 self.in_socket = None
187 self.out_socket = None
188 self.ensure_socket()
190 def _compare_system_state(self):
191 """Check if system parameters like command etc have changed
193 Returns:
194 bool: True if all parameters are the same otherwise False
195 """
196 old_state = self.system_state.copy()
197 new_state = self._dump_system_state()
198 for key, val in old_state.items():
199 new_val = new_state.pop(key, None)
200 if isinstance(new_val, dict):
201 if not compare_dict(val, new_val):
202 return False
203 else:
204 if not val == new_val:
205 return False
206 if new_state == {}:
207 return True
208 else:
209 return False
211 def _compare_calc_parameters(self, atoms, properties):
212 """Check if SPARC calculator parameters have changed
214 Returns:
215 bool: True if no change, otherwise False
216 """
217 _old_inpt_state = self.inpt_state.copy()
218 _new_inpt_state = self._generate_inpt_state(atoms, properties)
219 result = True
220 if set(_new_inpt_state.keys()) != set(_old_inpt_state.keys()):
221 result = False
222 else:
223 for key, old_val in _old_inpt_state.items():
224 new_val = _new_inpt_state[key]
225 # TODO: clean up bool
226 if isinstance(new_val, (str, bool)):
227 if new_val != old_val:
228 result = False
229 break
230 elif isinstance(new_val, (int, float)):
231 if not np.isclose(new_val, old_val):
232 result = False
233 break
234 elif isinstance(new_val, (list, np.ndarray)):
235 if not np.isclose(new_val, old_val).all():
236 result = False
237 break
238 return result
240 def _dump_system_state(self):
241 """Returns a dict with current system parameters
243 changing these parameters will cause the calculator to reload
244 especially in the use_socket = True case
245 """
246 system_state = {
247 "label": self.label,
248 "directory": self.directory,
249 "command": self.command,
250 "log": self.log,
251 "socket_params": self.socket_params,
252 }
253 return system_state
255 def ensure_socket(self):
256 # TODO: more ensure directory to other place?
257 if not self.directory.is_dir():
258 os.makedirs(self.directory, exist_ok=True)
259 if not self.use_socket:
260 return
261 if self.in_socket is None:
262 if self.socket_mode == "server":
263 # TODO: Exception for wrong port
264 self.in_socket = SPARCSocketServer(
265 port=self.socket_params["port"],
266 log=self.openfile(
267 file=self._indir(ext=".log", label="socket"),
268 comm=world,
269 mode="w",
270 ),
271 parent=self,
272 )
273 else:
274 socket_name = generate_random_socket_name()
275 print(f"Creating a socket server with name {socket_name}")
276 self.in_socket = SPARCSocketServer(
277 unixsocket=socket_name,
278 # TODO: make the log fd persistent
279 log=self.openfile(
280 file=self._indir(ext=".log", label="socket"),
281 comm=world,
282 mode="w",
283 ),
284 parent=self,
285 )
286 # TODO: add the outbound socket client
287 # TODO: we may need to check an actual socket server at host:port?!
288 # At this stage, we will need to wait the actual client to join
289 if self.out_socket is None:
290 if self.socket_mode == "client":
291 self.out_socket = SPARCSocketClient(
292 host=self.socket_params["host"],
293 port=self.socket_params["port"],
294 # TODO: change later
295 log=self.openfile(file="out_socket.log", comm=world),
296 # TODO: add the log and timeout part
297 parent_calc=self,
298 )
300 def __enter__(self):
301 """Reset upon entering the context."""
302 IOContext.__enter__(self)
303 self.reset()
304 self.close()
305 return self
307 def __exit__(self, type, value, traceback):
308 """Exiting the context manager and reset process"""
309 IOContext.__exit__(self, type, value, traceback)
310 self.close()
311 return
313 @property
314 def use_socket(self):
315 return self.socket_params["use_socket"]
317 @property
318 def socket_mode(self):
319 """The mode of the socket calculator:
321 disabled: pure SPARC file IO interface
322 local: Serves as a local SPARC calculator with socket support
323 client: Relay SPARC calculation
324 server: Remote server
325 """
326 if self.use_socket:
327 if self.socket_params["port"] > 0:
328 if self.socket_params["server_only"]:
329 return "server"
330 else:
331 return "client"
332 else:
333 return "local"
334 else:
335 return "disabled"
337 def _indir(self, ext, label=None, occur=0, d_format="{:02d}"):
338 return self.sparc_bundle._indir(
339 ext=ext, label=label, occur=occur, d_format=d_format
340 )
342 @property
343 def log(self):
344 return self.directory / self._log
346 @log.setter
347 def log(self, log):
348 # Stripe the parent direcoty information
349 if log is not None:
350 self._log = Path(log).name
351 else:
352 self._log = "sparc.log"
353 return
355 @property
356 def in_socket_filename(self):
357 # The actual socket name for inbound socket
358 # Return name as /tmp/ipi_sparc_<hex>
359 if self.in_socket is None:
360 return ""
361 else:
362 return self.in_socket.socket_filename
364 @property
365 def directory(self):
366 if hasattr(self, "sparc_bundle"):
367 return Path(self.sparc_bundle.directory)
368 else:
369 return Path(self._directory)
371 @directory.setter
372 def directory(self, directory):
373 if hasattr(self, "sparc_bundle"):
374 self.sparc_bundle.directory = Path(directory)
375 else:
376 self._directory = Path(directory)
377 return
379 @property
380 def label(self):
381 """Rewrite the label from Calculator class, since we don't want to contain pathsep"""
382 if hasattr(self, "sparc_bundle"):
383 return self.sparc_bundle.label
384 else:
385 return getattr(self, "_label", None)
387 @label.setter
388 def label(self, label):
389 """Rewrite the label from Calculator class,
390 since we don't want to contain pathsep
391 """
392 label = str(label)
393 if hasattr(self, "sparc_bundle"):
394 self.sparc_bundle.label = self.sparc_bundle._make_label(label)
395 else:
396 self._label = label
398 @property
399 def sort(self):
400 """Like Vasp calculator
401 ASE atoms --> sort --> SPARC
402 """
403 if self.sparc_bundle.sorting is None:
404 return None
405 else:
406 return self.sparc_bundle.sorting["sort"]
408 @property
409 def resort(self):
410 """Like Vasp calculator
411 SPARC --> resort --> ASE atoms
412 """
413 if self.sparc_bundle.sorting is None:
414 return None
415 else:
416 return self.sparc_bundle.sorting["resort"]
418 def check_state(self, atoms, tol=1e-8):
419 """Updated check_state method.
420 By default self.atoms (cached from output files) contains the initial_magmoms,
421 so we add a zero magmoms to the atoms for comparison if it does not exist.
423 reading a result from the .out file has only precision up to 10 digits
426 """
427 atoms_copy = atoms.copy()
428 if "initial_magmoms" not in atoms_copy.arrays:
429 atoms_copy.set_initial_magnetic_moments(
430 [
431 0,
432 ]
433 * len(atoms_copy)
434 )
435 system_changes = FileIOCalculator.check_state(self, atoms_copy, tol=tol)
436 # A few hard-written rules. Wrapping should only affect the position
437 if "positions" in system_changes:
438 atoms_copy.wrap(eps=tol)
439 new_system_changes = FileIOCalculator.check_state(self, atoms_copy, tol=tol)
440 if "positions" not in new_system_changes:
441 system_changes.remove("positions")
443 system_state_changed = not self._compare_system_state()
444 if system_state_changed:
445 system_changes.append("system_state")
446 return system_changes
448 def _make_command(self, extras=""):
449 """Use $ASE_SPARC_COMMAND or self.command to determine the command
450 as a last resort, if `sparc` exists in the PATH, use that information
452 Extras will add additional arguments to the self.command,
453 e.g. -name, -socket etc
455 2024.09.05 @alchem0x2a
456 Note in ase>=3.23 the FileIOCalculator.command will fallback
457 to self._legacy_default_command, which we should set to invalid value for now.
459 2024.11.28 @alchem0x2a
460 Make use of the ase.config to set up the command
461 """
462 if isinstance(extras, (list, tuple)):
463 extras = " ".join(extras)
464 else:
465 extras = extras.strip()
467 print(self.command)
469 # User-provided command (and properly initialized) should have
470 # highest priority
471 if (self.command is not None) and (
472 self.command != SPARC._legacy_default_command
473 ):
474 return f"{self.command} {extras}"
476 parser = self.cfg.parser["sparc"] if "sparc" in self.cfg.parser else {}
477 # Get sparc command from either env variable or ini
478 command_env = self.cfg.get("ASE_SPARC_COMMAND", None) or parser.get(
479 "command", None
480 )
482 # Get sparc binary and mpi-prefix (alternative)
483 sparc_exe = parser.get("sparc_exe", None)
484 mpi_prefix = parser.get("mpi_prefix", None)
485 if (sparc_exe is None) != (mpi_prefix is None):
486 raise ValueError(
487 "Both 'sparc_exe' and 'mpi_prefix' must be specified together, "
488 "or neither should be set in the configuration."
489 )
490 if command_env and sparc_exe:
491 raise ValueError(
492 "Cannot set both sparc_command and sparc_exe in the config ini file!"
493 )
495 if sparc_exe:
496 command_env = f"{mpi_prefix} {sparc_exe}"
498 # Fallback
499 if command_env is None:
500 sparc_exe, mpi_exe, num_cores = _find_default_sparc()
501 if sparc_exe is None:
502 raise EnvironmentError(
503 "Cannot find your sparc setup via $ASE_SPARC_COMMAND, SPARC.command, or "
504 "infer from your $PATH. Please refer to the dmanual!"
505 )
506 if mpi_exe is not None:
507 command_env = f"{mpi_exe} -n {num_cores} {sparc_exe}"
508 else:
509 command_env = str(sparc_exe)
510 warn(
511 f"Your sparc command is inferred to be {command_env}, "
512 "If this is not correct, "
513 "please manually set $ASE_SPARC_COMMAND or SPARC.command!"
514 )
515 self.command = command_env
516 return f"{self.command} {extras}"
518 def check_input_atoms(self, atoms):
519 """Check if input atoms are valid for SPARC inputs.
520 Raises:
521 ValueError: if the atoms structure is not suitable for SPARC input file
522 """
523 # Check if the user accidentally provides atoms unit cell without vacuum
524 if atoms and np.any(atoms.cell.cellpar()[:3] == 0):
525 msg = "Cannot setup SPARC calculation because at least one of the lattice dimension is zero!"
526 if any([not bc_ for bc_ in atoms.pbc]):
527 msg += " Please add a vacuum in the non-periodic direction of your input structure."
528 raise ValueError(msg)
529 # SPARC only supports orthogonal lattice when Dirichlet BC is used
530 if any([not bc_ for bc_ in atoms.pbc]):
531 if not np.isclose(atoms.cell.angles(), [90.0, 90.0, 90.0], 1.0e-4).all():
532 raise ValueError(
533 (
534 "SPARC only supports orthogonal lattice when Dirichlet BC is used! "
535 "Please modify your atoms structures"
536 )
537 )
538 for i, bc_ in enumerate(atoms.pbc):
539 if bc_:
540 continue
541 direction = "xyz"[i]
542 min_pos, max_pos = atoms.positions[:, i].min(), atoms.positions[:, i].max()
543 cell_len = atoms.cell.lengths()[i]
544 if (min_pos < 0) or (max_pos > cell_len):
545 raise ValueError(
546 (
547 f"You have Dirichlet BC enabled for {direction}-direction, "
548 "but atoms positions are out of domain. "
549 "SPARC calculator cannot continue. "
550 "Please consider using atoms.center() to reposition your atoms."
551 )
552 )
553 # Additionally, we should not allow use to calculate pbc=False with CALC_STRESS=1
554 if all([not bc_ for bc_ in atoms.pbc]): # All Dirichlet
555 calc_stress = self.parameters.get("calc_stress", False)
556 if calc_stress:
557 raise ValueError(
558 "Cannot set CALC_STRESS=1 for non-periodic system in SPARC!"
559 )
560 return
562 def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
563 """Perform a calculation step"""
565 self.check_input_atoms(atoms)
566 Calculator.calculate(self, atoms, properties, system_changes)
568 # Extra check for inpt parameters since check_state won't accept properties
569 # inpt should only change when write_inpt is actually called
570 param_changed = not self._compare_calc_parameters(atoms, properties)
571 if param_changed:
572 system_changes.append("parameters")
574 if self.socket_mode in ("local", "client"):
575 self._calculate_with_socket(
576 atoms=atoms, properties=properties, system_changes=system_changes
577 )
578 return
580 if self.socket_mode == "server":
581 self._calculate_as_server(
582 atoms=atoms, properties=properties, system_changes=system_changes
583 )
584 return
585 self.write_input(self.atoms, properties, system_changes)
586 self.execute()
587 self.read_results()
588 # Extra step, copy the atoms back to original atoms, if it's an
589 # geopt or aimd calculation
590 # This will not occur for socket calculator because it's using the static files
591 if ("geopt" in self.raw_results) or ("aimd" in self.raw_results):
592 # Update the parent atoms
593 atoms.set_positions(self.atoms.positions, apply_constraint=False)
594 atoms.cell = self.atoms.cell
595 atoms.constraints = self.atoms.constraints
596 atoms.pbc = self.atoms.pbc
597 # copy init magmom just to avoid check_state issue
598 if "initial_magmoms" in self.atoms.arrays:
599 atoms.set_initial_magnetic_moments(
600 self.atoms.get_initial_magnetic_moments()
601 )
603 def _calculate_as_server(
604 self, atoms=None, properties=["energy"], system_changes=all_changes
605 ):
606 """Use the server component to send instructions to socket"""
607 ret, raw_results = self.in_socket.calculate_new_protocol(
608 atoms=atoms, params=self.parameters
609 )
610 self.raw_results = raw_results
611 if "stress" not in self.results:
612 virial_from_socket = ret.get("virial", np.zeros(6))
613 stress_from_socket = (
614 -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume()
615 )
616 self.results["stress"] = stress_from_socket
617 # Energy and forces returned in this case do not need
618 # resorting, since they are already in the same format
619 self.results["energy"] = ret["energy"]
620 self.results["forces"] = ret["forces"]
621 return
623 def _calculate_with_socket(
624 self, atoms=None, properties=["energy"], system_changes=all_changes
625 ):
626 """Perform one socket single point calculation"""
627 # TODO: merge this part
628 if self.process is None:
629 if self.detect_socket_compatibility() is not True:
630 raise RuntimeError(
631 "Your sparc binary is not compiled with socket support!"
632 )
634 if any(
635 [
636 p in system_changes
637 for p in ("numbers", "pbc", "parameters", "system_state")
638 ]
639 ):
640 if self.process is not None:
641 if not self.socket_params["allow_restart"]:
642 raise RuntimeError(
643 (
644 f"System has changed {system_changes} and the "
645 "calculator needs to be restarted!\n"
646 "Please set socket_params['allow_restart'] = True "
647 "if you want to continue"
648 )
649 )
650 else:
651 print(
652 f"{system_changes} have changed since last calculation. Restart the socket process."
653 )
654 self.close(keep_out_socket=True)
656 if self.process is None:
657 self.ensure_socket()
658 self.write_input(atoms)
659 cmds = self._make_command(
660 extras=f"-socket {self.in_socket_filename}:unix -name {self.label}"
661 )
662 # Use the IOContext class's lazy context manager
663 # TODO what if self.log is None
664 fd_log = self.openfile(file=self.log, comm=world)
665 self.process = subprocess.Popen(
666 cmds,
667 shell=True,
668 stdout=fd_log,
669 stderr=fd_log,
670 cwd=self.directory,
671 universal_newlines=True,
672 bufsize=0,
673 )
674 # in_socket is a server
675 ret = self.in_socket.calculate_origin_protocol(atoms[self.sort])
676 # The results are parsed from file outputs (.static + .out)
677 # Except for stress, they should be exactly the same as socket returned results
678 self.read_results() #
679 assert np.isclose(
680 ret["energy"], self.results["energy"]
681 ), "Energy values from socket communication and output file are different! Please contact the developers."
682 try:
683 assert np.isclose(
684 ret["forces"][self.resort], self.results["forces"]
685 ).all(), "Force values from socket communication and output file are different! Please contact the developers."
686 except KeyError:
687 print(
688 "Force values cannot be accessed via the results dictionary. They may not be available in the output file. Ensure PRINT_FORCES: 1\nResults:\n",
689 self.results,
690 )
691 # For stress information, we make sure that the stress is always present
692 if "stress" not in self.results:
693 virial_from_socket = ret.get("virial", np.zeros(6))
694 stress_from_socket = (
695 -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume()
696 )
697 self.results["stress"] = stress_from_socket
698 self.system_state = self._dump_system_state()
699 return
701 def get_stress(self, atoms=None):
702 """Warn user the dimensionality change when using stress"""
703 if "stress_equiv" in self.results:
704 raise NotImplementedError(
705 "You're requesting stress in a low-dimensional system. Please use `calc.results['stress_equiv']` instead!"
706 )
707 return super().get_stress(atoms)
709 def _check_input_exclusion(self, input_parameters, atoms=None):
710 """Check if mutually exclusive parameters are provided
712 The exclusion rules are taken from the SPARC manual and currently hard-coded.
713 We may need to have a clever way to do the automatic rule conversion in API
714 """
715 # Rule 1: ECUT, MESH_SPACING, FD_GRID
716 count = 0
717 for key in ["ECUT", "MESH_SPACING", "FD_GRID"]:
718 if key in input_parameters:
719 count += 1
720 if count > 1:
721 raise ValueError(
722 "ECUT, MESH_SPACING, FD_GRID cannot be specified simultaneously!"
723 )
725 # Rule 2: LATVEC_SCALE, CELL
726 if ("LATVEC_SCALE" in input_parameters) and ("CELL" in input_parameters):
727 raise ValueError(
728 "LATVEC_SCALE and CELL cannot be specified simultaneously!"
729 )
731 # When the cell is provided via ase object, we will forbid user to provide
732 # LATVEC, LATVEC_SCALE or CELL
733 if atoms is not None:
734 if any([p in input_parameters for p in ["LATVEC", "LATVEC_SCALE", "CELL"]]):
735 raise ValueError(
736 (
737 "When passing an ase atoms object, LATVEC, LATVEC_SCALE or CELL cannot be set simultaneously! \n"
738 "Please set atoms.cell instead"
739 )
740 )
742 def _check_minimal_input(self, input_parameters):
743 """Check if the minimal input set is satisfied"""
744 for param in ["EXCHANGE_CORRELATION", "KPOINT_GRID"]:
745 if param not in input_parameters:
746 raise ValueError(f"Parameter {param} is not provided.")
747 # At least one from ECUT, MESH_SPACING and FD_GRID must be provided
748 if not any(
749 [param in input_parameters for param in ("ECUT", "MESH_SPACING", "FD_GRID")]
750 ):
751 raise ValueError(
752 "You should provide at least one of ECUT, MESH_SPACING or FD_GRID."
753 )
755 def _generate_inpt_state(self, atoms, properties=[]):
756 """Return a key:value pair to be written to inpt file
757 This is an immutable dict as the ground truth
758 """
759 converted_params = self._convert_special_params(atoms=atoms)
760 input_parameters = converted_params.copy()
761 input_parameters.update(self.valid_params)
763 # Make sure desired properties are always ensured, but we don't modify the user inputs
764 if "forces" in properties:
765 input_parameters["PRINT_FORCES"] = True
767 if "stress" in properties:
768 input_parameters["CALC_STRESS"] = True
770 self._check_input_exclusion(input_parameters, atoms=atoms)
771 self._check_minimal_input(input_parameters)
772 return input_parameters
774 def write_input(self, atoms, properties=[], system_changes=[]):
775 """Create input files via SparcBundle
776 Will use the self.keep_sold_files options to keep old output files
777 like .out_01, .out_02 etc
778 """
779 # import pdb; pdb.set_trace()
780 FileIOCalculator.write_input(self, atoms, properties, system_changes)
781 input_parameters = self._generate_inpt_state(atoms, properties=properties)
783 # TODO: make sure the sorting reset is justified (i.e. what about restarting?)
784 self.sparc_bundle.sorting = None
785 self.sparc_bundle._write_ion_and_inpt(
786 atoms=atoms,
787 label=self.label,
788 # Pass the rest parameters from calculator!
789 direct=False,
790 sort=True,
791 ignore_constraints=False,
792 wrap=False,
793 # Below are the parameters from v1
794 # scaled -> direct, ignore_constraints --> not add_constraints
795 scaled=False,
796 add_constraints=True,
797 copy_psp=True,
798 comment="",
799 input_parameters=input_parameters,
800 )
802 output_patterns = [".out", ".static", ".eigen", ".aimd", "geopt"]
803 # We just remove the output files, in case the user has psp files manually copied
804 if self.keep_old_files is False:
805 for f in self.directory.glob("*"):
806 if (f.is_file()) and any(
807 [f.suffix.startswith(p) for p in output_patterns]
808 ):
809 os.remove(f)
810 self.inpt_state = input_parameters
811 self.system_state = self._dump_system_state()
812 return
814 def execute(self):
815 """Make a normal SPARC calculation without socket. Note we probably need to use a better handling of background process!"""
816 extras = f"-name {self.label}"
817 command = self._make_command(extras=extras)
818 self.print_sysinfo(command)
820 try:
821 if self.log is not None:
822 with open(self.log, "a") as fd:
823 self.process = subprocess.run(
824 command, shell=True, cwd=self.directory, stdout=fd
825 )
826 else:
827 self.process = subprocess.run(
828 command, shell=True, cwd=self.directory, stdout=None
829 )
830 except OSError as err:
831 msg = 'Failed to execute "{}"'.format(command)
832 raise EnvironmentError(msg) from err
834 # We probably don't want to wait the
835 errorcode = self.process.returncode
837 if errorcode > 0:
838 msg = f"SPARC failed with command {command}" f"with error code {errorcode}"
839 raise RuntimeError(msg)
841 return
843 def close(self, keep_out_socket=False):
844 """Close the socket communication, the SPARC process etc"""
845 if not self.use_socket:
846 return
847 if self.in_socket is not None:
848 self.in_socket.close()
850 if (self.out_socket is not None) and (not keep_out_socket):
851 self.out_socket.close()
853 # In most cases if in_socket is closed, the SPARC process should also exit
854 if self.process:
855 with time_limit(5):
856 ret = self.process.poll()
857 if ret is None:
858 print("Force terminate the sparc process!")
859 self._send_mpi_signal(signal.SIGKILL)
860 else:
861 print(f"SPARC process exists with code {ret}")
863 # TODO: check if in_socket should be merged
864 self.in_socket = None
865 if not keep_out_socket:
866 self.out_socket = None
867 self._reset_process()
869 def _send_mpi_signal(self, sig):
870 """Send signal to the mpi process within self.process
871 If the process cannot be found, return without affecting the state
873 This is a method taken from the vasp_interactive project
874 """
875 try:
876 pid = self.process.pid
877 psutil_proc = psutil.Process(pid)
878 except Exception as e:
879 warn("SPARC process no longer exists. Will reset the calculator.")
880 self._reset_process()
881 return
883 if (self.pid == pid) and getattr(self, "mpi_match", None) is not None:
884 match = self.mpi_match
885 else:
886 # self.pid = pid
887 match = _find_mpi_process(pid)
888 self.mpi_match = match
889 if (match["type"] is None) or (match["process"] is None):
890 warn(
891 "Cannot find the mpi process or you're using different ompi wrapper. Will not send stop signal to mpi."
892 )
893 return
894 elif match["type"] == "mpi":
895 mpi_process = match["process"]
896 mpi_process.send_signal(sig)
897 elif match["type"] == "slurm":
898 slurm_step = match["process"]
899 _slurm_signal(slurm_step, sig)
900 else:
901 raise ValueError("Unsupported process type!")
902 return
904 def _reset_process(self):
905 """Reset the record for process in the calculator.
906 Useful if the process is missing or reset the calculator.
907 """
908 # Reset process tracker
909 self.process = None
910 # self.pid = None
911 if hasattr(self, "mpi_match"):
912 self.mpi_match = None
913 self.mpi_state = None
915 @property
916 def pid(self):
917 """The pid for the stored process"""
918 if self.process is None:
919 return None
920 else:
921 return self.process.pid
923 @property
924 def raw_results(self):
925 return getattr(self.sparc_bundle, "raw_results", {})
927 @raw_results.setter
928 def raw_results(self, value):
929 self.sparc_bundle.raw_results = value
930 return
932 def read_results(self):
933 """Parse from the SparcBundle"""
934 # self.sparc_bundle.read_raw_results()
935 last = self.sparc_bundle.convert_to_ase(indices=-1, include_all_files=False)
936 self.atoms = last.copy()
937 self.results.update(last.calc.results)
939 def _restart(self, restart=None):
940 """Reload the input parameters and atoms from previous calculation.
942 If self.parameters is already set, the parameters will not be loaded
943 If self.atoms is already set, the atoms will be not be read
944 """
945 if restart is None:
946 return
947 reload_atoms = self.atoms is None
948 reload_parameters = len(self.parameters) == 0
950 self.read_results()
951 if not reload_atoms:
952 self.atoms = None
953 if reload_parameters:
954 self.parameters = self.raw_results["inpt"]["params"]
956 if (not reload_parameters) or (not reload_atoms):
957 warn(
958 "Extra parameters or atoms are provided when restarting the SPARC calculator, "
959 "previous results will be cleared."
960 )
961 self.results.clear()
962 self.sparc_bundle.raw_results.clear()
963 return
965 def get_fermi_level(self):
966 """Extra get-method for Fermi level, if calculated"""
967 return self.results.get("fermi", None)
969 def detect_sparc_version(self):
970 """Run a short sparc test to determine which sparc is used"""
971 try:
972 cmd = self._make_command()
973 except EnvironmentError:
974 return None
975 print("Running a short calculation to determine SPARC version....")
976 # check_version must be set to False to avoid recursive calling
977 new_calc = SPARC(
978 command=self.command, psp_dir=self.sparc_bundle.psp_dir, check_version=False
979 )
980 with tempfile.TemporaryDirectory() as tmpdir:
981 new_calc.set(xc="pbe", h=0.3, kpts=(1, 1, 1), maxit_scf=1, directory=tmpdir)
982 atoms = Atoms(["H"], positions=[[0.0, 0.0, 0.0]], cell=[2, 2, 2], pbc=False)
983 try:
984 new_calc.calculate(atoms)
985 version = new_calc.raw_results["out"]["sparc_version"]
986 except Exception as e:
987 print("Error handling simple calculation: ", e)
988 version = None
989 # Warning information about version mismatch between binary and JSON API
990 # only when both are not None
991 if (version is None) and (self.validator.sparc_version is not None):
992 if version != self.validator.sparc_version:
993 warn(
994 (
995 f"SPARC binary version {version} does not match JSON API version {self.validator.sparc_version}. "
996 "You can set $SPARC_DOC_PATH to the SPARC documentation location."
997 )
998 )
999 return version
1001 def run_client(self, atoms=None, use_stress=False):
1002 """Main method to start the client code"""
1003 if not self.socket_mode == "client":
1004 raise RuntimeError(
1005 "Cannot use SPARC.run_client if the calculator is not configured in client mode!"
1006 )
1008 self.out_socket.run(atoms, use_stress)
1010 def detect_socket_compatibility(self):
1011 """Test if the sparc binary supports socket mode"""
1012 try:
1013 cmd = self._make_command()
1014 except EnvironmentError:
1015 return False
1016 with tempfile.TemporaryDirectory() as tmpdir:
1017 proc = subprocess.run(cmd, shell=True, cwd=tmpdir, capture_output=True)
1018 output = proc.stdout.decode("ascii")
1019 if "USAGE:" not in output:
1020 raise EnvironmentError(
1021 "Cannot find the sparc executable! Please make sure you have the correct setup"
1022 )
1023 compatibility = "-socket" in output
1024 return compatibility
1026 def set(self, **kwargs):
1027 """Overwrite the initial parameters"""
1028 # Do not use JSON Schema for these arguments
1029 if "label" in kwargs:
1030 self.label = kwargs.pop("label")
1032 if "directory" in kwargs:
1033 # str() call to deal with pathlib objects
1034 self.directory = str(kwargs.pop("directory"))
1036 if "log" in kwargs:
1037 self.log = kwargs.pop("log")
1039 if "check_version" in kwargs:
1040 self.check_version = bool(kwargs.pop("check_version"))
1042 if "keep_old_files" in kwargs:
1043 self.keep_old_files = kwargs.pop("keep_old_files")
1045 if "atoms" in kwargs:
1046 self.atoms = kwargs.pop("atoms") # Resets results
1048 if "command" in kwargs:
1049 self.command = kwargs.pop("command")
1051 # For now we don't let the user to hot-swap socket
1052 if ("use_socket" in kwargs) or ("socket_params" in kwargs):
1053 raise NotImplementedError("Hot swapping socket parameter is not supported!")
1055 self._sanitize_kwargs(**kwargs)
1056 set_params = {}
1057 set_params.update(self.special_params)
1058 set_params.update(self.valid_params)
1059 changed = super().set(**set_params)
1060 if changed != {}:
1061 self.reset()
1063 return changed
1065 def _sanitize_kwargs(self, **kwargs):
1066 """Convert known parameters from JSON API"""
1067 validator = self.validator
1068 if self.special_params == {}:
1069 init = True
1070 self.special_params = self.default_params.copy()
1071 else:
1072 init = False
1073 # User input gpts will overwrite default h
1074 # but user cannot put h and gpts both
1075 if "gpts" in kwargs:
1076 h = self.special_params.pop("h", None)
1077 if (h is not None) and (not init):
1078 warn("Parameter gpts will overwrite previously set parameter h.")
1079 elif "h" in kwargs:
1080 gpts = self.special_params.pop("gpts", None)
1081 if (gpts is not None) and (not init):
1082 warn("Parameter h will overwrite previously set parameter gpts.")
1084 upper_valid_params = set() # Valid SPARC parameters in upper case
1085 # SPARC API is case insensitive
1086 for key, value in kwargs.items():
1087 if key in self.special_inputs:
1088 # Special case: ignore h when gpts provided
1090 self.special_params[key] = value
1091 else:
1092 key = key.upper()
1093 if key in upper_valid_params:
1094 warn(f"Parameter {key} (case-insentive) appears multiple times!")
1095 if validator.validate_input(key, value):
1096 self.valid_params[key] = value
1097 upper_valid_params.add(key)
1098 else:
1099 raise ValueError(
1100 f"Value {value} for parameter {key} (case-insensitive) is invalid!"
1101 )
1102 return
1104 def _convert_special_params(self, atoms=None):
1105 """Convert ASE-compatible parameters to SPARC compatible ones
1106 parameters like `h`, `nbands` may need atoms information
1108 Special rules:
1109 h <--> gpts <--> FD_GRID, only when None of FD_GRID / ECUT or MESH_SPACING is provided
1110 """
1111 converted_sparc_params = {}
1112 validator = self.validator
1113 params = self.special_params.copy()
1115 # xc --> EXCHANGE_CORRELATION
1116 if "xc" in params:
1117 xc = params.pop("xc")
1118 if xc.lower() == "pbe":
1119 converted_sparc_params["EXCHANGE_CORRELATION"] = "GGA_PBE"
1120 elif xc.lower() == "lda":
1121 converted_sparc_params["EXCHANGE_CORRELATION"] = "LDA_PZ"
1122 elif xc.lower() == "rpbe":
1123 converted_sparc_params["EXCHANGE_CORRELATION"] = "GGA_RPBE"
1124 elif xc.lower() == "pbesol":
1125 converted_sparc_params["EXCHANGE_CORRELATION"] = "GGA_PBEsol"
1126 elif xc.lower() == "pbe0":
1127 converted_sparc_params["EXCHANGE_CORRELATION"] = "PBE0"
1128 elif xc.lower() == "hf":
1129 converted_sparc_params["EXCHANGE_CORRELATION"] = "HF"
1130 # backward compatibility for HSE03. Note HSE06 is not supported yet
1131 elif (xc.lower() == "hse") or (xc.lower() == "hse03"):
1132 converted_sparc_params["EXCHANGE_CORRELATION"] = "HSE"
1133 # backward compatibility for VASP-style XCs
1134 elif (
1135 (xc.lower() == "vdwdf1")
1136 or (xc.lower() == "vdw-df")
1137 or (xc.lower() == "vdw-df1")
1138 ):
1139 converted_sparc_params["EXCHANGE_CORRELATION"] = "vdWDF1"
1140 elif (xc.lower() == "vdwdf2") or (xc.lower() == "vdw-df2"):
1141 converted_sparc_params["EXCHANGE_CORRELATION"] = "vdWDF2"
1142 elif xc.lower() == "scan":
1143 converted_sparc_params["EXCHANGE_CORRELATION"] = "SCAN"
1144 else:
1145 raise ValueError(f"xc keyword value {xc} is invalid!")
1147 # h --> gpts
1148 if "h" in params:
1149 if "gpts" in params:
1150 raise KeyError(
1151 "h and gpts cannot be provided together in SPARC calculator!"
1152 )
1153 h = params.pop("h")
1154 # if atoms is None:
1155 # raise ValueError(
1156 # "Must have an active atoms object to convert h --> gpts!"
1157 # )
1158 if any(
1159 [p in self.valid_params for p in ("FD_GRID", "ECUT", "MESH_SPACING")]
1160 ):
1161 warn(
1162 "You have specified one of FD_GRID, ECUT or MESH_SPACING, "
1163 "conversion of h to mesh grid is ignored."
1164 )
1165 else:
1166 # gpts = h2gpts(h, atoms.cell)
1167 # params["gpts"] = gpts
1168 # Use mesh_spacing instead of fd_grid to avoid parameters
1169 converted_sparc_params["MESH_SPACING"] = h / Bohr
1171 # gpts --> FD_GRID
1172 if "gpts" in params:
1173 gpts = params.pop("gpts")
1174 if validator.validate_input("FD_GRID", gpts):
1175 converted_sparc_params["FD_GRID"] = gpts
1176 else:
1177 raise ValueError(f"Input parameter gpts has invalid value {gpts}")
1179 # kpts
1180 if "kpts" in params:
1181 kpts = params.pop("kpts")
1182 if validator.validate_input("KPOINT_GRID", kpts):
1183 converted_sparc_params["KPOINT_GRID"] = kpts
1184 else:
1185 raise ValueError(f"Input parameter kpts has invalid value {kpts}")
1187 # nbands
1188 if "nbands" in params:
1189 # TODO: Check if the nbands are correct in current system
1190 # TODO: default $N_e/2 \\times 1.2 + 5$
1191 nbands = params.pop("nbands")
1192 if validator.validate_input("NSTATES", nbands):
1193 converted_sparc_params["NSTATES"] = nbands
1194 else:
1195 raise ValueError(f"Input parameter nbands has invalid value {nbands}")
1197 # convergence is a dict
1198 if "convergence" in params:
1199 convergence = params.pop("convergence")
1200 tol_e = convergence.get("energy", None)
1201 if tol_e:
1202 # TOL SCF: Ha / atom <--> energy tol: eV / atom
1203 converted_sparc_params["TOL_SCF"] = tol_e / Hartree
1205 tol_f = convergence.get("relax", None)
1206 if tol_f:
1207 # TOL SCF: Ha / Bohr <--> energy tol: Ha / Bohr
1208 converted_sparc_params["TOL_RELAX"] = tol_f / Hartree * Bohr
1210 tol_dens = convergence.get("density", None)
1211 if tol_dens:
1212 # TOL SCF: electrons / atom
1213 converted_sparc_params["TOL_PSEUDOCHARGE"] = tol_dens
1215 tol_stress = convergence.get("stress", None)
1216 if tol_stress:
1217 # TOL SCF: electrons / atom
1218 converted_sparc_params["TOL_RELAX_CELL"] = tol_stress / GPa
1220 # setup currently only handles HUBBARD U
1221 if "setups" in params:
1222 hubbard_u_pairs = []
1223 # TODO: Check element validity
1224 for elem, hubbard_string in params["setups"].items():
1225 u_val = parse_hubbard_string(hubbard_string)
1226 hubbard_u_pairs.append({"U_ATOM_TYPE": elem, "U_VAL": u_val})
1227 converted_sparc_params["HUBBARD"] = hubbard_u_pairs
1229 return converted_sparc_params
1231 def print_sysinfo(self, command=None):
1232 """Record current runtime information"""
1233 now = datetime.datetime.now().isoformat()
1234 if command is None:
1235 command = self.command
1236 msg = (
1237 "\n" + "*" * 80 + "\n"
1238 f"SPARC program started by SPARC-X-API at {now}\n"
1239 f"command: {command}\n"
1240 )
1241 if self.log is None:
1242 print(msg)
1243 else:
1244 with open(self.log, "a") as fd:
1245 print(msg, file=fd)
1247 ###############################################
1248 # Below are deprecated functions from v1
1249 ###############################################
1250 @deprecated("Please use SPARC.set instead for setting grid")
1251 def interpret_grid_input(self, atoms, **kwargs):
1252 return None
1254 @deprecated("Please use SPARC.set instead for setting kpoints")
1255 def interpret_kpoint_input(self, atoms, **kwargs):
1256 return None
1258 @deprecated("Please use SPARC.set instead for setting downsampling parameter")
1259 def interpret_downsampling_input(self, atoms, **kwargs):
1260 return None
1262 @deprecated("Please use SPARC.set instead for setting kpoint shift")
1263 def interpret_kpoint_shift(self, atoms, **kwargs):
1264 return None
1266 @deprecated("Please use SPARC.psp_dir instead")
1267 def get_pseudopotential_directory(self, pseudo_dir=None, **kwargs):
1268 return self.sparc_bundle.psp_dir
1270 def get_nstates(self):
1271 raise NotImplementedError("Parsing nstates is not yet implemented.")
1273 @deprecated("Please set the variables separatedly")
1274 def setup_parallel_env(self):
1275 return None
1277 @deprecated("Please use SPARC._make_command instead")
1278 def generate_command(self):
1279 return self._make_command(f"-name {self.label}")
1281 def estimate_memory(self, atoms=None, units="GB", **kwargs):
1282 """
1283 a function to estimate the amount of memory required to run
1284 the selected calculation. This function takes in **kwargs,
1285 but if none are passed in, it will fall back on the parameters
1286 input when the class was instantiated
1287 """
1288 conversion_dict = {
1289 "MB": 1e-6,
1290 "GB": 1e-9,
1291 "B": 1,
1292 "byte": 1,
1293 "KB": 1e-3,
1294 }
1295 if kwargs == {}:
1296 kwargs = self.parameters
1297 if atoms is None:
1298 atoms = self.atoms
1300 nstates = kwargs.get("NSTATES")
1301 if nstates is None:
1302 nstates = self.get_nstates(atoms=atoms, **kwargs)
1304 # some annoying code to figure out if it's a spin system
1305 spin_polarized = kwargs.get("nstates")
1306 if spin_polarized is not None:
1307 spin_polarized = int(spin_polarized)
1308 else:
1309 spin_polarized = 1
1310 if spin_polarized == 2:
1311 spin_factor = 2
1312 else:
1313 spin_factor = 1
1315 if "MESH_SPACING" in kwargs:
1316 # MESH_SPACING: Bohr; h: angstrom
1317 kwargs["h"] = kwargs.pop("MESH_SPACING") / Bohr
1318 npoints = np.product(self.interpret_grid_input(atoms, **kwargs))
1320 kpt_grid = self.interpret_kpoint_input(atoms, **kwargs)
1321 kpt_factor = np.ceil(np.product(kpt_grid) / 2)
1323 # this is a pretty generous over-estimate
1324 estimate = 5 * npoints * nstates * kpt_factor * spin_factor * 8 # bytes
1325 converted_estimate = estimate * conversion_dict[units]
1326 return converted_estimate
1328 def get_scf_steps(self, include_uncompleted_last_step=False):
1329 raise NotImplemented
1331 @deprecated("Use SPARC.get_number_of_ionic_steps instead")
1332 def get_geometric_steps(self, include_uncompleted_last_step=False):
1333 raise NotImplemented
1335 def get_runtime(self):
1336 raise NotImplemented
1338 def get_fermi_level(self):
1339 raise NotImplemented
1341 @deprecated
1342 def concatinate_output(self):
1343 raise DeprecationWarning("Functionality moved in sparc.SparcBundle.")
1345 @deprecated
1346 def read_line(self, **kwargs):
1347 raise DeprecationWarning(
1348 "Parsers for individual files have been moved to sparc.sparc_parsers module"
1349 )
1351 @deprecated
1352 def parse_output(self, **kwargs):
1353 raise DeprecationWarning("Use SPARC.read_results for parsing results!")
1355 @deprecated
1356 def parse_relax(self, *args, **kwargs):
1357 raise DeprecationWarning("Use SPARC.read_results for parsing results!")
1359 @deprecated
1360 def parse_MD(self, *args, **kwargs):
1361 raise DeprecationWarning("Use SPARC.read_results for parsing results!")
1363 @deprecated
1364 def parse_input_args(self, input_block):
1365 raise DeprecationWarning("Use SPARC.set for argument handling!")
1367 @deprecated
1368 def recover_index_order_from_ion_file(self, label):
1369 raise DeprecationWarning(
1370 "Use SPARC.sort and SPARC.resort for atomic index sorting!"
1371 )
1373 @deprecated
1374 def atoms_dict(self, *args, **kwargs):
1375 raise DeprecationWarning("")
1377 @deprecated
1378 def dict_atoms(self, *args, **kwargs):
1379 raise DeprecationWarning("")