Coverage for sparc/quicktest.py: 80%

216 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 01:13 +0000

1"""A simple test module for sparc python api 

2Usage: 

3python -m sparc.quicktest 

4""" 

5from pathlib import Path 

6 

7from ase.data import chemical_symbols 

8 

9from .utils import cprint 

10 

11 

12class BaseTest(object): 

13 """Base class for all tests providing functionalities 

14 

15 Each child class will implement its own `run_test` method to 

16 update the `result`, `error_handling` and `info` fields. 

17 

18 If you wish to include a simple error handling message for each 

19 child class, add a line starting `Error handling` follows by the 

20 helper message at the end of the docstring 

21 """ 

22 

23 def __init__(self): 

24 self.result = None 

25 self.error_msg = "" 

26 self.error_handling = "" 

27 self.info = {} 

28 

29 @property 

30 @classmethod 

31 def dislay_name(cls): 

32 return cls.__name__ 

33 

34 def display_docstring(self): 

35 """Convert the class's docstring to error handling""" 

36 doc = self.__class__.__doc__ 

37 error_handling_lines = [] 

38 begin_record = False 

39 indent = 0 # indentation for the "Error handling" line 

40 if doc: 

41 for line in doc.splitlines(): 

42 if line.lstrip().startswith("Error handling"): 

43 if begin_record is True: 

44 msg = ( 

45 "There are multiple Error handlings " 

46 "in the docstring of " 

47 f"{self.__class__.__name__}." 

48 ) 

49 raise ValueError(msg) 

50 begin_record = True 

51 indent = len(line) - len(line.lstrip()) 

52 elif begin_record is True: 

53 current_indent = len(line) - len(line.lstrip()) 

54 line = line.strip() 

55 if len(line) > 0: # Only add non-empty lines 

56 # Compensate for the extra indentation 

57 # if current_indent > indent 

58 spaces = max(0, current_indent - indent) * " " 

59 error_handling_lines.append(spaces + line) 

60 else: 

61 pass 

62 else: 

63 pass 

64 error_handling_string = "\n".join(error_handling_lines) 

65 return error_handling_string 

66 

67 def make_test(self): 

68 """Each class should implement ways to update `result` and `info`""" 

69 raise NotImplementedError 

70 

71 def run_test(self): 

72 """Run test and update result etc. 

73 If result is False, update the error handling message 

74 """ 

75 try: 

76 self.make_test() 

77 except Exception as e: 

78 self.result = False 

79 self.error_msg = str(e) 

80 

81 if self.result is None: 

82 raise ValueError( 

83 "Test result is not updated for " f"{self.__class__.__name__} !" 

84 ) 

85 if self.result is False: 

86 self.error_handling = self.display_docstring() 

87 return 

88 

89 

90class ImportTest(BaseTest): 

91 """Check if external io format `sparc` can be registered in ASE 

92 

93 Error handling: 

94 - Make sure SPARC-X-API is installed via conda / pip / setuptools 

95 - If you wish to work on SPARC-X-API source code, use `pip install -e` 

96 instead of setting up $PYTHON_PATH 

97 """ 

98 

99 display_name = "Import" 

100 

101 def make_test(self): 

102 cprint("Testing import...", color="COMMENT") 

103 from ase.io.formats import ioformats 

104 

105 self.result = "sparc" in ioformats.keys() 

106 if self.result is False: 

107 self.error_msg = ( 

108 "Cannot find `sparc` as a valid " "external ioformat for ASE." 

109 ) 

110 return 

111 

112 

113class PspTest(BaseTest): 

114 """Check at least one directory of Pseudopotential files exist 

115 info[`psp_dir`] contains the first psp dir found on system 

116 # TODO: check if all psp files can be located 

117 #TODO: update to the ASE 3.23 config method 

118 

119 Error handling: 

120 - Default version of psp files can be downloaded by 

121 `python -m sparc.download_data` 

122 - Alternatively, specify the variable $SPARC_PSP_PATH 

123 to the custom pseudopotential files 

124 """ 

125 

126 display_name = "Pseudopotential" 

127 

128 def make_test(self): 

129 cprint("Testing pseudo potential path...", color="COMMENT") 

130 import tempfile 

131 

132 from .io import SparcBundle 

133 from .sparc_parsers.pseudopotential import find_pseudo_path 

134 

135 with tempfile.TemporaryDirectory() as tmpdir: 

136 sb = SparcBundle(directory=tmpdir) 

137 psp_dir = sb.psp_dir 

138 

139 if psp_dir is not None: 

140 psp_dir = Path(psp_dir) 

141 self.info["psp_dir"] = f"{psp_dir.resolve()}" 

