Coverage for sparc/utils.py: 44%

250 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 16:19 +0000

1"""Utilities that are loosely related to core sparc functionalities 

2""" 

3import _thread 

4import io 

5import os 

6import re 

7import shutil 

8import signal 

9import subprocess 

10import sys 

11import tempfile 

12import threading 

13import time 

14from contextlib import contextmanager 

15from pathlib import Path 

16from typing import List, Optional, Union 

17from warnings import warn 

18 

19import numpy as np 

20import psutil 

21 

22# 2024-11-28 @alchem0x2a add config 

23from ase.config import cfg as _cfg 

24from ase.units import Hartree 

25 

26from .api import SparcAPI 

27from .docparser import SparcDocParser 

28 

29 

30def deprecated(message): 

31 def decorator(func): 

32 def new_func(*args, **kwargs): 

33 warn( 

34 "Function {} is deprecated! {}".format(func.__name__, message), 

35 category=DeprecationWarning, 

36 ) 

37 return func(*args, **kwargs) 

38 

39 return new_func 

40 

41 return decorator 

42 

43 

44def compare_dict(d1, d2): 

45 """Helper function to compare dictionaries""" 

46 # Use symmetric difference to find keys which aren't shared 

47 # for python 2.7 compatibility 

48 if set(d1.keys()) ^ set(d2.keys()): 

49 return False 

50 

51 # Check for differences in values 

52 for key, value in d1.items(): 

53 if np.any(value != d2[key]): 

54 return False 

55 return True 

56 

57 

58def string2index(string: str) -> Union[int, slice, str]: 

59 """Convert index string to either int or slice 

60 This method is a copy of ase.io.formats.string2index 

61 """ 

62 # A quick fix for slice 

63 if isinstance(string, (list, slice)): 

64 return string 

65 if ":" not in string: 

66 # may contain database accessor 

67 try: 

68 return int(string) 

69 except ValueError: 

70 return string 

71 i: List[Optional[int]] = [] 

72 for s in string.split(":"): 

73 if s == "": 

74 i.append(None) 

75 else: 

76 i.append(int(s)) 

77 i += (3 - len(i)) * [None] 

78 return slice(*i) 

79 

80 

81def _find_default_sparc(): 

82 """Find the default sparc by $PATH and mpi location""" 

83 sparc_exe = shutil.which("sparc") 

84 

85 mpi_exe = shutil.which("mpirun") 

86 # TODO: more examples on pbs / lsf 

87 if mpi_exe is not None: 

88 try: 

89 num_cores = int( 

90 os.environ.get( 

91 "OMPI_COMM_WORLD_SIZE", 

92 os.environ.get( 

93 "OMPI_UNIVERSE_SIZE", 

94 os.environ.get("MPICH_RANK_REORDER_METHOD", ""), 

95 ).split(":")[-1], 

96 ) 

97 ) 

98 except Exception: 

99 num_cores = 1 

100 return sparc_exe, mpi_exe, num_cores 

101 

102 mpi_exe = shutil.which("srun") 

103 if mpi_exe is not None: 

104 # If srun is available, get the number of cores from the environment 

105 num_cores = int(os.environ.get("SLURM_JOB_CPUS_PER_NODE", 1)) 

106 return sparc_exe, mpi_exe, num_cores 

107 

108 return sparc_exe, None, 1 

109 

110 

111def h2gpts(h, cell_cv, idiv=4): 

112 """Convert a h-parameter (Angstrom) to gpts""" 

113 cell_cv = np.array(cell_cv) 

114 cell_lengths = np.linalg.norm(cell_cv, axis=1) 

115 grid = np.ceil(cell_lengths / h) 

116 grid = np.maximum(idiv, grid) 

117 return [int(a) for a in grid] 

118 

119 

120def parse_hubbard_string(hubbard_str): 

121 """setups: {element: hubbard_string} parser from gpaw""" 

122 # Parse DFT+U parameters from type-string: 

