Coverage for sparc/utils.py: 41%
236 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
1"""Utilities that are loosely related to core sparc functionalities
2"""
3import _thread
4import io
5import os
6import re
7import shutil
8import signal
9import subprocess
10import sys
11import tempfile
12import threading
13import time
14from contextlib import contextmanager
15from pathlib import Path
16from typing import List, Optional, Union
17from warnings import warn
19import numpy as np
20import psutil
22# 2024-11-28 @alchem0x2a add config
23from ase.config import cfg as _cfg
25from .api import SparcAPI
26from .docparser import SparcDocParser
29def deprecated(message):
30 def decorator(func):
31 def new_func(*args, **kwargs):
32 warn(
33 "Function {} is deprecated! {}".format(func.__name__, message),
34 category=DeprecationWarning,
35 )
36 return func(*args, **kwargs)
38 return new_func
40 return decorator
43def compare_dict(d1, d2):
44 """Helper function to compare dictionaries"""
45 # Use symmetric difference to find keys which aren't shared
46 # for python 2.7 compatibility
47 if set(d1.keys()) ^ set(d2.keys()):
48 return False
50 # Check for differences in values
51 for key, value in d1.items():
52 if np.any(value != d2[key]):
53 return False
54 return True
57def string2index(string: str) -> Union[int, slice, str]:
58 """Convert index string to either int or slice
59 This method is a copy of ase.io.formats.string2index
60 """
61 # A quick fix for slice
62 if isinstance(string, (list, slice)):
63 return string
64 if ":" not in string:
65 # may contain database accessor
66 try:
67 return int(string)
68 except ValueError:
69 return string
70 i: List[Optional[int]] = []
71 for s in string.split(":"):
72 if s == "":
73 i.append(None)
74 else:
75 i.append(int(s))
76 i += (3 - len(i)) * [None]
77 return slice(*i)
80def _find_default_sparc():
81 """Find the default sparc by $PATH and mpi location"""
82 sparc_exe = shutil.which("sparc")
84 mpi_exe = shutil.which("mpirun")
85 # TODO: more examples on pbs / lsf
86 if mpi_exe is not None:
87 try:
88 num_cores = int(
89 os.environ.get(
90 "OMPI_COMM_WORLD_SIZE",
91 os.environ.get(
92 "OMPI_UNIVERSE_SIZE",
93 os.environ.get("MPICH_RANK_REORDER_METHOD", ""),
94 ).split(":")[-1],
95 )
96 )
97 except Exception:
98 num_cores = 1
99 return sparc_exe, mpi_exe, num_cores
101 mpi_exe = shutil.which("srun")
102 if mpi_exe is not None:
103 # If srun is available, get the number of cores from the environment
104 num_cores = int(os.environ.get("SLURM_JOB_CPUS_PER_NODE", 1))
105 return sparc_exe, mpi_exe, num_cores
107 return sparc_exe, None, 1
110def h2gpts(h, cell_cv, idiv=4):
111 """Convert a h-parameter (Angstrom) to gpts"""
112 cell_cv = np.array(cell_cv)
113 cell_lengths = np.linalg.norm(cell_cv, axis=1)
114 grid = np.ceil(cell_lengths / h)
115 grid = np.maximum(idiv, grid)
116 return [int(a) for a in grid]
119def cprint(content, color=None, bold=False, underline=False, **kwargs):
120 """Color print wrapper for ansi terminal.
121 Only a few color names are provided
122 """
123 ansi_color = dict(
124 HEADER="\033[95m",
125 COMMENT="\033[90m",
126 OKBLUE="\033[94m",
127 OKGREEN="\033[92m",
128 OKCYAN="\033[96m",
129 WARNING="\033[93m",
130 FAIL="\033[91m",
131 ENDC="\033[0m",
132 )
134 style_codes = {"BOLD": "\033[1m", "UNDERLINE": "\033[4m"}
136 if color is None:
137 output = content
138 elif color.upper() in ansi_color.keys() and color.upper() != "ENDC":
139 output = ansi_color[color.upper()] + content + ansi_color["ENDC"]
140 else:
141 raise ValueError(
142 f"Unknown ANSI color name. Allowed values are {list(ansi_color.keys())}"
143 )
145 if bold:
146 output = style_codes["BOLD"] + output + ansi_color["ENDC"]
148 if underline:
149 output = style_codes["UNDERLINE"] + output + ansi_color["ENDC"]
151 print(output, **kwargs)
152 return
155def sanitize_path(path_string):
156 """Sanitize path containing string in UNIX systems
157 Returns a PosixPath object
159 It is recommended to use this sanitize function
160 before passing any path-like strings from cfg parser
161 """
162 if isinstance(path_string, str):
163 path = os.path.expandvars(os.path.expanduser(path_string))
164 path = Path(path).resolve()
165 else:
166 path = Path(path_string).resolve()
167 return path
170def locate_api(json_file=None, doc_path=None, cfg=_cfg):
171 """
172 Locate the SPARC API setup file with the following priority:
173 1) If `json_file` is provided (either from parameter or cfg), use it directly.
174 2) If `doc_path` is provided:
175 a) Function parameter takes precedence.
176 b) Environment variable SPARC_DOC_PATH comes next.
177 c) Configuration section [sparc] in the ini file is the last resort.
178 3) If both `json_file` and `doc_path` are provided, raise an exception.
179 4) Fallback to the default API setup if neither is provided.
180 """
181 parser = cfg.parser["sparc"] if "sparc" in cfg.parser else {}
182 if not json_file:
183 json_file = parser.get("json_schema") if parser else None
185 # Environment variable SPARC_DOC_PATH can overwrite user settings
186 if not doc_path:
187 doc_path = cfg.get("SPARC_DOC_PATH")
189 if not doc_path:
190 doc_path = parser.get("doc_path") if parser else None
192 json_file = sanitize_path(json_file) if json_file else None
193 doc_path = sanitize_path(doc_path) if doc_path else None
195 # Step 4: Ensure mutual exclusivity
196 if json_file and doc_path:
197 raise ValueError(
198 "Cannot set both the path of json file and documentation"
199 "at the same time!"
200 )
202 if json_file:
203 if not json_file.is_file():
204 raise FileNotFoundError(f"JSON file '{json_file}' does not exist.")
205 return SparcAPI(json_file)
207 if doc_path:
208 if not doc_path.is_dir():
209 raise FileNotFoundError(
210 f"Documentation path '{doc_path}' does not exist or is not a directory."
211 )
212 try:
213 with tempfile.TemporaryDirectory() as tmpdir:
214 tmpfile = Path(tmpdir) / "parameters.json"
215 with open(tmpfile, "w") as fd:
216 fd.write(
217 SparcDocParser.json_from_directory(
218 doc_path, include_subdirs=True
219 )
220 )
221 api = SparcAPI(tmpfile)
222 api.source = {"path": str(doc_path.resolve()), "type": "latex"}
223 return api
224 except Exception as e:
225 raise RuntimeError(
226 f"Failed to load API from documentation path '{doc_path}': {e}"
227 )
229 # Fallback to default API
230 return SparcAPI()
233# Utilities taken from vasp_interactive project
236class TimeoutException(Exception):
237 """Simple class for timeout"""
239 pass
242@contextmanager
243def time_limit(seconds):
244 """Usage:
245 try:
246 with time_limit(60):
247 do_something()
248 except TimeoutException:
249 raise
250 """
252 def signal_handler(signum, frame):
253 raise TimeoutException("Timed out closing sparc process.")
255 signal.signal(signal.SIGALRM, signal_handler)
256 signal.alarm(seconds)
257 try:
258 yield
259 finally:
260 signal.alarm(0)
263class ProcessReturned(Exception):
264 """Simple class for process that has returned"""
266 pass
269@contextmanager
270def monitor_process(self, interval=1.0):
271 """Usage:
272 try:
273 with monitor_process(process):
274 do_something()
275 except TimeoutException:
276 raise
277 """
279 def signal_handler(signum, frame):
280 raise ProcessReturned(
281 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!"
282 )
284 def check_process():
285 while True:
286 if self.process.poll() is not None:
287 # signal.alarm(0)
288 print("The process has exited")
289 self.in_socket.close()
290 print(self.in_socket)
291 signal(signal.SIGALRM)
292 raise ProcessReturned(
293 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!"
294 )
295 time.sleep(interval)
297 if self.process is None:
298 raise RuntimeError("No process selected!")
300 signal.signal(signal.SIGALRM, signal_handler)
301 monitor = threading.Thread(target=check_process)
302 monitor.start()
303 try:
304 yield
305 finally:
306 monitor.join()
309def _find_mpi_process(pid, mpi_program="mpirun", sparc_program="sparc"):
310 """Recursively search children processes with PID=pid and return the one
311 that mpirun (or synonyms) are the main command.
313 If srun is found as the process, need to use `scancel` to pause / resume the job step
314 """
315 allowed_names = set(["mpirun", "mpiexec", "orterun", "oshrun", "shmemrun"])
316 allowed_sparc_names = set(["sparc"])
317 if mpi_program:
318 allowed_names.add(mpi_program)
319 if sparc_program:
320 allowed_sparc_names.add(sparc_program)
321 try:
322 process_list = [psutil.Process(pid)]
323 except psutil.NoSuchProcess:
324 warn(
325 "Psutil cannot locate the pid. Your sparc program may have already exited."
326 )
327 match = {"type": None, "process": None}
328 return match
330 process_list.extend(process_list[0].children(recursive=True))
331 mpi_candidates = []
332 match = {"type": None, "process": None}
333 for proc in process_list:
334 name = proc.name()
335 if name in ["srun"]:
336 match["type"] = "slurm"
337 match["process"] = _locate_slurm_step(program=sparc_program)
338 break
339 elif proc.name() in allowed_names:
340 # are the mpi process's direct children sparc binaries?
341 children = proc.children()
342 if len(children) > 0:
343 if children[0].name() in allowed_sparc_names:
344 mpi_candidates.append(proc)
345 if len(mpi_candidates) > 1:
346 warn(
347 "More than 1 mpi processes are created. This may be a bug. I'll use the last one"
348 )
349 if len(mpi_candidates) > 0:
350 match["type"] = "mpi"
351 match["process"] = mpi_candidates[-1]
353 return match
356def _get_slurm_jobid():
357 jobid = os.environ.get("SLURM_JOB_ID", None)
358 if jobid is None:
359 jobid = os.environ.get("SLURM_JOBID", None)
360 return jobid
363def _locate_slurm_step(program="sparc"):
364 """If slurm job system is involved, search for the slurm step id
365 that matches vasp_std (or other vasp commands)
367 Steps:
368 1. Look for SLURM_JOB_ID in current env
369 2. Use `squeue` to locate the sparc step (latest)
371 squeue
372 """
373 allowed_names = set(["sparc"])
374 if program:
375 allowed_names.add(program)
376 jobid = _get_slurm_jobid()
377 if jobid is None:
378 # TODO: convert warn to logger
379 warn(("Cannot locate the SLURM job id."))
380 return None
381 # Only 2 column output (jobid and jobname)
382 cmds = ["squeue", "-s", "--job", str(jobid), "-o", "%.30i %.30j"]
383 proc = _run_process(cmds, capture_output=True)
384 output = proc.stdout.decode("utf8").split("\n")
385 # print(output)
386 candidates = []
387 # breakpoint()
388 for line in output[1:]:
389 try:
390 stepid, name = line.strip().split()
391 except Exception:
392 continue
393 if any([v in name for v in allowed_names]):
394 candidates.append(stepid)
396 if len(candidates) > 1:
397 warn("More than 1 slurm steps are found. I'll use the most recent one")
398 if len(candidates) > 0:
399 proc = candidates[0]
400 else:
401 proc = None
402 return proc
405def _slurm_signal(stepid, sig=signal.SIGTSTP):
406 if isinstance(sig, (str,)):
407 sig = str(sig)
408 elif isinstance(sig, (int,)):
409 sig = signal.Signals(sig).name
410 else:
411 sig = sig.name
412 cmds = ["scancel", "-s", sig, str(stepid)]
413 proc = _run_process(cmds, capture_output=True)
414 output = proc.stdout.decode("utf8").split("\n")
415 return
418def _run_process(commands, shell=False, print_cmd=True, cwd=".", capture_output=False):
419 """Wrap around subprocess.run
420 Returns the process object
421 """
422 full_cmd = " ".join(commands)
423 if print_cmd:
424 print(" ".join(commands))
425 if shell is False:
426 proc = subprocess.run(
427 commands, shell=shell, cwd=cwd, capture_output=capture_output
428 )
429 else:
430 proc = subprocess.run(
431 full_cmd, shell=shell, cwd=cwd, capture_output=capture_output
432 )
433 if proc.returncode == 0:
434 return proc
435 else:
436 raise RuntimeError(f"Running {full_cmd} returned error code {proc.returncode}")