142 if not psp_dir.is_dir(): 

143 self.result = False 

144 self.error_msg = ( 

145 "Pseudopotential files path " f"{psp_dir.resolve()} does not exist." 

146 ) 

147 else: 

148 missing_elements = [] 

149 # Default psp file are 1-57 + 72-83 

150 spms_elements = chemical_symbols[1:58] + chemical_symbols[72:84] 

151 for element in spms_elements: 

152 try: 

153 find_pseudo_path(element, psp_dir) 

154 except Exception: 

155 missing_elements.append(element) 

156 if len(missing_elements) == 0: 

157 self.result = True 

158 else: 

159 self.result = False 

160 self.error_msg = ( 

161 "Pseudopotential files for " 

162 f"{len(missing_elements)} elements are " 

163 "missing or incompatible: \n" 

164 f"{missing_elements}" 

165 ) 

166 else: 

167 self.info["psp_dir"] = "None" 

168 self.result = False 

169 self.error_msg = ( 

170 "Pseudopotential file path not defined and/or " 

171 "default psp files are incomplete." 

172 ) 

173 return 

174 

175 

176class ApiTest(BaseTest): 

177 """Check if the API can be loaded, and store the Schema version. 

178 

179 # TODO: consider change to schema instead of api 

180 # TODO: allow config to change json file path 

181 Error handling: 

182 - Check if default JSON schema exists in 

183 `<sparc-x-api-root>/sparc_json_api/parameters.json` 

184 - Use $SPARC_DOC_PATH to specify the raw LaTeX files 

185 """ 

186 

187 display_name = "JSON API" 

188 

189 def make_test(self): 

190 from .utils import locate_api 

191 

192 try: 

193 api = locate_api() 

194 version = api.sparc_version 

195 self.result = True 

196 self.info["api_version"] = version 

197 self.info["api_source"] = api.source 

198 except Exception as e: 

199 self.result = False 

200 self.info["api_version"] = "NaN" 

201 self.info["api_source"] = "not found" 

202 self.error_msg = ( 

203 "Error when locating a JSON schema or " 

204 f"LaTeX source files for SPARC. Error is {e}" 

205 ) 

206 return 

207 

208 

209class CommandTest(BaseTest): 

210 """Check validity of command to run SPARC calculation. This test 

211 also checks sparc version and socket compatibility 

212 

213 # TODO: check ase 3.23 config with separate binary 

214 Error handling: 

215 - The command prefix to run SPARC calculation should look like 

216 `<mpi instructions> <sparc binary>` 

217 - Use $ASE_SPARC_COMMAND to set the command string 

218 - Check HPC resources and compatibility (e.g. `srun` on a login node) 

219 """ 

220 

221 display_name = "SPARC Command" 

222 

223 def make_test(self): 

224 import tempfile 

225 

226 from sparc.calculator import SPARC 

227 

228 self.info["command"] = "" 

229 self.info["sparc_version"] = "" 

230 

231 with tempfile.TemporaryDirectory() as tmpdir: 

232 calc = SPARC(directory=tmpdir) 

233 # Step 1: validity of sparc command 

234 try: 

235 test_cmd = calc._make_command() 

236 self.result = True 

237 self.info["command"] = test_cmd 

238 except Exception as e: 

239 self.result = False 

240 self.info["command"] = "not found" 

241 self.error_msg = f"Error setting SPARC command:\n{e}" 

242 

243 # Step 2: check SPARC binary version 

244 try: 

245 sparc_version = calc.detect_sparc_version() 

246 # Version may be None if failed to retrieve 

247 if sparc_version: 

248 self.result = self.result & True 

249 self.info["sparc_version"] = sparc_version 

250 else: 

251 self.result = False 

252 self.info["sparc_version"] = "NaN" 

253 self.error_msg += "\n" if len(self.error_msg) > 0 else "" 

254 self.error_msg += "Error detecting SPARC version" 

255 except Exception as e: 

256 self.result = False 

257 self.info["sparc_version"] = "NaN" 

258 self.error_msg += "\n" if len(self.error_msg) > 0 else "" 

259 self.error_msg += f"\nError detecting SPARC version:\n{e}" 

260 return 

261 

262 

263class FileIOCalcTest(BaseTest): 

264 """Run a simple calculation in File IO mode. 

265 

266 # TODO: check ase 3.23 config 

267 Error handling: 

268 - Check if settings for pseudopotential files are correct 

269 - Check if SPARC binary exists and functional 

270 - Check if specific HPC requirements are met: 

271 (module files, libraries, parallel settings, resources) 

272 """ 

273 