123 # Examples: "type:l,U" or "type:l,U,scale" 

124 # we rarely use the type as SPARC is read-space code 

125 # the input should be in eV 

126 # return a u_value array 

127 # "U_VAL": 4-tupe-in-hartreee} 

128 _, lus = hubbard_str.split(":") 

129 

130 l = [] 

131 U = [] 

132 scale = [] 

133 U = [0, 0, 0, 0] # s, p, d, f 

134 for lu in lus.split(";"): # Multiple U corrections 

135 # Scale is not used 

136 l_, u_, scale_ = (lu + ",,").split(",")[:3] 

137 l_ind = "spdf".find(l_) 

138 if U[l_ind] != 0: 

139 raise ValueError("Bad HUBBARD U value formatting. Multiple keys?") 

140 U[l_ind] = float(u_) / Hartree # eV --> Hartree 

141 return U 

142 

143 

144def cprint(content, color=None, bold=False, underline=False, **kwargs): 

145 """Color print wrapper for ansi terminal. 

146 Only a few color names are provided 

147 """ 

148 ansi_color = dict( 

149 HEADER="\033[95m", 

150 COMMENT="\033[90m", 

151 OKBLUE="\033[94m", 

152 OKGREEN="\033[92m", 

153 OKCYAN="\033[96m", 

154 WARNING="\033[93m", 

155 FAIL="\033[91m", 

156 ENDC="\033[0m", 

157 ) 

158 

159 style_codes = {"BOLD": "\033[1m", "UNDERLINE": "\033[4m"} 

160 

161 if color is None: 

162 output = content 

163 elif color.upper() in ansi_color.keys() and color.upper() != "ENDC": 

164 output = ansi_color[color.upper()] + content + ansi_color["ENDC"] 

165 else: 

166 raise ValueError( 

167 f"Unknown ANSI color name. Allowed values are {list(ansi_color.keys())}" 

168 ) 

169 

170 if bold: 

171 output = style_codes["BOLD"] + output + ansi_color["ENDC"] 

172 

173 if underline: 

174 output = style_codes["UNDERLINE"] + output + ansi_color["ENDC"] 

175 

176 print(output, **kwargs) 

177 return 

178 

179 

180def sanitize_path(path_string): 

181 """Sanitize path containing string in UNIX systems 

182 Returns a PosixPath object 

183 

184 It is recommended to use this sanitize function 

185 before passing any path-like strings from cfg parser 

186 """ 

187 if isinstance(path_string, str): 

188 path = os.path.expandvars(os.path.expanduser(path_string)) 

189 path = Path(path).resolve() 

190 else: 

191 path = Path(path_string).resolve() 

192 return path 

193 

194 

195def locate_api(json_file=None, doc_path=None, cfg=_cfg): 

196 """ 

197 Locate the SPARC API setup file with the following priority: 

198 1) If `json_file` is provided (either from parameter or cfg), use it directly. 

199 2) If `doc_path` is provided: 

200 a) Function parameter takes precedence. 

201 b) Environment variable SPARC_DOC_PATH comes next. 

202 c) Configuration section [sparc] in the ini file is the last resort. 

203 3) If both `json_file` and `doc_path` are provided, raise an exception. 

204 4) Fallback to the default API setup if neither is provided. 

205 """ 

206 parser = cfg.parser["sparc"] if "sparc" in cfg.parser else {} 

207 if not json_file: 

208 json_file = parser.get("json_schema") if parser else None 

209 

210 # Environment variable SPARC_DOC_PATH can overwrite user settings 

211 if not doc_path: 

212 doc_path = cfg.get("SPARC_DOC_PATH") 

213 

214 if not doc_path: 

215 doc_path = parser.get("doc_path") if parser else None 

216 

217 json_file = sanitize_path(json_file) if json_file else None 

218 doc_path = sanitize_path(doc_path) if doc_path else None 

219 

220 # Step 4: Ensure mutual exclusivity 

221 if json_file and doc_path: 

