Source code for sparc.utils

"""Utilities that are loosely related to core sparc functionalities
"""
import _thread
import io
import os
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import threading
import time
from contextlib import contextmanager
from pathlib import Path
from typing import List, Optional, Union
from warnings import warn

import numpy as np
import psutil

# 2024-11-28 @alchem0x2a add config
from ase.config import cfg as _cfg

from .api import SparcAPI
from .docparser import SparcDocParser


[docs] def deprecated(message): def decorator(func): def new_func(*args, **kwargs): warn( "Function {} is deprecated! {}".format(func.__name__, message), category=DeprecationWarning, ) return func(*args, **kwargs) return new_func return decorator
[docs] def compare_dict(d1, d2): """Helper function to compare dictionaries""" # Use symmetric difference to find keys which aren't shared # for python 2.7 compatibility if set(d1.keys()) ^ set(d2.keys()): return False # Check for differences in values for key, value in d1.items(): if np.any(value != d2[key]): return False return True
[docs] def string2index(string: str) -> Union[int, slice, str]: """Convert index string to either int or slice This method is a copy of ase.io.formats.string2index """ # A quick fix for slice if isinstance(string, (list, slice)): return string if ":" not in string: # may contain database accessor try: return int(string) except ValueError: return string i: List[Optional[int]] = [] for s in string.split(":"): if s == "": i.append(None) else: i.append(int(s)) i += (3 - len(i)) * [None] return slice(*i)
def _find_default_sparc(): """Find the default sparc by $PATH and mpi location""" sparc_exe = shutil.which("sparc") mpi_exe = shutil.which("mpirun") # TODO: more examples on pbs / lsf if mpi_exe is not None: try: num_cores = int( os.environ.get( "OMPI_COMM_WORLD_SIZE", os.environ.get( "OMPI_UNIVERSE_SIZE", os.environ.get("MPICH_RANK_REORDER_METHOD", ""), ).split(":")[-1], ) ) except Exception: num_cores = 1 return sparc_exe, mpi_exe, num_cores mpi_exe = shutil.which("srun") if mpi_exe is not None: # If srun is available, get the number of cores from the environment num_cores = int(os.environ.get("SLURM_JOB_CPUS_PER_NODE", 1)) return sparc_exe, mpi_exe, num_cores return sparc_exe, None, 1
[docs] def h2gpts(h, cell_cv, idiv=4): """Convert a h-parameter (Angstrom) to gpts""" cell_cv = np.array(cell_cv) cell_lengths = np.linalg.norm(cell_cv, axis=1) grid = np.ceil(cell_lengths / h) grid = np.maximum(idiv, grid) return [int(a) for a in grid]
[docs] def cprint(content, color=None, bold=False, underline=False, **kwargs): """Color print wrapper for ansi terminal. Only a few color names are provided """ ansi_color = dict( HEADER="\033[95m", COMMENT="\033[90m", OKBLUE="\033[94m", OKGREEN="\033[92m", OKCYAN="\033[96m", WARNING="\033[93m", FAIL="\033[91m", ENDC="\033[0m", ) style_codes = {"BOLD": "\033[1m", "UNDERLINE": "\033[4m"} if color is None: output = content elif color.upper() in ansi_color.keys() and color.upper() != "ENDC": output = ansi_color[color.upper()] + content + ansi_color["ENDC"] else: raise ValueError( f"Unknown ANSI color name. Allowed values are {list(ansi_color.keys())}" ) if bold: output = style_codes["BOLD"] + output + ansi_color["ENDC"] if underline: output = style_codes["UNDERLINE"] + output + ansi_color["ENDC"] print(output, **kwargs) return
[docs] def sanitize_path(path_string): """Sanitize path containing string in UNIX systems Returns a PosixPath object It is recommended to use this sanitize function before passing any path-like strings from cfg parser """ if isinstance(path_string, str): path = os.path.expandvars(os.path.expanduser(path_string)) path = Path(path).resolve() else: path = Path(path_string).resolve() return path
[docs] def locate_api(json_file=None, doc_path=None, cfg=_cfg): """ Locate the SPARC API setup file with the following priority: 1) If `json_file` is provided (either from parameter or cfg), use it directly. 2) If `doc_path` is provided: a) Function parameter takes precedence. b) Environment variable SPARC_DOC_PATH comes next. c) Configuration section [sparc] in the ini file is the last resort. 3) If both `json_file` and `doc_path` are provided, raise an exception. 4) Fallback to the default API setup if neither is provided. """ parser = cfg.parser["sparc"] if "sparc" in cfg.parser else {} if not json_file: json_file = parser.get("json_schema") if parser else None # Environment variable SPARC_DOC_PATH can overwrite user settings if not doc_path: doc_path = cfg.get("SPARC_DOC_PATH") if not doc_path: doc_path = parser.get("doc_path") if parser else None json_file = sanitize_path(json_file) if json_file else None doc_path = sanitize_path(doc_path) if doc_path else None # Step 4: Ensure mutual exclusivity if json_file and doc_path: raise ValueError( "Cannot set both the path of json file and documentation" "at the same time!" ) if json_file: if not json_file.is_file(): raise FileNotFoundError(f"JSON file '{json_file}' does not exist.") return SparcAPI(json_file) if doc_path: if not doc_path.is_dir(): raise FileNotFoundError( f"Documentation path '{doc_path}' does not exist or is not a directory." ) try: with tempfile.TemporaryDirectory() as tmpdir: tmpfile = Path(tmpdir) / "parameters.json" with open(tmpfile, "w") as fd: fd.write( SparcDocParser.json_from_directory( doc_path, include_subdirs=True ) ) api = SparcAPI(tmpfile) api.source = {"path": str(doc_path.resolve()), "type": "latex"} return api except Exception as e: raise RuntimeError( f"Failed to load API from documentation path '{doc_path}': {e}" ) # Fallback to default API return SparcAPI()
# Utilities taken from vasp_interactive project
[docs] class TimeoutException(Exception): """Simple class for timeout""" pass
[docs] @contextmanager def time_limit(seconds): """Usage: try: with time_limit(60): do_something() except TimeoutException: raise """ def signal_handler(signum, frame): raise TimeoutException("Timed out closing sparc process.") signal.signal(signal.SIGALRM, signal_handler) signal.alarm(seconds) try: yield finally: signal.alarm(0)
[docs] class ProcessReturned(Exception): """Simple class for process that has returned""" pass
[docs] @contextmanager def monitor_process(self, interval=1.0): """Usage: try: with monitor_process(process): do_something() except TimeoutException: raise """ def signal_handler(signum, frame): raise ProcessReturned( f"Process {self.process.pid} has returned with exit code {self.process.poll()}!" ) def check_process(): while True: if self.process.poll() is not None: # signal.alarm(0) print("The process has exited") self.in_socket.close() print(self.in_socket) signal(signal.SIGALRM) raise ProcessReturned( f"Process {self.process.pid} has returned with exit code {self.process.poll()}!" ) time.sleep(interval) if self.process is None: raise RuntimeError("No process selected!") signal.signal(signal.SIGALRM, signal_handler) monitor = threading.Thread(target=check_process) monitor.start() try: yield finally: monitor.join()
def _find_mpi_process(pid, mpi_program="mpirun", sparc_program="sparc"): """Recursively search children processes with PID=pid and return the one that mpirun (or synonyms) are the main command. If srun is found as the process, need to use `scancel` to pause / resume the job step """ allowed_names = set(["mpirun", "mpiexec", "orterun", "oshrun", "shmemrun"]) allowed_sparc_names = set(["sparc"]) if mpi_program: allowed_names.add(mpi_program) if sparc_program: allowed_sparc_names.add(sparc_program) try: process_list = [psutil.Process(pid)] except psutil.NoSuchProcess: warn( "Psutil cannot locate the pid. Your sparc program may have already exited." ) match = {"type": None, "process": None} return match process_list.extend(process_list[0].children(recursive=True)) mpi_candidates = [] match = {"type": None, "process": None} for proc in process_list: name = proc.name() if name in ["srun"]: match["type"] = "slurm" match["process"] = _locate_slurm_step(program=sparc_program) break elif proc.name() in allowed_names: # are the mpi process's direct children sparc binaries? children = proc.children() if len(children) > 0: if children[0].name() in allowed_sparc_names: mpi_candidates.append(proc) if len(mpi_candidates) > 1: warn( "More than 1 mpi processes are created. This may be a bug. I'll use the last one" ) if len(mpi_candidates) > 0: match["type"] = "mpi" match["process"] = mpi_candidates[-1] return match def _get_slurm_jobid(): jobid = os.environ.get("SLURM_JOB_ID", None) if jobid is None: jobid = os.environ.get("SLURM_JOBID", None) return jobid def _locate_slurm_step(program="sparc"): """If slurm job system is involved, search for the slurm step id that matches vasp_std (or other vasp commands) Steps: 1. Look for SLURM_JOB_ID in current env 2. Use `squeue` to locate the sparc step (latest) squeue """ allowed_names = set(["sparc"]) if program: allowed_names.add(program) jobid = _get_slurm_jobid() if jobid is None: # TODO: convert warn to logger warn(("Cannot locate the SLURM job id.")) return None # Only 2 column output (jobid and jobname) cmds = ["squeue", "-s", "--job", str(jobid), "-o", "%.30i %.30j"] proc = _run_process(cmds, capture_output=True) output = proc.stdout.decode("utf8").split("\n") # print(output) candidates = [] # breakpoint() for line in output[1:]: try: stepid, name = line.strip().split() except Exception: continue if any([v in name for v in allowed_names]): candidates.append(stepid) if len(candidates) > 1: warn("More than 1 slurm steps are found. I'll use the most recent one") if len(candidates) > 0: proc = candidates[0] else: proc = None return proc def _slurm_signal(stepid, sig=signal.SIGTSTP): if isinstance(sig, (str,)): sig = str(sig) elif isinstance(sig, (int,)): sig = signal.Signals(sig).name else: sig = sig.name cmds = ["scancel", "-s", sig, str(stepid)] proc = _run_process(cmds, capture_output=True) output = proc.stdout.decode("utf8").split("\n") return def _run_process(commands, shell=False, print_cmd=True, cwd=".", capture_output=False): """Wrap around subprocess.run Returns the process object """ full_cmd = " ".join(commands) if print_cmd: print(" ".join(commands)) if shell is False: proc = subprocess.run( commands, shell=shell, cwd=cwd, capture_output=capture_output ) else: proc = subprocess.run( full_cmd, shell=shell, cwd=cwd, capture_output=capture_output ) if proc.returncode == 0: return proc else: raise RuntimeError(f"Running {full_cmd} returned error code {proc.returncode}")