Coverage for sparc/utils.py: 41%

236 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 01:13 +0000

1"""Utilities that are loosely related to core sparc functionalities 

2""" 

3import _thread 

4import io 

5import os 

6import re 

7import shutil 

8import signal 

9import subprocess 

10import sys 

11import tempfile 

12import threading 

13import time 

14from contextlib import contextmanager 

15from pathlib import Path 

16from typing import List, Optional, Union 

17from warnings import warn 

18 

19import numpy as np 

20import psutil 

21 

22# 2024-11-28 @alchem0x2a add config 

23from ase.config import cfg as _cfg 

24 

25from .api import SparcAPI 

26from .docparser import SparcDocParser 

27 

28 

29def deprecated(message): 

30 def decorator(func): 

31 def new_func(*args, **kwargs): 

32 warn( 

33 "Function {} is deprecated! {}".format(func.__name__, message), 

34 category=DeprecationWarning, 

35 ) 

36 return func(*args, **kwargs) 

37 

38 return new_func 

39 

40 return decorator 

41 

42 

43def compare_dict(d1, d2): 

44 """Helper function to compare dictionaries""" 

45 # Use symmetric difference to find keys which aren't shared 

46 # for python 2.7 compatibility 

47 if set(d1.keys()) ^ set(d2.keys()): 

48 return False 

49 

50 # Check for differences in values 

51 for key, value in d1.items(): 

52 if np.any(value != d2[key]): 

53 return False 

54 return True 

55 

56 

57def string2index(string: str) -> Union[int, slice, str]: 

58 """Convert index string to either int or slice 

59 This method is a copy of ase.io.formats.string2index 

60 """ 

61 # A quick fix for slice 

62 if isinstance(string, (list, slice)): 

63 return string 

64 if ":" not in string: 

65 # may contain database accessor 

66 try: 

67 return int(string) 

68 except ValueError: 

69 return string 

70 i: List[Optional[int]] = [] 

71 for s in string.split(":"): 

72 if s == "": 

73 i.append(None) 

74 else: 

75 i.append(int(s)) 

76 i += (3 - len(i)) * [None] 

77 return slice(*i) 

78 

79 

80def _find_default_sparc(): 

81 """Find the default sparc by $PATH and mpi location""" 

82 sparc_exe = shutil.which("sparc") 

83 

84 mpi_exe = shutil.which("mpirun") 

85 # TODO: more examples on pbs / lsf 

86 if mpi_exe is not None: 

87 try: 

88 num_cores = int( 

89 os.environ.get( 

90 "OMPI_COMM_WORLD_SIZE", 

91 os.environ.get( 

92 "OMPI_UNIVERSE_SIZE", 

93 os.environ.get("MPICH_RANK_REORDER_METHOD", ""), 

94 ).split(":")[-1], 

95 ) 

96 ) 

97 except Exception: 

98 num_cores = 1 

99 return sparc_exe, mpi_exe, num_cores 

100 

101 mpi_exe = shutil.which("srun") 

102 if mpi_exe is not None: 

103 # If srun is available, get the number of cores from the environment 

104 num_cores = int(os.environ.get("SLURM_JOB_CPUS_PER_NODE", 1)) 

105 return sparc_exe, mpi_exe, num_cores 

106 

107 return sparc_exe, None, 1 

108 

109 

110def h2gpts(h, cell_cv, idiv=4): 

111 """Convert a h-parameter (Angstrom) to gpts""" 

112 cell_cv = np.array(cell_cv) 

113 cell_lengths = np.linalg.norm(cell_cv, axis=1) 

114 grid = np.ceil(cell_lengths / h) 

115 grid = np.maximum(idiv, grid) 

116 return [int(a) for a in grid] 

117 

118 

119def cprint(content, color=None, bold=False, underline=False, **kwargs): 

120 """Color print wrapper for ansi terminal. 

121 Only a few color names are provided 

122 """ 

123 ansi_color = dict( 

124 HEADER="\033[95m", 

125 COMMENT="\033[90m", 

126 OKBLUE="\033[94m", 

127 OKGREEN="\033[92m", 

128 OKCYAN="\033[96m", 

129 WARNING="\033[93m", 

130 FAIL="\033[91m", 

131 ENDC="\033[0m", 

132 ) 

133 

134 style_codes = {"BOLD": "\033[1m", "UNDERLINE": "\033[4m"} 

135 

136 if color is None: 

137 output = content 

138 elif color.upper() in ansi_color.keys() and color.upper() != "ENDC": 

139 output = ansi_color[color.upper()] + content + ansi_color["ENDC"] 

140 else: 

141 raise ValueError( 

142 f"Unknown ANSI color name. Allowed values are {list(ansi_color.keys())}" 

143 ) 

144 

145 if bold: 

146 output = style_codes["BOLD"] + output + ansi_color["ENDC"] 

147 

148 if underline: 

149 output = style_codes["UNDERLINE"] + output + ansi_color["ENDC"] 