222 raise ValueError( 

223 "Cannot set both the path of json file and documentation" 

224 "at the same time!" 

225 ) 

226 

227 if json_file: 

228 if not json_file.is_file(): 

229 raise FileNotFoundError(f"JSON file '{json_file}' does not exist.") 

230 return SparcAPI(json_file) 

231 

232 if doc_path: 

233 if not doc_path.is_dir(): 

234 raise FileNotFoundError( 

235 f"Documentation path '{doc_path}' does not exist or is not a directory." 

236 ) 

237 try: 

238 with tempfile.TemporaryDirectory() as tmpdir: 

239 tmpfile = Path(tmpdir) / "parameters.json" 

240 with open(tmpfile, "w") as fd: 

241 fd.write( 

242 SparcDocParser.json_from_directory( 

243 doc_path, include_subdirs=True 

244 ) 

245 ) 

246 api = SparcAPI(tmpfile) 

247 api.source = {"path": str(doc_path.resolve()), "type": "latex"} 

248 return api 

249 except Exception as e: 

250 raise RuntimeError( 

251 f"Failed to load API from documentation path '{doc_path}': {e}" 

252 ) 

253 

254 # Fallback to default API 

255 return SparcAPI() 

256 

257 

258# Utilities taken from vasp_interactive project 

259 

260 

261class TimeoutException(Exception): 

262 """Simple class for timeout""" 

263 

264 pass 

265 

266 

267@contextmanager 

268def time_limit(seconds): 

269 """Usage: 

270 try: 

271 with time_limit(60): 

272 do_something() 

273 except TimeoutException: 

274 raise 

275 """ 

276 

277 def signal_handler(signum, frame): 

278 raise TimeoutException("Timed out closing sparc process.") 

279 

280 signal.signal(signal.SIGALRM, signal_handler) 

281 signal.alarm(seconds) 

282 try: 

283 yield 

284 finally: 

285 signal.alarm(0) 

286 

287 

288class ProcessReturned(Exception): 

289 """Simple class for process that has returned""" 

290 

291 pass 

292 

293 

294@contextmanager 

295def monitor_process(self, interval=1.0): 

296 """Usage: 

297 try: 

298 with monitor_process(process): 

299 do_something() 

300 except TimeoutException: 

301 raise 

302 """ 

303 

304 def signal_handler(signum, frame): 

305 raise ProcessReturned( 

306 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!" 

307 ) 

308 

309 def check_process(): 

310 while True: 

311 if self.process.poll() is not None: 

312 # signal.alarm(0) 

313 print("The process has exited") 

314 self.in_socket.close() 

315 print(self.in_socket) 

316 signal(signal.SIGALRM) 

317 raise ProcessReturned( 

318 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!" 

319 ) 

320 time.sleep(interval) 

321 

322 if self.process is None: 

323 raise RuntimeError("No process selected!") 

324 

325 signal.signal(signal.SIGALRM, signal_handler) 

326 monitor = threading.Thread(target=check_process) 

327 monitor.start() 

328 try: 

329 yield 

330 finally: 

331 monitor.join() 

332 

333 

334def _find_mpi_process(pid, mpi_program="mpirun", sparc_program="sparc"): 

335 """Recursively search children processes with PID=pid and return the one 

336 that mpirun (or synonyms) are the main command. 

337 

338 If srun is found as the process, need to use `scancel` to pause / resume the job step 

339 """ 

340 allowed_names = set(["mpirun", "mpiexec", "orterun", "oshrun", "shmemrun"]) 

341 allowed_sparc_names = set(["sparc"]) 

342 if mpi_program: 

343 allowed_names.add(mpi_program) 

344 if sparc_program: 

345 allowed_sparc_names.add(sparc_program) 

346 try: 

347 process_list = [psutil.Process(pid)] 

348 except psutil.NoSuchProcess: 

349 warn( 

350 "Psutil cannot locate the pid. Your sparc program may have already exited." 

351 ) 

352 match = {"type": None, "process": None} 