274 display_name = "Calculation (File I/O)" 

275 

276 def make_test(self): 

277 import tempfile 

278 

279 from ase.build import bulk 

280 

281 from sparc.calculator import SPARC 

282 

283 # 1x Al atoms with super bad calculation condition 

284 al = bulk("Al", cubic=False) 

285 

286 with tempfile.TemporaryDirectory() as tmpdir: 

287 calc = SPARC(h=0.3, kpts=(1, 1, 1), tol_scf=1e-3, directory=tmpdir) 

288 try: 

289 al.calc = calc 

290 al.get_potential_energy() 

291 self.result = True 

292 except Exception as e: 

293 self.result = False 

294 self.error_msg = "Simple calculation in file I/O mode failed: \n" f"{e}" 

295 return 

296 

297 

298class SocketCalcTest(BaseTest): 

299 """Run a simple calculation in Socket mode (UNIX socket). 

300 

301 # TODO: check ase 3.23 config 

302 Error handling: 

303 - The same as error handling in file I/O calculation test 

304 - Check if SPARC binary supports socket 

305 """ 

306 

307 display_name = "Calculation (UNIX socket)" 

308 

309 def make_test(self): 

310 import tempfile 

311 

312 from ase.build import bulk 

313 

314 from sparc.calculator import SPARC 

315 

316 # Check SPARC binary socket compatibility 

317 with tempfile.TemporaryDirectory() as tmpdir: 

318 calc = SPARC(directory=tmpdir) 

319 try: 

320 sparc_compat = calc.detect_socket_compatibility() 

321 self.info["sparc_socket_compatibility"] = sparc_compat 

322 except Exception: 

323 self.info["sparc_socket_compatibility"] = False 

324 

325 # 1x Al atoms with super bad calculation condition 

326 al = bulk("Al", cubic=False) 

327 

328 with tempfile.TemporaryDirectory() as tmpdir: 

329 calc = SPARC( 

330 h=0.3, kpts=(1, 1, 1), tol_scf=1e-3, use_socket=True, directory=tmpdir 

331 ) 

332 try: 

333 al.calc = calc 

334 al.get_potential_energy() 

335 self.result = True 

336 except Exception as e: 

337 self.result = False 

338 self.error_msg = ( 

339 "Simple calculation in socket mode (UNIX socket) failed: \n" f"{e}" 

340 ) 

341 return 

342 

343 

344def main(): 

345 cprint( 

346 ("Performing a quick test on your " "SPARC and python API setup"), 

347 color=None, 

348 ) 

349 

350 test_classes = [ 

351 ImportTest(), 

352 PspTest(), 

353 ApiTest(), 

354 CommandTest(), 

355 FileIOCalcTest(), 

356 SocketCalcTest(), 

357 ] 

358 

359 system_info = {} 

360 for test in test_classes: 

361 test.run_test() 

362 system_info.update(test.info) 

363 

364 # Header section 

365 print("-" * 80) 

366 cprint( 

367 "Summary", 

368 bold=True, 

369 color="HEADER", 

370 ) 

371 print("-" * 80) 

372 cprint("Configuration", bold=True, color="HEADER") 

373 for key, val in system_info.items(): 

374 print(f"{key}: {val}") 

375 

376 print("-" * 80) 

377 # Body section 

378 cprint("Tests", bold=True, color="HEADER") 

379 

380 print_wiki = False 

381 for test in test_classes: 

382 cprint(f"{test.display_name}:", bold=True, end="") 

383 if test.result is True: 

384 cprint(" PASS", color="OKGREEN") 

385 else: 

386 cprint(" FAIL", color="FAIL") 

387 print_wiki = True 

388 

389 print("-" * 80) 

390 # Error information section 

391 has_print_error_header = False 

392 for test in test_classes: 

393 if (test.result is False) and (test.error_handling): 

394 if has_print_error_header is False: 

395 cprint( 

396 ("Some tests failed! " "Please check the following information.\n"), 

397 color="FAIL", 

398 ) 

399 has_print_error_header = True 

400 cprint(f"{test.display_name}:", bold=True) 

401 cprint(f"{test.error_msg}", color="FAIL") 

402 print(test.error_handling) 

403 print("\n") 

404 

405 if print_wiki: 

406 print("-" * 80) 

407 cprint( 

408 "Please check additional information from:\n" 

409 "1. SPARC's documentation: https://github.com/SPARC-X/SPARC/blob/master/doc/Manual.pdf \n" 

410 "2. Python API documentation: https://sparc-x.github.io/SPARC-X-API\n", 

411 color=None, 

412 ) 

413 

414 

415if __name__ == "__main__": 

416 main()