Coverage for sparc/download_data.py: 88%
52 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
1"""Download the pseudopotential and other related files after sparc-x-api is installed
3Run:
5python -m sparc.download_data
6"""
8import hashlib
9import shutil
10import tempfile
11import zipfile
12from io import BytesIO
13from pathlib import Path
15# import urllib.request
16from urllib.request import urlopen
18from .common import psp_dir
20sparc_tag = "b702c1061400a2d23c0e223e32182609d7958156"
21sparc_source_url = "https://github.com/SPARC-X/SPARC/archive/{sparc_tag}.zip"
22# This is a all_psp8_checksum
23all_psp8_checksum = "5ef42c4a81733a90b0e080b771c5a73a"
26def download_psp(sparc_tag=sparc_tag, psp_dir=psp_dir):
27 """Download the external PSPs into the sparc/psp folder
29 Arguments:
30 sparc_tag (str): Commit hash or git tag for the psp files
31 psp_dir (str or PosixPath): Directory to download the psp files
32 """
33 if is_psp_download_complete():
34 print("PSPs have been successfully downloaded!")
35 return
36 download_url = sparc_source_url.format(sparc_tag=sparc_tag)
37 print(f"Download link: {download_url}")
38 with tempfile.TemporaryDirectory() as tmpdir:
39 tmpdir = Path(tmpdir)
40 with urlopen(download_url) as zipresp:
41 with zipfile.ZipFile(BytesIO(zipresp.read())) as zfile:
42 zfile.extractall(tmpdir)
43 # print(list(os.walk(tmpdir)))
44 source_dir = next(tmpdir.glob("SPARC-*/psps"))
45 print(f"Found source_dir at {source_dir}")
46 if not source_dir.is_dir():
47 raise FileNotFoundError("Error downloading or extracting zip")
48 print(f"Moving psp files to {psp_dir}")
49 for ext in ("*.psp8", "*.psp", "*.pot"):
50 for pspf in source_dir.glob(ext):
51 print(f"Found {pspf} --> {psp_dir}")
52 shutil.copy(pspf, psp_dir)
53 if not is_psp_download_complete(psp_dir):
54 raise RuntimeError(f"Files downloaded to {psp_dir} have different checksums!")
55 return
58def checksum_all(psp_dir=psp_dir, extension="*.psp8"):
59 """Checksum all the files under the psp_dir to make sure the psp8 files
60 are the same as intended
62 Arguments:
63 psp_dir (str or PosixPath): Directory for the psp files
64 extension (str): Search pattern for the psp files, either '.psp', '.psp8' or '.pot'
66 Returns:
67 str: Checksum for all the files concatenated
68 """
69 checker = hashlib.md5()
70 psp_dir = Path(psp_dir)
71 # Use sorted to make sure file order is correct
72 for filename in sorted(psp_dir.glob(extension)):
73 # Open the file in binary mode and update the group checksum
74 print(f"Checking {filename}")
75 with open(filename, "r") as f:
76 f_checker = hashlib.md5()
77 content = f.read().encode("utf8")
78 f_checker.update(content)
79 checker.update(f_checker.hexdigest().encode("ascii"))
80 final_checksum = checker.hexdigest()
81 print(f"Final checksum is {final_checksum}")
82 return final_checksum
85def is_psp_download_complete(psp_dir=psp_dir):
86 return checksum_all(psp_dir) == all_psp8_checksum
89if __name__ == "__main__":
90 print("Running command-line psp downloader")
91 download_psp()