353 return match 

354 

355 process_list.extend(process_list[0].children(recursive=True)) 

356 mpi_candidates = [] 

357 match = {"type": None, "process": None} 

358 for proc in process_list: 

359 name = proc.name() 

360 if name in ["srun"]: 

361 match["type"] = "slurm" 

362 match["process"] = _locate_slurm_step(program=sparc_program) 

363 break 

364 elif proc.name() in allowed_names: 

365 # are the mpi process's direct children sparc binaries? 

366 children = proc.children() 

367 if len(children) > 0: 

368 if children[0].name() in allowed_sparc_names: 

369 mpi_candidates.append(proc) 

370 if len(mpi_candidates) > 1: 

371 warn( 

372 "More than 1 mpi processes are created. This may be a bug. I'll use the last one" 

373 ) 

374 if len(mpi_candidates) > 0: 

375 match["type"] = "mpi" 

376 match["process"] = mpi_candidates[-1] 

377 

378 return match 

379 

380 

381def _get_slurm_jobid(): 

382 jobid = os.environ.get("SLURM_JOB_ID", None) 

383 if jobid is None: 

384 jobid = os.environ.get("SLURM_JOBID", None) 

385 return jobid 

386 

387 

388def _locate_slurm_step(program="sparc"): 

389 """If slurm job system is involved, search for the slurm step id 

390 that matches vasp_std (or other vasp commands) 

391 

392 Steps: 

393 1. Look for SLURM_JOB_ID in current env 

394 2. Use `squeue` to locate the sparc step (latest) 

395 

396 squeue 

397 """ 

398 allowed_names = set(["sparc"]) 

399 if program: 

400 allowed_names.add(program) 

401 jobid = _get_slurm_jobid() 

402 if jobid is None: 

403 # TODO: convert warn to logger 

404 warn(("Cannot locate the SLURM job id.")) 

405 return None 

406 # Only 2 column output (jobid and jobname) 

407 cmds = ["squeue", "-s", "--job", str(jobid), "-o", "%.30i %.30j"] 

408 proc = _run_process(cmds, capture_output=True) 

409 output = proc.stdout.decode("utf8").split("\n") 

410 # print(output) 

411 candidates = [] 

412 # breakpoint() 

413 for line in output[1:]: 

414 try: 

415 stepid, name = line.strip().split() 

416 except Exception: 

417 continue 

418 if any([v in name for v in allowed_names]): 

419 candidates.append(stepid) 

420 

421 if len(candidates) > 1: 

422 warn("More than 1 slurm steps are found. I'll use the most recent one") 

423 if len(candidates) > 0: 

424 proc = candidates[0] 

425 else: 

426 proc = None 

427 return proc 

428 

429 

430def _slurm_signal(stepid, sig=signal.SIGTSTP): 

431 if isinstance(sig, (str,)): 

432 sig = str(sig) 

433 elif isinstance(sig, (int,)): 

434 sig = signal.Signals(sig).name 

435 else: 

436 sig = sig.name 

437 cmds = ["scancel", "-s", sig, str(stepid)] 

438 proc = _run_process(cmds, capture_output=True) 

439 output = proc.stdout.decode("utf8").split("\n") 

440 return 

441 

442 

443def _run_process(commands, shell=False, print_cmd=True, cwd=".", capture_output=False): 

444 """Wrap around subprocess.run 

445 Returns the process object 

446 """ 

447 full_cmd = " ".join(commands) 

448 if print_cmd: 

449 print(" ".join(commands)) 

450 if shell is False: 

451 proc = subprocess.run( 

452 commands, shell=shell, cwd=cwd, capture_output=capture_output 

453 ) 

454 else: 

455 proc = subprocess.run( 

456 full_cmd, shell=shell, cwd=cwd, capture_output=capture_output 

457 ) 

458 if proc.returncode == 0: 

459 return proc 

460 else: 

461 raise RuntimeError(f"Running {full_cmd} returned error code {proc.returncode}")