150 

151 print(output, **kwargs) 

152 return 

153 

154 

155def sanitize_path(path_string): 

156 """Sanitize path containing string in UNIX systems 

157 Returns a PosixPath object 

158 

159 It is recommended to use this sanitize function 

160 before passing any path-like strings from cfg parser 

161 """ 

162 if isinstance(path_string, str): 

163 path = os.path.expandvars(os.path.expanduser(path_string)) 

164 path = Path(path).resolve() 

165 else: 

166 path = Path(path_string).resolve() 

167 return path 

168 

169 

170def locate_api(json_file=None, doc_path=None, cfg=_cfg): 

171 """ 

172 Locate the SPARC API setup file with the following priority: 

173 1) If `json_file` is provided (either from parameter or cfg), use it directly. 

174 2) If `doc_path` is provided: 

175 a) Function parameter takes precedence. 

176 b) Environment variable SPARC_DOC_PATH comes next. 

177 c) Configuration section [sparc] in the ini file is the last resort. 

178 3) If both `json_file` and `doc_path` are provided, raise an exception. 

179 4) Fallback to the default API setup if neither is provided. 

180 """ 

181 parser = cfg.parser["sparc"] if "sparc" in cfg.parser else {} 

182 if not json_file: 

183 json_file = parser.get("json_schema") if parser else None 

184 

185 # Environment variable SPARC_DOC_PATH can overwrite user settings 

186 if not doc_path: 

187 doc_path = cfg.get("SPARC_DOC_PATH") 

188 

189 if not doc_path: 

190 doc_path = parser.get("doc_path") if parser else None 

191 

192 json_file = sanitize_path(json_file) if json_file else None 

193 doc_path = sanitize_path(doc_path) if doc_path else None 

194 

195 # Step 4: Ensure mutual exclusivity 

196 if json_file and doc_path: 

197 raise ValueError( 

198 "Cannot set both the path of json file and documentation" 

199 "at the same time!" 

200 ) 

201 

202 if json_file: 

203 if not json_file.is_file(): 

204 raise FileNotFoundError(f"JSON file '{json_file}' does not exist.") 

205 return SparcAPI(json_file) 

206 

207 if doc_path: 

208 if not doc_path.is_dir(): 

209 raise FileNotFoundError( 

210 f"Documentation path '{doc_path}' does not exist or is not a directory." 

211 ) 

212 try: 

213 with tempfile.TemporaryDirectory() as tmpdir: 

214 tmpfile = Path(tmpdir) / "parameters.json" 

215 with open(tmpfile, "w") as fd: 

216 fd.write( 

217 SparcDocParser.json_from_directory( 

218 doc_path, include_subdirs=True 

219 ) 

220 ) 

221 api = SparcAPI(tmpfile) 

222 api.source = {"path": str(doc_path.resolve()), "type": "latex"} 

223 return api 

224 except Exception as e: 

225 raise RuntimeError( 

226 f"Failed to load API from documentation path '{doc_path}': {e}" 

227 ) 

228 

229 # Fallback to default API 

230 return SparcAPI() 

231 

232 

233# Utilities taken from vasp_interactive project 

234 

235 

236class TimeoutException(Exception): 

237 """Simple class for timeout""" 

238 

239 pass 

240 

241 

242@contextmanager 

243def time_limit(seconds): 

244 """Usage: 

245 try: 

246 with time_limit(60): 

247 do_something() 

248 except TimeoutException: 

249 raise 

250 """ 

251 

252 def signal_handler(signum, frame): 

253 raise TimeoutException("Timed out closing sparc process.") 

254 

255 signal.signal(signal.SIGALRM, signal_handler) 

256 signal.alarm(seconds) 

257 try: 

258 yield 

259 finally: 

260 signal.alarm(0) 

261 

262 

263class ProcessReturned(Exception): 

264 """Simple class for process that has returned""" 

265 

266 pass 

267 

268 

269@contextmanager 

270def monitor_process(self, interval=1.0): 

271 """Usage: 

272 try: 

273 with monitor_process(process): 

274 do_something() 

275 except TimeoutException: 

276 raise 

277 """ 

278 

279 def signal_handler(signum, frame): 

280 raise ProcessReturned( 

281 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!" 

282 ) 

283 

284 def check_process(): 

285 while True: 

286 if self.process.poll() is not None: 

287 # signal.alarm(0) 

288 print("The process has exited") 

289 self.in_socket.close() 

290 print(self.in_socket) 

291 signal(signal.SIGALRM) 

292 raise ProcessReturned( 

293 f"Process {self.process.pid} has returned with exit code {self.process.poll()}!" 

294 ) 

295 time.sleep(interval) 

296 

297 if self.process is None: 

298 raise RuntimeError("No process selected!") 

299 

300 signal.signal(signal.SIGALRM, signal_handler) 

301 monitor = threading.Thread(target=check_process) 

302 monitor.start() 

