Coverage for sparc/calculator.py: 62%
681 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
1import datetime
2import os
3import signal
4import subprocess
5import tempfile
6from pathlib import Path
7from warnings import warn, warn_explicit
9import numpy as np
10import psutil
11from ase.atoms import Atoms
12from ase.calculators.calculator import Calculator, FileIOCalculator, all_changes
14# 2024-11-28: @alchem0x2a add support for ase.config
15# In the first we only use cfg as parser for configurations
16from ase.config import cfg as _cfg
17from ase.parallel import world
18from ase.stress import full_3x3_to_voigt_6_stress
19from ase.units import Bohr, GPa, Hartree, eV
20from ase.utils import IOContext
22from .api import SparcAPI
23from .io import SparcBundle
24from .socketio import (
25 SPARCProtocol,
26 SPARCSocketClient,
27 SPARCSocketServer,
28 generate_random_socket_name,
29)
30from .utils import (
31 _find_default_sparc,
32 _find_mpi_process,
33 _get_slurm_jobid,
34 _locate_slurm_step,
35 _slurm_signal,
36 compare_dict,
37 deprecated,
38 h2gpts,
39 locate_api,
40 monitor_process,
41 time_limit,
42)
44# Below are a list of ASE-compatible calculator input parameters that are
45# in Angstrom/eV units
46# Ideas are taken from GPAW calculator
47sparc_python_inputs = [
48 "xc",
49 "h",
50 "kpts",
51 "convergence",
52 "gpts",
53 "nbands",
54]
56# The socket mode in SPARC calculator uses a relay-based mechanism
57# Several scenarios:
58# 1) use_socket = False --> Turn off all socket communications. SPARC runs from cold-start
59# 2) use_socket = True, port < 0 --> Only connect the sparc binary using ephemeral unix socket. Interface appears as if it is a normal calculator
60# 3) use_socket = True, port > 0 --> Use an out-going socket to relay information
61# 4) use_socket = True, server_only = True --> Act as a SocketServer
62# We do not support outgoing unix socket because the limited user cases
63default_socket_params = {
64 "use_socket": False, # Main switch to use socket or not
65 "host": "localhost", # Name of the socket host (only outgoing)
66 "port": -1, # Port number of the outgoing socket
67 "allow_restart": True, # If True, allow the socket server to restart
68 "server_only": False, # Start the calculator as a server
69}
72class SPARC(FileIOCalculator, IOContext):
73 """Calculator interface to the SPARC codes via the FileIOCalculator"""
75 implemented_properties = ["energy", "forces", "fermi", "stress"]
76 name = "sparc"
77 ase_objtype = "sparc_calculator" # For JSON storage
78 special_inputs = sparc_python_inputs
79 default_params = {
80 "xc": "pbe",
81 "kpts": (1, 1, 1),
82 "h": 0.25, # Angstrom equivalent to MESH_SPACING = 0.47
83 }
84 # TODO: ASE 3.23 compatibility. should use profile
85 # TODO: remove the legacy command check for future releases
86 _legacy_default_command = "sparc not initialized"
88 def __init__(
89 self,
90 restart=None,
91 directory=".",
92 *,
93 label=None,
94 atoms=None,
95 command=None,
96 psp_dir=None,
97 log="sparc.log",
98 sparc_json_file=None,
99 sparc_doc_path=None,
100 check_version=False,
101 keep_old_files=True,
102 use_socket=False,
103 socket_params={},
104 **kwargs,
105 ):
106 """
107 Initialize the SPARC calculator similar to FileIOCalculator. The validator uses the JSON API guessed
108 from sparc_json_file or sparc_doc_path.
110 Arguments:
111 restart (str or None): Path to the directory for restarting a calculation. If None, starts a new calculation.
112 directory (str or Path): Directory for SPARC calculation files.
113 label (str, optional): Custom label for identifying calculation files.
114 atoms (Atoms, optional): ASE Atoms object representing the system to be calculated.
115 command (str, optional): Command to execute SPARC. If None, it will be determined automatically.
116 psp_dir (str or Path, optional): Directory containing pseudopotentials.
117 log (str, optional): Name of the log file.
118 sparc_json_file (str, optional): Path to a JSON file with SPARC parameters.
119 sparc_doc_path (str, optional): Path to the SPARC doc LaTeX code for parsing parameters.
120 check_version (bool): Check if SPARC and document versions match
121 keep_old_files (bool): Whether older SPARC output files should be preserved.
122 If True, SPARC program will rewrite the output files
123 with suffix like .out_01, .out_02 etc
124 use_socket (bool): Main switch for the socket mode. Alias for socket_params["use_socket"]
125 socket_params (dict): Parameters to control the socket behavior. Please check default_socket_params
126 **kwargs: Additional keyword arguments to set up the calculator.
127 """
128 # 2024-11-28 @alchem0x2a added cfg as the default validator
129 self.validator = locate_api(
130 json_file=sparc_json_file, doc_path=sparc_doc_path, cfg=self.cfg
131 )
132 self.valid_params = {}
133 self.special_params = {}
134 self.inpt_state = {} # Store the inpt file states
135 self.system_state = {} # Store the system parameters (directory, bundle etc)
136 FileIOCalculator.__init__(
137 self,
138 restart=None,
139 label=None,
140 atoms=atoms,
141 command=command,
142 directory=directory,
143 **kwargs,
144 )
146 # sparc bundle will set the label. self.label will be available after the init
147 if label is None:
148 label = "SPARC" if restart is None else None
150 # Use psp dir from user input or env
151 self.sparc_bundle = SparcBundle(
152 directory=Path(self.directory),
153 mode="w",
154 atoms=self.atoms,
155 label=label, # The order is tricky here. Use label not self.label
156 psp_dir=psp_dir,
157 validator=self.validator,
158 cfg=self.cfg,
159 )
161 # Try restarting from an old calculation and set results
162 self._restart(restart=restart)
164 # self.log = self.directory / log if log is not None else None
165 self.log = log
166 self.keep_old_files = keep_old_files
167 if check_version:
168 self.sparc_version = self.detect_sparc_version()
169 else:
170 self.sparc_version = None
172 # Partially update the socket params, so that when setting use_socket = True,
173 # User can directly use the socket client
174 self.socket_params = default_socket_params.copy()
175 # Everything in argument socket_params will overwrite
176 self.socket_params.update(use_socket=use_socket)
177 self.socket_params.update(**socket_params)
179 # TODO: check parameter compatibility with socket params
180 self.process = None
181 # self.pid = None
183 # Initialize the socket settings
184 self.in_socket = None
185 self.out_socket = None
186 self.ensure_socket()
188 def _compare_system_state(self):
189 """Check if system parameters like command etc have changed
191 Returns:
192 bool: True if all parameters are the same otherwise False
193 """
194 old_state = self.system_state.copy()
195 new_state = self._dump_system_state()
196 for key, val in old_state.items():
197 new_val = new_state.pop(key, None)
198 if isinstance(new_val, dict):
199 if not compare_dict(val, new_val):
200 return False
201 else:
202 if not val == new_val:
203 return False
204 if new_state == {}:
205 return True
206 else:
207 return False
209 def _compare_calc_parameters(self, atoms, properties):
210 """Check if SPARC calculator parameters have changed
212 Returns:
213 bool: True if no change, otherwise False
214 """
215 _old_inpt_state = self.inpt_state.copy()
216 _new_inpt_state = self._generate_inpt_state(atoms, properties)
217 result = True
218 if set(_new_inpt_state.keys()) != set(_old_inpt_state.keys()):
219 result = False
220 else:
221 for key, old_val in _old_inpt_state.items():
222 new_val = _new_inpt_state[key]
223 # TODO: clean up bool
224 if isinstance(new_val, (str, bool)):
225 if new_val != old_val:
226 result = False
227 break
228 elif isinstance(new_val, (int, float)):
229 if not np.isclose(new_val, old_val):
230 result = False
231 break
232 elif isinstance(new_val, (list, np.ndarray)):
233 if not np.isclose(new_val, old_val).all():
234 result = False
235 break
236 return result
238 def _dump_system_state(self):
239 """Returns a dict with current system parameters
241 changing these parameters will cause the calculator to reload
242 especially in the use_socket = True case
243 """
244 system_state = {
245 "label": self.label,
246 "directory": self.directory,
247 "command": self.command,
248 "log": self.log,
249 "socket_params": self.socket_params,
250 }
251 return system_state
253 def ensure_socket(self):
254 # TODO: more ensure directory to other place?
255 if not self.directory.is_dir():
256 os.makedirs(self.directory, exist_ok=True)
257 if not self.use_socket:
258 return
259 if self.in_socket is None:
260 if self.socket_mode == "server":
261 # TODO: Exception for wrong port
262 self.in_socket = SPARCSocketServer(
263 port=self.socket_params["port"],
264 log=self.openfile(
265 file=self._indir(ext=".log", label="socket"),
266 comm=world,
267 mode="w",
268 ),
269 parent=self,
270 )
271 else:
272 socket_name = generate_random_socket_name()
273 print(f"Creating a socket server with name {socket_name}")
274 self.in_socket = SPARCSocketServer(
275 unixsocket=socket_name,
276 # TODO: make the log fd persistent
277 log=self.openfile(
278 file=self._indir(ext=".log", label="socket"),
279 comm=world,
280 mode="w",
281 ),
282 parent=self,
283 )
284 # TODO: add the outbound socket client
285 # TODO: we may need to check an actual socket server at host:port?!
286 # At this stage, we will need to wait the actual client to join
287 if self.out_socket is None:
288 if self.socket_mode == "client":
289 self.out_socket = SPARCSocketClient(
290 host=self.socket_params["host"],
291 port=self.socket_params["port"],
292 # TODO: change later
293 log=self.openfile(file="out_socket.log", comm=world),
294 # TODO: add the log and timeout part
295 parent_calc=self,
296 )
298 def __enter__(self):
299 """Reset upon entering the context."""
300 IOContext.__enter__(self)
301 self.reset()
302 self.close()
303 return self
305 def __exit__(self, type, value, traceback):
306 """Exiting the context manager and reset process"""
307 IOContext.__exit__(self, type, value, traceback)
308 self.close()
309 return
311 @property
312 def use_socket(self):
313 return self.socket_params["use_socket"]
315 @property
316 def socket_mode(self):
317 """The mode of the socket calculator:
319 disabled: pure SPARC file IO interface
320 local: Serves as a local SPARC calculator with socket support
321 client: Relay SPARC calculation
322 server: Remote server
323 """
324 if self.use_socket:
325 if self.socket_params["port"] > 0:
326 if self.socket_params["server_only"]:
327 return "server"
328 else:
329 return "client"
330 else:
331 return "local"
332 else:
333 return "disabled"
335 def _indir(self, ext, label=None, occur=0, d_format="{:02d}"):
336 return self.sparc_bundle._indir(
337 ext=ext, label=label, occur=occur, d_format=d_format
338 )
340 @property
341 def log(self):
342 return self.directory / self._log
344 @log.setter
345 def log(self, log):
346 # Stripe the parent direcoty information
347 if log is not None:
348 self._log = Path(log).name
349 else:
350 self._log = "sparc.log"
351 return
353 @property
354 def in_socket_filename(self):
355 # The actual socket name for inbound socket
356 # Return name as /tmp/ipi_sparc_<hex>
357 if self.in_socket is None:
358 return ""
359 else:
360 return self.in_socket.socket_filename
362 @property
363 def directory(self):
364 if hasattr(self, "sparc_bundle"):
365 return Path(self.sparc_bundle.directory)
366 else:
367 return Path(self._directory)
369 @directory.setter
370 def directory(self, directory):
371 if hasattr(self, "sparc_bundle"):
372 self.sparc_bundle.directory = Path(directory)
373 else:
374 self._directory = Path(directory)
375 return
377 @property
378 def label(self):
379 """Rewrite the label from Calculator class, since we don't want to contain pathsep"""
380 if hasattr(self, "sparc_bundle"):
381 return self.sparc_bundle.label
382 else:
383 return getattr(self, "_label", None)
385 @label.setter
386 def label(self, label):
387 """Rewrite the label from Calculator class,
388 since we don't want to contain pathsep
389 """
390 label = str(label)
391 if hasattr(self, "sparc_bundle"):
392 self.sparc_bundle.label = self.sparc_bundle._make_label(label)
393 else:
394 self._label = label
396 @property
397 def sort(self):
398 """Like Vasp calculator
399 ASE atoms --> sort --> SPARC
400 """
401 if self.sparc_bundle.sorting is None:
402 return None
403 else:
404 return self.sparc_bundle.sorting["sort"]
406 @property
407 def resort(self):
408 """Like Vasp calculator
409 SPARC --> resort --> ASE atoms
410 """
411 if self.sparc_bundle.sorting is None:
412 return None
413 else:
414 return self.sparc_bundle.sorting["resort"]
416 def check_state(self, atoms, tol=1e-8):
417 """Updated check_state method.
418 By default self.atoms (cached from output files) contains the initial_magmoms,
419 so we add a zero magmoms to the atoms for comparison if it does not exist.
421 reading a result from the .out file has only precision up to 10 digits
424 """
425 atoms_copy = atoms.copy()
426 if "initial_magmoms" not in atoms_copy.arrays:
427 atoms_copy.set_initial_magnetic_moments(
428 [
429 0,
430 ]
431 * len(atoms_copy)
432 )
433 system_changes = FileIOCalculator.check_state(self, atoms_copy, tol=tol)
434 # A few hard-written rules. Wrapping should only affect the position
435 if "positions" in system_changes:
436 atoms_copy.wrap(eps=tol)
437 new_system_changes = FileIOCalculator.check_state(self, atoms_copy, tol=tol)
438 if "positions" not in new_system_changes:
439 system_changes.remove("positions")
441 system_state_changed = not self._compare_system_state()
442 if system_state_changed:
443 system_changes.append("system_state")
444 return system_changes
446 def _make_command(self, extras=""):
447 """Use $ASE_SPARC_COMMAND or self.command to determine the command
448 as a last resort, if `sparc` exists in the PATH, use that information
450 Extras will add additional arguments to the self.command,
451 e.g. -name, -socket etc
453 2024.09.05 @alchem0x2a
454 Note in ase>=3.23 the FileIOCalculator.command will fallback
455 to self._legacy_default_command, which we should set to invalid value for now.
457 2024.11.28 @alchem0x2a
458 Make use of the ase.config to set up the command
459 """
460 if isinstance(extras, (list, tuple)):
461 extras = " ".join(extras)
462 else:
463 extras = extras.strip()
465 print(self.command)
467 # User-provided command (and properly initialized) should have
468 # highest priority
469 if (self.command is not None) and (
470 self.command != SPARC._legacy_default_command
471 ):
472 return f"{self.command} {extras}"
474 parser = self.cfg.parser["sparc"] if "sparc" in self.cfg.parser else {}
475 # Get sparc command from either env variable or ini
476 command_env = self.cfg.get("ASE_SPARC_COMMAND", None) or parser.get(
477 "command", None
478 )
480 # Get sparc binary and mpi-prefix (alternative)
481 sparc_exe = parser.get("sparc_exe", None)
482 mpi_prefix = parser.get("mpi_prefix", None)
483 if (sparc_exe is None) != (mpi_prefix is None):
484 raise ValueError(
485 "Both 'sparc_exe' and 'mpi_prefix' must be specified together, "
486 "or neither should be set in the configuration."
487 )
488 if command_env and sparc_exe:
489 raise ValueError(
490 "Cannot set both sparc_command and sparc_exe in the config ini file!"
491 )
493 if sparc_exe:
494 command_env = f"{mpi_prefix} {sparc_exe}"
496 # Fallback
497 if command_env is None:
498 sparc_exe, mpi_exe, num_cores = _find_default_sparc()
499 if sparc_exe is None:
500 raise EnvironmentError(
501 "Cannot find your sparc setup via $ASE_SPARC_COMMAND, SPARC.command, or "
502 "infer from your $PATH. Please refer to the dmanual!"
503 )
504 if mpi_exe is not None:
505 command_env = f"{mpi_exe} -n {num_cores} {sparc_exe}"
506 else:
507 command_env = str(sparc_exe)
508 warn(
509 f"Your sparc command is inferred to be {command_env}, "
510 "If this is not correct, "
511 "please manually set $ASE_SPARC_COMMAND or SPARC.command!"
512 )
513 self.command = command_env
514 return f"{self.command} {extras}"
516 def check_input_atoms(self, atoms):
517 """Check if input atoms are valid for SPARC inputs.
518 Raises:
519 ValueError: if the atoms structure is not suitable for SPARC input file
520 """
521 # Check if the user accidentally provides atoms unit cell without vacuum
522 if atoms and np.any(atoms.cell.cellpar()[:3] == 0):
523 msg = "Cannot setup SPARC calculation because at least one of the lattice dimension is zero!"
524 if any([not bc_ for bc_ in atoms.pbc]):
525 msg += " Please add a vacuum in the non-periodic direction of your input structure."
526 raise ValueError(msg)
527 # SPARC only supports orthogonal lattice when Dirichlet BC is used
528 if any([not bc_ for bc_ in atoms.pbc]):
529 if not np.isclose(atoms.cell.angles(), [90.0, 90.0, 90.0], 1.0e-4).all():
530 raise ValueError(
531 (
532 "SPARC only supports orthogonal lattice when Dirichlet BC is used! "
533 "Please modify your atoms structures"
534 )
535 )
536 for i, bc_ in enumerate(atoms.pbc):
537 if bc_:
538 continue
539 direction = "xyz"[i]
540 min_pos, max_pos = atoms.positions[:, i].min(), atoms.positions[:, i].max()
541 cell_len = atoms.cell.lengths()[i]
542 if (min_pos < 0) or (max_pos > cell_len):
543 raise ValueError(
544 (
545 f"You have Dirichlet BC enabled for {direction}-direction, "
546 "but atoms positions are out of domain. "
547 "SPARC calculator cannot continue. "
548 "Please consider using atoms.center() to reposition your atoms."
549 )
550 )
551 # Additionally, we should not allow use to calculate pbc=False with CALC_STRESS=1
552 if all([not bc_ for bc_ in atoms.pbc]): # All Dirichlet
553 calc_stress = self.parameters.get("calc_stress", False)
554 if calc_stress:
555 raise ValueError(
556 "Cannot set CALC_STRESS=1 for non-periodic system in SPARC!"
557 )
558 return
560 def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
561 """Perform a calculation step"""
563 self.check_input_atoms(atoms)
564 Calculator.calculate(self, atoms, properties, system_changes)
566 # Extra check for inpt parameters since check_state won't accept properties
567 # inpt should only change when write_inpt is actually called
568 param_changed = not self._compare_calc_parameters(atoms, properties)
569 if param_changed:
570 system_changes.append("parameters")
572 if self.socket_mode in ("local", "client"):
573 self._calculate_with_socket(
574 atoms=atoms, properties=properties, system_changes=system_changes
575 )
576 return
578 if self.socket_mode == "server":
579 self._calculate_as_server(
580 atoms=atoms, properties=properties, system_changes=system_changes
581 )
582 return
583 self.write_input(self.atoms, properties, system_changes)
584 self.execute()
585 self.read_results()
586 # Extra step, copy the atoms back to original atoms, if it's an
587 # geopt or aimd calculation
588 # This will not occur for socket calculator because it's using the static files
589 if ("geopt" in self.raw_results) or ("aimd" in self.raw_results):
590 # Update the parent atoms
591 atoms.set_positions(self.atoms.positions, apply_constraint=False)
592 atoms.cell = self.atoms.cell
593 atoms.constraints = self.atoms.constraints
594 atoms.pbc = self.atoms.pbc
595 # copy init magmom just to avoid check_state issue
596 if "initial_magmoms" in self.atoms.arrays:
597 atoms.set_initial_magnetic_moments(
598 self.atoms.get_initial_magnetic_moments()
599 )
601 def _calculate_as_server(
602 self, atoms=None, properties=["energy"], system_changes=all_changes
603 ):
604 """Use the server component to send instructions to socket"""
605 ret, raw_results = self.in_socket.calculate_new_protocol(
606 atoms=atoms, params=self.parameters
607 )
608 self.raw_results = raw_results
609 if "stress" not in self.results:
610 virial_from_socket = ret.get("virial", np.zeros(6))
611 stress_from_socket = (
612 -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume()
613 )
614 self.results["stress"] = stress_from_socket
615 # Energy and forces returned in this case do not need
616 # resorting, since they are already in the same format
617 self.results["energy"] = ret["energy"]
618 self.results["forces"] = ret["forces"]
619 return
621 def _calculate_with_socket(
622 self, atoms=None, properties=["energy"], system_changes=all_changes
623 ):
624 """Perform one socket single point calculation"""
625 # TODO: merge this part
626 if self.process is None:
627 if self.detect_socket_compatibility() is not True:
628 raise RuntimeError(
629 "Your sparc binary is not compiled with socket support!"
630 )
632 if any(
633 [
634 p in system_changes
635 for p in ("numbers", "pbc", "parameters", "system_state")
636 ]
637 ):
638 if self.process is not None:
639 if not self.socket_params["allow_restart"]:
640 raise RuntimeError(
641 (
642 f"System has changed {system_changes} and the "
643 "calculator needs to be restarted!\n"
644 "Please set socket_params['allow_restart'] = True "
645 "if you want to continue"
646 )
647 )
648 else:
649 print(
650 f"{system_changes} have changed since last calculation. Restart the socket process."
651 )
652 self.close(keep_out_socket=True)
654 if self.process is None:
655 self.ensure_socket()
656 self.write_input(atoms)
657 cmds = self._make_command(
658 extras=f"-socket {self.in_socket_filename}:unix -name {self.label}"
659 )
660 # Use the IOContext class's lazy context manager
661 # TODO what if self.log is None
662 fd_log = self.openfile(file=self.log, comm=world)
663 self.process = subprocess.Popen(
664 cmds,
665 shell=True,
666 stdout=fd_log,
667 stderr=fd_log,
668 cwd=self.directory,
669 universal_newlines=True,
670 bufsize=0,
671 )
672 # in_socket is a server
673 ret = self.in_socket.calculate_origin_protocol(atoms[self.sort])
674 # The results are parsed from file outputs (.static + .out)
675 # Except for stress, they should be exactly the same as socket returned results
676 self.read_results() #
677 assert np.isclose(
678 ret["energy"], self.results["energy"]
679 ), "Energy values from socket communication and output file are different! Please contact the developers."
680 try:
681 assert np.isclose(
682 ret["forces"][self.resort], self.results["forces"]
683 ).all(), "Force values from socket communication and output file are different! Please contact the developers."
684 except KeyError:
685 print(
686 "Force values cannot be accessed via the results dictionary. They may not be available in the output file. Ensure PRINT_FORCES: 1\nResults:\n",
687 self.results,
688 )
689 # For stress information, we make sure that the stress is always present
690 if "stress" not in self.results:
691 virial_from_socket = ret.get("virial", np.zeros(6))
692 stress_from_socket = (
693 -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume()
694 )
695 self.results["stress"] = stress_from_socket
696 self.system_state = self._dump_system_state()
697 return
699 def get_stress(self, atoms=None):
700 """Warn user the dimensionality change when using stress"""
701 if "stress_equiv" in self.results:
702 raise NotImplementedError(
703 "You're requesting stress in a low-dimensional system. Please use `calc.results['stress_equiv']` instead!"
704 )
705 return super().get_stress(atoms)
707 def _check_input_exclusion(self, input_parameters, atoms=None):
708 """Check if mutually exclusive parameters are provided
710 The exclusion rules are taken from the SPARC manual and currently hard-coded.
711 We may need to have a clever way to do the automatic rule conversion in API
712 """
713 # Rule 1: ECUT, MESH_SPACING, FD_GRID
714 count = 0
715 for key in ["ECUT", "MESH_SPACING", "FD_GRID"]:
716 if key in input_parameters:
717 count += 1
718 if count > 1:
719 raise ValueError(
720 "ECUT, MESH_SPACING, FD_GRID cannot be specified simultaneously!"
721 )
723 # Rule 2: LATVEC_SCALE, CELL
724 if ("LATVEC_SCALE" in input_parameters) and ("CELL" in input_parameters):
725 raise ValueError(
726 "LATVEC_SCALE and CELL cannot be specified simultaneously!"
727 )
729 # When the cell is provided via ase object, we will forbid user to provide
730 # LATVEC, LATVEC_SCALE or CELL
731 if atoms is not None:
732 if any([p in input_parameters for p in ["LATVEC", "LATVEC_SCALE", "CELL"]]):
733 raise ValueError(
734 (
735 "When passing an ase atoms object, LATVEC, LATVEC_SCALE or CELL cannot be set simultaneously! \n"
736 "Please set atoms.cell instead"
737 )
738 )
740 def _check_minimal_input(self, input_parameters):
741 """Check if the minimal input set is satisfied"""
742 for param in ["EXCHANGE_CORRELATION", "KPOINT_GRID"]:
743 if param not in input_parameters:
744 raise ValueError(f"Parameter {param} is not provided.")
745 # At least one from ECUT, MESH_SPACING and FD_GRID must be provided
746 if not any(
747 [param in input_parameters for param in ("ECUT", "MESH_SPACING", "FD_GRID")]
748 ):
749 raise ValueError(
750 "You should provide at least one of ECUT, MESH_SPACING or FD_GRID."
751 )
753 def _generate_inpt_state(self, atoms, properties=[]):
754 """Return a key:value pair to be written to inpt file
755 This is an immutable dict as the ground truth
756 """
757 converted_params = self._convert_special_params(atoms=atoms)
758 input_parameters = converted_params.copy()
759 input_parameters.update(self.valid_params)
761 # Make sure desired properties are always ensured, but we don't modify the user inputs
762 if "forces" in properties:
763 input_parameters["PRINT_FORCES"] = True
765 if "stress" in properties:
766 input_parameters["CALC_STRESS"] = True
768 self._check_input_exclusion(input_parameters, atoms=atoms)
769 self._check_minimal_input(input_parameters)
770 return input_parameters
772 def write_input(self, atoms, properties=[], system_changes=[]):
773 """Create input files via SparcBundle
774 Will use the self.keep_sold_files options to keep old output files
775 like .out_01, .out_02 etc
776 """
777 # import pdb; pdb.set_trace()
778 FileIOCalculator.write_input(self, atoms, properties, system_changes)
779 input_parameters = self._generate_inpt_state(atoms, properties=properties)
781 # TODO: make sure the sorting reset is justified (i.e. what about restarting?)
782 self.sparc_bundle.sorting = None
783 self.sparc_bundle._write_ion_and_inpt(
784 atoms=atoms,
785 label=self.label,
786 # Pass the rest parameters from calculator!
787 direct=False,
788 sort=True,
789 ignore_constraints=False,
790 wrap=False,
791 # Below are the parameters from v1
792 # scaled -> direct, ignore_constraints --> not add_constraints
793 scaled=False,
794 add_constraints=True,
795 copy_psp=True,
796 comment="",
797 input_parameters=input_parameters,
798 )
800 output_patterns = [".out", ".static", ".eigen", ".aimd", "geopt"]
801 # We just remove the output files, in case the user has psp files manually copied
802 if self.keep_old_files is False:
803 for f in self.directory.glob("*"):
804 if (f.is_file()) and any(
805 [f.suffix.startswith(p) for p in output_patterns]
806 ):
807 os.remove(f)
808 self.inpt_state = input_parameters
809 self.system_state = self._dump_system_state()
810 return
812 def execute(self):
813 """Make a normal SPARC calculation without socket. Note we probably need to use a better handling of background process!"""
814 extras = f"-name {self.label}"
815 command = self._make_command(extras=extras)
816 self.print_sysinfo(command)
818 try:
819 if self.log is not None:
820 with open(self.log, "a") as fd:
821 self.process = subprocess.run(
822 command, shell=True, cwd=self.directory, stdout=fd
823 )
824 else:
825 self.process = subprocess.run(
826 command, shell=True, cwd=self.directory, stdout=None
827 )
828 except OSError as err:
829 msg = 'Failed to execute "{}"'.format(command)
830 raise EnvironmentError(msg) from err
832 # We probably don't want to wait the
833 errorcode = self.process.returncode
835 if errorcode > 0:
836 msg = f"SPARC failed with command {command}" f"with error code {errorcode}"
837 raise RuntimeError(msg)
839 return
841 def close(self, keep_out_socket=False):
842 """Close the socket communication, the SPARC process etc"""
843 if not self.use_socket:
844 return
845 if self.in_socket is not None:
846 self.in_socket.close()
848 if (self.out_socket is not None) and (not keep_out_socket):
849 self.out_socket.close()
851 # In most cases if in_socket is closed, the SPARC process should also exit
852 if self.process:
853 with time_limit(5):
854 ret = self.process.poll()
855 if ret is None:
856 print("Force terminate the sparc process!")
857 self._send_mpi_signal(signal.SIGKILL)
858 else:
859 print(f"SPARC process exists with code {ret}")
861 # TODO: check if in_socket should be merged
862 self.in_socket = None
863 if not keep_out_socket:
864 self.out_socket = None
865 self._reset_process()
867 def _send_mpi_signal(self, sig):
868 """Send signal to the mpi process within self.process
869 If the process cannot be found, return without affecting the state
871 This is a method taken from the vasp_interactive project
872 """
873 try:
874 pid = self.process.pid
875 psutil_proc = psutil.Process(pid)
876 except Exception as e:
877 warn("SPARC process no longer exists. Will reset the calculator.")
878 self._reset_process()
879 return
881 if (self.pid == pid) and getattr(self, "mpi_match", None) is not None:
882 match = self.mpi_match
883 else:
884 # self.pid = pid
885 match = _find_mpi_process(pid)
886 self.mpi_match = match
887 if (match["type"] is None) or (match["process"] is None):
888 warn(
889 "Cannot find the mpi process or you're using different ompi wrapper. Will not send stop signal to mpi."
890 )
891 return
892 elif match["type"] == "mpi":
893 mpi_process = match["process"]
894 mpi_process.send_signal(sig)
895 elif match["type"] == "slurm":
896 slurm_step = match["process"]
897 _slurm_signal(slurm_step, sig)
898 else:
899 raise ValueError("Unsupported process type!")
900 return
902 def _reset_process(self):
903 """Reset the record for process in the calculator.
904 Useful if the process is missing or reset the calculator.
905 """
906 # Reset process tracker
907 self.process = None
908 # self.pid = None
909 if hasattr(self, "mpi_match"):
910 self.mpi_match = None
911 self.mpi_state = None
913 @property
914 def pid(self):
915 """The pid for the stored process"""
916 if self.process is None:
917 return None
918 else:
919 return self.process.pid
921 @property
922 def raw_results(self):
923 return getattr(self.sparc_bundle, "raw_results", {})
925 @raw_results.setter
926 def raw_results(self, value):
927 self.sparc_bundle.raw_results = value
928 return
930 def read_results(self):
931 """Parse from the SparcBundle"""
932 # self.sparc_bundle.read_raw_results()
933 last = self.sparc_bundle.convert_to_ase(indices=-1, include_all_files=False)
934 self.atoms = last.copy()
935 self.results.update(last.calc.results)
937 def _restart(self, restart=None):
938 """Reload the input parameters and atoms from previous calculation.
940 If self.parameters is already set, the parameters will not be loaded
941 If self.atoms is already set, the atoms will be not be read
942 """
943 if restart is None:
944 return
945 reload_atoms = self.atoms is None
946 reload_parameters = len(self.parameters) == 0
948 self.read_results()
949 if not reload_atoms:
950 self.atoms = None
951 if reload_parameters:
952 self.parameters = self.raw_results["inpt"]["params"]
954 if (not reload_parameters) or (not reload_atoms):
955 warn(
956 "Extra parameters or atoms are provided when restarting the SPARC calculator, "
957 "previous results will be cleared."
958 )
959 self.results.clear()
960 self.sparc_bundle.raw_results.clear()
961 return
963 def get_fermi_level(self):
964 """Extra get-method for Fermi level, if calculated"""
965 return self.results.get("fermi", None)
967 def detect_sparc_version(self):
968 """Run a short sparc test to determine which sparc is used"""
969 try:
970 cmd = self._make_command()
971 except EnvironmentError:
972 return None
973 print("Running a short calculation to determine SPARC version....")
974 # check_version must be set to False to avoid recursive calling
975 new_calc = SPARC(
976 command=self.command, psp_dir=self.sparc_bundle.psp_dir, check_version=False
977 )
978 with tempfile.TemporaryDirectory() as tmpdir:
979 new_calc.set(xc="pbe", h=0.3, kpts=(1, 1, 1), maxit_scf=1, directory=tmpdir)
980 atoms = Atoms(["H"], positions=[[0.0, 0.0, 0.0]], cell=[2, 2, 2], pbc=False)
981 try:
982 new_calc.calculate(atoms)
983 version = new_calc.raw_results["out"]["sparc_version"]
984 except Exception as e:
985 print("Error handling simple calculation: ", e)
986 version = None
987 # Warning information about version mismatch between binary and JSON API
988 # only when both are not None
989 if (version is None) and (self.validator.sparc_version is not None):
990 if version != self.validator.sparc_version:
991 warn(
992 (
993 f"SPARC binary version {version} does not match JSON API version {self.validator.sparc_version}. "
994 "You can set $SPARC_DOC_PATH to the SPARC documentation location."
995 )
996 )
997 return version
999 def run_client(self, atoms=None, use_stress=False):
1000 """Main method to start the client code"""
1001 if not self.socket_mode == "client":
1002 raise RuntimeError(
1003 "Cannot use SPARC.run_client if the calculator is not configured in client mode!"
1004 )
1006 self.out_socket.run(atoms, use_stress)
1008 def detect_socket_compatibility(self):
1009 """Test if the sparc binary supports socket mode"""
1010 try:
1011 cmd = self._make_command()
1012 except EnvironmentError:
1013 return False
1014 with tempfile.TemporaryDirectory() as tmpdir:
1015 proc = subprocess.run(cmd, shell=True, cwd=tmpdir, capture_output=True)
1016 output = proc.stdout.decode("ascii")
1017 if "USAGE:" not in output:
1018 raise EnvironmentError(
1019 "Cannot find the sparc executable! Please make sure you have the correct setup"
1020 )
1021 compatibility = "-socket" in output
1022 return compatibility
1024 def set(self, **kwargs):
1025 """Overwrite the initial parameters"""
1026 # Do not use JSON Schema for these arguments
1027 if "label" in kwargs:
1028 self.label = kwargs.pop("label")
1030 if "directory" in kwargs:
1031 # str() call to deal with pathlib objects
1032 self.directory = str(kwargs.pop("directory"))
1034 if "log" in kwargs:
1035 self.log = kwargs.pop("log")
1037 if "check_version" in kwargs:
1038 self.check_version = bool(kwargs.pop("check_version"))
1040 if "keep_old_files" in kwargs:
1041 self.keep_old_files = kwargs.pop("keep_old_files")
1043 if "atoms" in kwargs:
1044 self.atoms = kwargs.pop("atoms") # Resets results
1046 if "command" in kwargs:
1047 self.command = kwargs.pop("command")
1049 # For now we don't let the user to hot-swap socket
1050 if ("use_socket" in kwargs) or ("socket_params" in kwargs):
1051 raise NotImplementedError("Hot swapping socket parameter is not supported!")
1053 self._sanitize_kwargs(**kwargs)
1054 set_params = {}
1055 set_params.update(self.special_params)
1056 set_params.update(self.valid_params)
1057 changed = super().set(**set_params)
1058 if changed != {}:
1059 self.reset()
1061 return changed
1063 def _sanitize_kwargs(self, **kwargs):
1064 """Convert known parameters from JSON API"""
1065 validator = self.validator
1066 if self.special_params == {}:
1067 init = True
1068 self.special_params = self.default_params.copy()
1069 else:
1070 init = False
1071 # User input gpts will overwrite default h
1072 # but user cannot put h and gpts both
1073 if "gpts" in kwargs:
1074 h = self.special_params.pop("h", None)
1075 if (h is not None) and (not init):
1076 warn("Parameter gpts will overwrite previously set parameter h.")
1077 elif "h" in kwargs:
1078 gpts = self.special_params.pop("gpts", None)
1079 if (gpts is not None) and (not init):
1080 warn("Parameter h will overwrite previously set parameter gpts.")
1082 upper_valid_params = set() # Valid SPARC parameters in upper case
1083 # SPARC API is case insensitive
1084 for key, value in kwargs.items():
1085 if key in self.special_inputs:
1086 # Special case: ignore h when gpts provided
1088 self.special_params[key] = value
1089 else:
1090 key = key.upper()
1091 if key in upper_valid_params:
1092 warn(f"Parameter {key} (case-insentive) appears multiple times!")
1093 if validator.validate_input(key, value):
1094 self.valid_params[key] = value
1095 upper_valid_params.add(key)
1096 else:
1097 raise ValueError(
1098 f"Value {value} for parameter {key} (case-insensitive) is invalid!"
1099 )
1100 return
1102 def _convert_special_params(self, atoms=None):
1103 """Convert ASE-compatible parameters to SPARC compatible ones
1104 parameters like `h`, `nbands` may need atoms information
1106 Special rules:
1107 h <--> gpts <--> FD_GRID, only when None of FD_GRID / ECUT or MESH_SPACING is provided
1108 """
1109 converted_sparc_params = {}
1110 validator = self.validator
1111 params = self.special_params.copy()
1113 # xc --> EXCHANGE_CORRELATION
1114 if "xc" in params:
1115 xc = params.pop("xc")
1116 if xc.lower() == "pbe":
1117 converted_sparc_params["EXCHANGE_CORRELATION"] = "GGA_PBE"
1118 elif xc.lower() == "lda":
1119 converted_sparc_params["EXCHANGE_CORRELATION"] = "LDA_PZ"
1120 elif xc.lower() == "rpbe":
1121 converted_sparc_params["EXCHANGE_CORRELATION"] = "GGA_RPBE"
1122 elif xc.lower() == "pbesol":
1123 converted_sparc_params["EXCHANGE_CORRELATION"] = "GGA_PBEsol"
1124 elif xc.lower() == "pbe0":
1125 converted_sparc_params["EXCHANGE_CORRELATION"] = "PBE0"
1126 elif xc.lower() == "hf":
1127 converted_sparc_params["EXCHANGE_CORRELATION"] = "HF"
1128 # backward compatibility for HSE03. Note HSE06 is not supported yet
1129 elif (xc.lower() == "hse") or (xc.lower() == "hse03"):
1130 converted_sparc_params["EXCHANGE_CORRELATION"] = "HSE"
1131 # backward compatibility for VASP-style XCs
1132 elif (
1133 (xc.lower() == "vdwdf1")
1134 or (xc.lower() == "vdw-df")
1135 or (xc.lower() == "vdw-df1")
1136 ):
1137 converted_sparc_params["EXCHANGE_CORRELATION"] = "vdWDF1"
1138 elif (xc.lower() == "vdwdf2") or (xc.lower() == "vdw-df2"):
1139 converted_sparc_params["EXCHANGE_CORRELATION"] = "vdWDF2"
1140 elif xc.lower() == "scan":
1141 converted_sparc_params["EXCHANGE_CORRELATION"] = "SCAN"
1142 else:
1143 raise ValueError(f"xc keyword value {xc} is invalid!")
1145 # h --> gpts
1146 if "h" in params:
1147 if "gpts" in params:
1148 raise KeyError(
1149 "h and gpts cannot be provided together in SPARC calculator!"
1150 )
1151 h = params.pop("h")
1152 # if atoms is None:
1153 # raise ValueError(
1154 # "Must have an active atoms object to convert h --> gpts!"
1155 # )
1156 if any(
1157 [p in self.valid_params for p in ("FD_GRID", "ECUT", "MESH_SPACING")]
1158 ):
1159 warn(
1160 "You have specified one of FD_GRID, ECUT or MESH_SPACING, "
1161 "conversion of h to mesh grid is ignored."
1162 )
1163 else:
1164 # gpts = h2gpts(h, atoms.cell)
1165 # params["gpts"] = gpts
1166 # Use mesh_spacing instead of fd_grid to avoid parameters
1167 converted_sparc_params["MESH_SPACING"] = h / Bohr
1169 # gpts --> FD_GRID
1170 if "gpts" in params:
1171 gpts = params.pop("gpts")
1172 if validator.validate_input("FD_GRID", gpts):
1173 converted_sparc_params["FD_GRID"] = gpts
1174 else:
1175 raise ValueError(f"Input parameter gpts has invalid value {gpts}")
1177 # kpts
1178 if "kpts" in params:
1179 kpts = params.pop("kpts")
1180 if validator.validate_input("KPOINT_GRID", kpts):
1181 converted_sparc_params["KPOINT_GRID"] = kpts
1182 else:
1183 raise ValueError(f"Input parameter kpts has invalid value {kpts}")
1185 # nbands
1186 if "nbands" in params:
1187 # TODO: Check if the nbands are correct in current system
1188 # TODO: default $N_e/2 \\times 1.2 + 5$
1189 nbands = params.pop("nbands")
1190 if validator.validate_input("NSTATES", nbands):
1191 converted_sparc_params["NSTATES"] = nbands
1192 else:
1193 raise ValueError(f"Input parameter nbands has invalid value {nbands}")
1195 # convergence is a dict
1196 if "convergence" in params:
1197 convergence = params.pop("convergence")
1198 tol_e = convergence.get("energy", None)
1199 if tol_e:
1200 # TOL SCF: Ha / atom <--> energy tol: eV / atom
1201 converted_sparc_params["TOL_SCF"] = tol_e / Hartree
1203 tol_f = convergence.get("relax", None)
1204 if tol_f:
1205 # TOL SCF: Ha / Bohr <--> energy tol: Ha / Bohr
1206 converted_sparc_params["TOL_RELAX"] = tol_f / Hartree * Bohr
1208 tol_dens = convergence.get("density", None)
1209 if tol_dens:
1210 # TOL SCF: electrons / atom
1211 converted_sparc_params["TOL_PSEUDOCHARGE"] = tol_dens
1213 tol_stress = convergence.get("stress", None)
1214 if tol_stress:
1215 # TOL SCF: electrons / atom
1216 converted_sparc_params["TOL_RELAX_CELL"] = tol_stress / GPa
1218 return converted_sparc_params
1220 def print_sysinfo(self, command=None):
1221 """Record current runtime information"""
1222 now = datetime.datetime.now().isoformat()
1223 if command is None:
1224 command = self.command
1225 msg = (
1226 "\n" + "*" * 80 + "\n"
1227 f"SPARC program started by SPARC-X-API at {now}\n"
1228 f"command: {command}\n"
1229 )
1230 if self.log is None:
1231 print(msg)
1232 else:
1233 with open(self.log, "a") as fd:
1234 print(msg, file=fd)
1236 ###############################################
1237 # Below are deprecated functions from v1
1238 ###############################################
1239 @deprecated("Please use SPARC.set instead for setting grid")
1240 def interpret_grid_input(self, atoms, **kwargs):
1241 return None
1243 @deprecated("Please use SPARC.set instead for setting kpoints")
1244 def interpret_kpoint_input(self, atoms, **kwargs):
1245 return None
1247 @deprecated("Please use SPARC.set instead for setting downsampling parameter")
1248 def interpret_downsampling_input(self, atoms, **kwargs):
1249 return None
1251 @deprecated("Please use SPARC.set instead for setting kpoint shift")
1252 def interpret_kpoint_shift(self, atoms, **kwargs):
1253 return None
1255 @deprecated("Please use SPARC.psp_dir instead")
1256 def get_pseudopotential_directory(self, pseudo_dir=None, **kwargs):
1257 return self.sparc_bundle.psp_dir
1259 def get_nstates(self):
1260 raise NotImplementedError("Parsing nstates is not yet implemented.")
1262 @deprecated("Please set the variables separatedly")
1263 def setup_parallel_env(self):
1264 return None
1266 @deprecated("Please use SPARC._make_command instead")
1267 def generate_command(self):
1268 return self._make_command(f"-name {self.label}")
1270 def estimate_memory(self, atoms=None, units="GB", **kwargs):
1271 """
1272 a function to estimate the amount of memory required to run
1273 the selected calculation. This function takes in **kwargs,
1274 but if none are passed in, it will fall back on the parameters
1275 input when the class was instantiated
1276 """
1277 conversion_dict = {
1278 "MB": 1e-6,
1279 "GB": 1e-9,
1280 "B": 1,
1281 "byte": 1,
1282 "KB": 1e-3,
1283 }
1284 if kwargs == {}:
1285 kwargs = self.parameters
1286 if atoms is None:
1287 atoms = self.atoms
1289 nstates = kwargs.get("NSTATES")
1290 if nstates is None:
1291 nstates = self.get_nstates(atoms=atoms, **kwargs)
1293 # some annoying code to figure out if it's a spin system
1294 spin_polarized = kwargs.get("nstates")
1295 if spin_polarized is not None:
1296 spin_polarized = int(spin_polarized)
1297 else:
1298 spin_polarized = 1
1299 if spin_polarized == 2:
1300 spin_factor = 2
1301 else:
1302 spin_factor = 1
1304 if "MESH_SPACING" in kwargs:
1305 # MESH_SPACING: Bohr; h: angstrom
1306 kwargs["h"] = kwargs.pop("MESH_SPACING") / Bohr
1307 npoints = np.product(self.interpret_grid_input(atoms, **kwargs))
1309 kpt_grid = self.interpret_kpoint_input(atoms, **kwargs)
1310 kpt_factor = np.ceil(np.product(kpt_grid) / 2)
1312 # this is a pretty generous over-estimate
1313 estimate = 5 * npoints * nstates * kpt_factor * spin_factor * 8 # bytes
1314 converted_estimate = estimate * conversion_dict[units]
1315 return converted_estimate
1317 def get_scf_steps(self, include_uncompleted_last_step=False):
1318 raise NotImplemented
1320 @deprecated("Use SPARC.get_number_of_ionic_steps instead")
1321 def get_geometric_steps(self, include_uncompleted_last_step=False):
1322 raise NotImplemented
1324 def get_runtime(self):
1325 raise NotImplemented
1327 def get_fermi_level(self):
1328 raise NotImplemented
1330 @deprecated
1331 def concatinate_output(self):
1332 raise DeprecationWarning("Functionality moved in sparc.SparcBundle.")
1334 @deprecated
1335 def read_line(self, **kwargs):
1336 raise DeprecationWarning(
1337 "Parsers for individual files have been moved to sparc.sparc_parsers module"
1338 )
1340 @deprecated
1341 def parse_output(self, **kwargs):
1342 raise DeprecationWarning("Use SPARC.read_results for parsing results!")
1344 @deprecated
1345 def parse_relax(self, *args, **kwargs):
1346 raise DeprecationWarning("Use SPARC.read_results for parsing results!")
1348 @deprecated
1349 def parse_MD(self, *args, **kwargs):
1350 raise DeprecationWarning("Use SPARC.read_results for parsing results!")
1352 @deprecated
1353 def parse_input_args(self, input_block):
1354 raise DeprecationWarning("Use SPARC.set for argument handling!")
1356 @deprecated
1357 def recover_index_order_from_ion_file(self, label):
1358 raise DeprecationWarning(
1359 "Use SPARC.sort and SPARC.resort for atomic index sorting!"
1360 )
1362 @deprecated
1363 def atoms_dict(self, *args, **kwargs):
1364 raise DeprecationWarning("")
1366 @deprecated
1367 def dict_atoms(self, *args, **kwargs):
1368 raise DeprecationWarning("")