Coverage for sparc/utils.py: 44%
250 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
1"""Utilities that are loosely related to core sparc functionalities
2"""
3import _thread
4import io
5import os
6import re
7import shutil
8import signal
9import subprocess
10import sys
11import tempfile
12import threading
13import time
14from contextlib import contextmanager
15from pathlib import Path
16from typing import List, Optional, Union
17from warnings import warn
19import numpy as np
20import psutil
22# 2024-11-28 @alchem0x2a add config
23from ase.config import cfg as _cfg
24from ase.units import Hartree
26from .api import SparcAPI
27from .docparser import SparcDocParser
30def deprecated(message):
31 def decorator(func):
32 def new_func(*args, **kwargs):
33 warn(
34 "Function {} is deprecated! {}".format(func.__name__, message),
35 category=DeprecationWarning,
36 )
37 return func(*args, **kwargs)
39 return new_func
41 return decorator
44def compare_dict(d1, d2):
45 """Helper function to compare dictionaries"""
46 # Use symmetric difference to find keys which aren't shared
47 # for python 2.7 compatibility
48 if set(d1.keys()) ^ set(d2.keys()):
49 return False
51 # Check for differences in values
52 for key, value in d1.items():
53 if np.any(value != d2[key]):
54 return False
55 return True
58def string2index(string: str) -> Union[int, slice, str]:
59 """Convert index string to either int or slice
60 This method is a copy of ase.io.formats.string2index
61 """
62 # A quick fix for slice
63 if isinstance(string, (list, slice)):
64 return string
65 if ":" not in string:
66 # may contain database accessor
67 try:
68 return int(string)
69 except ValueError:
70 return string
71 i: List[Optional[int]] = []
72 for s in string.split(":"):
73 if s == "":
74 i.append(None)
75 else:
76 i.append(int(s))
77 i += (3 - len(i)) * [None]
78 return slice(*i)
81def _find_default_sparc():
82 """Find the default sparc by $PATH and mpi location"""
83 sparc_exe = shutil.which("sparc")
85 mpi_exe = shutil.which("mpirun")
86 # TODO: more examples on pbs / lsf
87 if mpi_exe is not None:
88 try:
89 num_cores = int(
90 os.environ.get(
91 "OMPI_COMM_WORLD_SIZE",
92 os.environ.get(
93 "OMPI_UNIVERSE_SIZE",
94 os.environ.get("MPICH_RANK_REORDER_METHOD", ""),
95 ).split(":")[-1],
96 )
97 )
98 except Exception:
99 num_cores = 1
100 return sparc_exe, mpi_exe, num_cores
102 mpi_exe = shutil.which("srun")
103 if mpi_exe is not None:
104 # If srun is available, get the number of cores from the environment
105 num_cores = int(os.environ.get("SLURM_JOB_CPUS_PER_NODE", 1))
106 return sparc_exe, mpi_exe, num_cores
108 return sparc_exe, None, 1
111def h2gpts(h, cell_cv, idiv=4):
112 """Convert a h-parameter (Angstrom) to gpts"""
113 cell_cv = np.array(cell_cv)
114 cell_lengths = np.linalg.norm(cell_cv, axis=1)
115 grid = np.ceil(cell_lengths / h)
116 grid = np.maximum(idiv, grid)
117 return [int(a) for a in grid]
120def parse_hubbard_string(hubbard_str):
121 """setups: {element: hubbard_string} parser from gpaw"""
122 # Parse DFT+U parameters from type-string:
123 # Examples: "type:l,U" or "type:l,U,scale"
124 # we rarely use the type as SPARC is read-space code
125 # the input should be in eV
126 # return a u_value array
127 # "U_VAL": 4-tupe-in-hartreee}
128 _, lus = hubbard_str.split(":")
130 l = []
131 U = []
132 scale = []
133 U = [0, 0, 0, 0] # s, p, d, f
134 for lu in lus.split(";"): # Multiple U corrections
135 # Scale is not used
136 l_, u_, scale_ = (lu + ",,").split(",")[:3]
137 l_ind = "spdf".find(l_)
138 if U[l_ind] != 0:
139 raise ValueError("Bad HUBBARD U value formatting. Multiple keys?")
140 U[l_ind] = float(u_) / Hartree # eV --> Hartree
141 return U
144def cprint(content, color=None, bold=False, underline=False, **kwargs):
145 """Color print wrapper for ansi terminal.
146 Only a few color names are provided
147 """
148 ansi_color = dict(
149 HEADER="\033[95m",
150 COMMENT="\033[90m",
151 OKBLUE="\033[94m",
152 OKGREEN="\033[92m",
153 OKCYAN="\033[96m",
154 WARNING="\033[93m",
155 FAIL="\033[91m",
156 ENDC="\033[0m",
157 )
159 style_codes = {"BOLD": "\033[1m", "UNDERLINE": "\033[4m"}
161 if color is None:
162 output = content
163 elif color.upper() in ansi_color.keys() and color.upper() != "ENDC":
164 output = ansi_color[color.upper()] + content + ansi_color["ENDC"]
165 else:
166 raise ValueError(
167 f"Unknown ANSI color name. Allowed values are {list(ansi_color.keys())}"
168 )
170 if bold:
171 output = style_codes["BOLD"] + output + ansi_color["ENDC"]
173 if underline:
174 output = style_codes["UNDERLINE"] + output + ansi_color["ENDC"]
176 print(output, **kwargs)
177 return
180def sanitize_path(path_string):
181 """Sanitize path containing string in UNIX systems
182 Returns a PosixPath object
184 It is recommended to use this sanitize function
185 before passing any path-like strings from cfg parser
186 """
187 if isinstance(path_string, str):
188 path = os.path.expandvars(os.path.expanduser(path_string))
189 path = Path(path).resolve()
190 else:
191 path = Path(path_string).resolve()
192 return path
195def locate_api(json_file=None, doc_path=None, cfg=_cfg):
196 """
197 Locate the SPARC API setup file with the following priority:
198 1) If `json_file` is provided (either from parameter or cfg), use it directly.
199 2) If `doc_path` is provided:
200 a) Function parameter takes precedence.
201 b) Environment variable SPARC_DOC_PATH comes next.
202 c) Configuration section [sparc] in the ini file is the last resort.
203 3) If both `json_file` and `doc_path` are provided, raise an exception.
204 4) Fallback to the default API setup if neither is provided.
205 """
206 parser = cfg.parser["sparc"] if "sparc" in cfg.parser else {}
207 if not json_file:
208 json_file = parser.get("json_schema") if parser else None
210 # Environment variable SPARC_DOC_PATH can overwrite user settings
211 if not doc_path:
212 doc_path = cfg.get("SPARC_DOC_PATH")
214 if not doc_path:
215 doc_path = parser.get("doc_path") if parser else None
217 json_file = sanitize_path(json_file) if json_file else None
218 doc_path = sanitize_path(doc_path) if doc_path else None
220 # Step 4: Ensure mutual exclusivity
221 if json_file and doc_path:
222 raise ValueError(
223 "Cannot set both the path of json file and documentation"
224 "at the same time!"
225 )
227 if json_file:
228 if not json_file.is_file():
229 raise FileNotFoundError(f"JSON file '{json_file}' does not exist.")
230 return SparcAPI(json_file)
232 if doc_path:
233 if not doc_path.is_dir():
234 raise FileNotFoundError(
235 f"Documentation path '{doc_path}' does not exist or is not a directory."
236 )
237 try:
238 with tempfile.TemporaryDirectory() as tmpdir:
239 tmpfile = Path(tmpdir) / "parameters.json"
240 with open(tmpfile, "w") as fd:
241 fd.write(
242 SparcDocParser.json_from_directory(
243 doc_path, include_subdirs=True
244 )
245 )
246 api = SparcAPI(tmpfile)
247 api.source = {"path": str(doc_path.resolve()), "type": "latex"}
248 return api
249 except Exception as e:
250 raise RuntimeError(
251 f"Failed to load API from documentation path '{doc_path}': {e}"
252 )
254 # Fallback to default API
255 return SparcAPI()
258# Utilities taken from vasp_interactive project
261class TimeoutException(Exception):
262 """Simple class for timeout"""
264 pass
267@contextmanager
268def time_limit(seconds):
269 """Usage:
270 try:
271 with time_limit(60):
272 do_something()
273 except TimeoutException:
274 raise
275 """
277 def signal_handler(signum, frame):
278 raise TimeoutException("Timed out closing sparc process.")
280 signal.signal(signal.SIGALRM, signal_handler)
281 signal.alarm(seconds)
282 try:
283 yield
284 finally:
285 signal.alarm(0)
288class ProcessReturned(Exception):
289 """Simple class for process that has returned"""
291 pass
294@contextmanager
295def monitor_process(self, interval=1.0):
296 """Usage:
297 try:
298 with monitor_process(process):
299 do_something()
300 except TimeoutException:
301 raise
302 """
304 def signal_handler(signum, frame):
305 raise ProcessReturned(
306 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!"
307 )
309 def check_process():
310 while True:
311 if self.process.poll() is not None:
312 # signal.alarm(0)
313 print("The process has exited")
314 self.in_socket.close()
315 print(self.in_socket)
316 signal(signal.SIGALRM)
317 raise ProcessReturned(
318 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!"
319 )
320 time.sleep(interval)
322 if self.process is None:
323 raise RuntimeError("No process selected!")
325 signal.signal(signal.SIGALRM, signal_handler)
326 monitor = threading.Thread(target=check_process)
327 monitor.start()
328 try:
329 yield
330 finally:
331 monitor.join()
334def _find_mpi_process(pid, mpi_program="mpirun", sparc_program="sparc"):
335 """Recursively search children processes with PID=pid and return the one
336 that mpirun (or synonyms) are the main command.
338 If srun is found as the process, need to use `scancel` to pause / resume the job step
339 """
340 allowed_names = set(["mpirun", "mpiexec", "orterun", "oshrun", "shmemrun"])
341 allowed_sparc_names = set(["sparc"])
342 if mpi_program:
343 allowed_names.add(mpi_program)
344 if sparc_program:
345 allowed_sparc_names.add(sparc_program)
346 try:
347 process_list = [psutil.Process(pid)]
348 except psutil.NoSuchProcess:
349 warn(
350 "Psutil cannot locate the pid. Your sparc program may have already exited."
351 )
352 match = {"type": None, "process": None}
353 return match
355 process_list.extend(process_list[0].children(recursive=True))
356 mpi_candidates = []
357 match = {"type": None, "process": None}
358 for proc in process_list:
359 name = proc.name()
360 if name in ["srun"]:
361 match["type"] = "slurm"
362 match["process"] = _locate_slurm_step(program=sparc_program)
363 break
364 elif proc.name() in allowed_names:
365 # are the mpi process's direct children sparc binaries?
366 children = proc.children()
367 if len(children) > 0:
368 if children[0].name() in allowed_sparc_names:
369 mpi_candidates.append(proc)
370 if len(mpi_candidates) > 1:
371 warn(
372 "More than 1 mpi processes are created. This may be a bug. I'll use the last one"
373 )
374 if len(mpi_candidates) > 0:
375 match["type"] = "mpi"
376 match["process"] = mpi_candidates[-1]
378 return match
381def _get_slurm_jobid():
382 jobid = os.environ.get("SLURM_JOB_ID", None)
383 if jobid is None:
384 jobid = os.environ.get("SLURM_JOBID", None)
385 return jobid
388def _locate_slurm_step(program="sparc"):
389 """If slurm job system is involved, search for the slurm step id
390 that matches vasp_std (or other vasp commands)
392 Steps:
393 1. Look for SLURM_JOB_ID in current env
394 2. Use `squeue` to locate the sparc step (latest)
396 squeue
397 """
398 allowed_names = set(["sparc"])
399 if program:
400 allowed_names.add(program)
401 jobid = _get_slurm_jobid()
402 if jobid is None:
403 # TODO: convert warn to logger
404 warn(("Cannot locate the SLURM job id."))
405 return None
406 # Only 2 column output (jobid and jobname)
407 cmds = ["squeue", "-s", "--job", str(jobid), "-o", "%.30i %.30j"]
408 proc = _run_process(cmds, capture_output=True)
409 output = proc.stdout.decode("utf8").split("\n")
410 # print(output)
411 candidates = []
412 # breakpoint()
413 for line in output[1:]:
414 try:
415 stepid, name = line.strip().split()
416 except Exception:
417 continue
418 if any([v in name for v in allowed_names]):
419 candidates.append(stepid)
421 if len(candidates) > 1:
422 warn("More than 1 slurm steps are found. I'll use the most recent one")
423 if len(candidates) > 0:
424 proc = candidates[0]
425 else:
426 proc = None
427 return proc
430def _slurm_signal(stepid, sig=signal.SIGTSTP):
431 if isinstance(sig, (str,)):
432 sig = str(sig)
433 elif isinstance(sig, (int,)):
434 sig = signal.Signals(sig).name
435 else:
436 sig = sig.name
437 cmds = ["scancel", "-s", sig, str(stepid)]
438 proc = _run_process(cmds, capture_output=True)
439 output = proc.stdout.decode("utf8").split("\n")
440 return
443def _run_process(commands, shell=False, print_cmd=True, cwd=".", capture_output=False):
444 """Wrap around subprocess.run
445 Returns the process object
446 """
447 full_cmd = " ".join(commands)
448 if print_cmd:
449 print(" ".join(commands))
450 if shell is False:
451 proc = subprocess.run(
452 commands, shell=shell, cwd=cwd, capture_output=capture_output
453 )
454 else:
455 proc = subprocess.run(
456 full_cmd, shell=shell, cwd=cwd, capture_output=capture_output
457 )
458 if proc.returncode == 0:
459 return proc
460 else:
461 raise RuntimeError(f"Running {full_cmd} returned error code {proc.returncode}")