Coverage for sparc/download_data.py: 88%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 01:13 +0000

1"""Download the pseudopotential and other related files after sparc-x-api is installed 

2 

3Run: 

4 

5python -m sparc.download_data 

6""" 

7 

8import hashlib 

9import shutil 

10import tempfile 

11import zipfile 

12from io import BytesIO 

13from pathlib import Path 

14 

15# import urllib.request 

16from urllib.request import urlopen 

17 

18from .common import psp_dir 

19 

20sparc_tag = "b702c1061400a2d23c0e223e32182609d7958156" 

21sparc_source_url = "https://github.com/SPARC-X/SPARC/archive/{sparc_tag}.zip" 

22# This is a all_psp8_checksum 

23all_psp8_checksum = "5ef42c4a81733a90b0e080b771c5a73a" 

24 

25 

26def download_psp(sparc_tag=sparc_tag, psp_dir=psp_dir): 

27 """Download the external PSPs into the sparc/psp folder 

28 

29 Arguments: 

30 sparc_tag (str): Commit hash or git tag for the psp files 

31 psp_dir (str or PosixPath): Directory to download the psp files 

32 """ 

33 if is_psp_download_complete(): 

34 print("PSPs have been successfully downloaded!") 

35 return 

36 download_url = sparc_source_url.format(sparc_tag=sparc_tag) 

37 print(f"Download link: {download_url}") 

38 with tempfile.TemporaryDirectory() as tmpdir: 

39 tmpdir = Path(tmpdir) 

40 with urlopen(download_url) as zipresp: 

41 with zipfile.ZipFile(BytesIO(zipresp.read())) as zfile: 

42 zfile.extractall(tmpdir) 

43 # print(list(os.walk(tmpdir))) 

44 source_dir = next(tmpdir.glob("SPARC-*/psps")) 

45 print(f"Found source_dir at {source_dir}") 

46 if not source_dir.is_dir(): 

47 raise FileNotFoundError("Error downloading or extracting zip") 

48 print(f"Moving psp files to {psp_dir}") 

49 for ext in ("*.psp8", "*.psp", "*.pot"): 

50 for pspf in source_dir.glob(ext): 

51 print(f"Found {pspf} --> {psp_dir}") 

52 shutil.copy(pspf, psp_dir) 

53 if not is_psp_download_complete(psp_dir): 

54 raise RuntimeError(f"Files downloaded to {psp_dir} have different checksums!") 

55 return 

56 

57 

58def checksum_all(psp_dir=psp_dir, extension="*.psp8"): 

59 """Checksum all the files under the psp_dir to make sure the psp8 files 

60 are the same as intended 

61 

62 Arguments: 

63 psp_dir (str or PosixPath): Directory for the psp files 

64 extension (str): Search pattern for the psp files, either '.psp', '.psp8' or '.pot' 

65 

66 Returns: 

67 str: Checksum for all the files concatenated 

68 """ 

69 checker = hashlib.md5() 

70 psp_dir = Path(psp_dir) 

71 # Use sorted to make sure file order is correct 

72 for filename in sorted(psp_dir.glob(extension)): 

73 # Open the file in binary mode and update the group checksum 

74 print(f"Checking {filename}") 

75 with open(filename, "r") as f: 

76 f_checker = hashlib.md5() 

77 content = f.read().encode("utf8") 

78 f_checker.update(content) 

79 checker.update(f_checker.hexdigest().encode("ascii")) 

80 final_checksum = checker.hexdigest() 

81 print(f"Final checksum is {final_checksum}") 

82 return final_checksum 

83 

84 

85def is_psp_download_complete(psp_dir=psp_dir): 

86 return checksum_all(psp_dir) == all_psp8_checksum 

87 

88 

89if __name__ == "__main__": 

90 print("Running command-line psp downloader") 

91 download_psp()