303 try: 

304 yield 

305 finally: 

306 monitor.join() 

307 

308 

309def _find_mpi_process(pid, mpi_program="mpirun", sparc_program="sparc"): 

310 """Recursively search children processes with PID=pid and return the one 

311 that mpirun (or synonyms) are the main command. 

312 

313 If srun is found as the process, need to use `scancel` to pause / resume the job step 

314 """ 

315 allowed_names = set(["mpirun", "mpiexec", "orterun", "oshrun", "shmemrun"]) 

316 allowed_sparc_names = set(["sparc"]) 

317 if mpi_program: 

318 allowed_names.add(mpi_program) 

319 if sparc_program: 

320 allowed_sparc_names.add(sparc_program) 

321 try: 

322 process_list = [psutil.Process(pid)] 

323 except psutil.NoSuchProcess: 

324 warn( 

325 "Psutil cannot locate the pid. Your sparc program may have already exited." 

326 ) 

327 match = {"type": None, "process": None} 

328 return match 

329 

330 process_list.extend(process_list[0].children(recursive=True)) 

331 mpi_candidates = [] 

332 match = {"type": None, "process": None} 

333 for proc in process_list: 

334 name = proc.name() 

335 if name in ["srun"]: 

336 match["type"] = "slurm" 

337 match["process"] = _locate_slurm_step(program=sparc_program) 

338 break 

339 elif proc.name() in allowed_names: 

340 # are the mpi process's direct children sparc binaries? 

341 children = proc.children() 

342 if len(children) > 0: 

343 if children[0].name() in allowed_sparc_names: 

344 mpi_candidates.append(proc) 

345 if len(mpi_candidates) > 1: 

346 warn( 

347 "More than 1 mpi processes are created. This may be a bug. I'll use the last one" 

348 ) 

349 if len(mpi_candidates) > 0: 

350 match["type"] = "mpi" 

351 match["process"] = mpi_candidates[-1] 

352 

353 return match 

354 

355 

356def _get_slurm_jobid(): 

357 jobid = os.environ.get("SLURM_JOB_ID", None) 

358 if jobid is None: 

359 jobid = os.environ.get("SLURM_JOBID", None) 

360 return jobid 

361 

362 

363def _locate_slurm_step(program="sparc"): 

364 """If slurm job system is involved, search for the slurm step id 

365 that matches vasp_std (or other vasp commands) 

366 

367 Steps: 

368 1. Look for SLURM_JOB_ID in current env 

369 2. Use `squeue` to locate the sparc step (latest) 

370 

371 squeue 

372 """ 

373 allowed_names = set(["sparc"]) 

374 if program: 

375 allowed_names.add(program) 

376 jobid = _get_slurm_jobid() 

377 if jobid is None: 

378 # TODO: convert warn to logger 

379 warn(("Cannot locate the SLURM job id.")) 

380 return None 

381 # Only 2 column output (jobid and jobname) 

382 cmds = ["squeue", "-s", "--job", str(jobid), "-o", "%.30i %.30j"] 

383 proc = _run_process(cmds, capture_output=True) 

384 output = proc.stdout.decode("utf8").split("\n") 

385 # print(output) 

386 candidates = [] 

387 # breakpoint() 

388 for line in output[1:]: 

389 try: 

390 stepid, name = line.strip().split() 

391 except Exception: 

392 continue 

393 if any([v in name for v in allowed_names]): 

394 candidates.append(stepid) 

395 

396 if len(candidates) > 1: 

397 warn("More than 1 slurm steps are found. I'll use the most recent one") 

398 if len(candidates) > 0: 

399 proc = candidates[0] 

400 else: 

401 proc = None 

402 return proc 

403 

404 

405def _slurm_signal(stepid, sig=signal.SIGTSTP): 

406 if isinstance(sig, (str,)): 

407 sig = str(sig) 

408 elif isinstance(sig, (int,)): 

409 sig = signal.Signals(sig).name 

410 else: 

411 sig = sig.name 

412 cmds = ["scancel", "-s", sig, str(stepid)] 

413 proc = _run_process(cmds, capture_output=True) 

414 output = proc.stdout.decode("utf8").split("\n") 

415 return 

416 

417 

418def _run_process(commands, shell=False, print_cmd=True, cwd=".", capture_output=False): 

419 """Wrap around subprocess.run 

420 Returns the process object 

421 """ 

422 full_cmd = " ".join(commands) 

423 if print_cmd: 

424 print(" ".join(commands)) 

425 if shell is False: 

426 proc = subprocess.run( 

427 commands, shell=shell, cwd=cwd, capture_output=capture_output 

428 ) 

429 else: 

430 proc = subprocess.run( 

431 full_cmd, shell=shell, cwd=cwd, capture_output=capture_output 

432 ) 

433 if proc.returncode == 0: 

434 return proc 

435 else: 

436 raise RuntimeError(f"Running {full_cmd} returned error code {proc.returncode}")