Coverage for sparc/sparc_parsers/inpt.py: 94%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 01:13 +0000

1import numpy as np 

2from ase.units import Bohr 

3 

4# Safe wrappers for both string and fd 

5from ase.utils import reader, writer 

6 

7from ..api import SparcAPI 

8from .utils import read_block_input, strip_comments 

9 

10defaultAPI = SparcAPI() 

11 

12 

13@reader 

14def _read_inpt(fileobj, validator=defaultAPI): 

15 contents = fileobj.read() 

16 # label = get_label(fileobj, ".ion") 

17 data, comments = strip_comments(contents) 

18 # We do not read the cell at this time! 

19 

20 # find the index for all atom type lines. They should be at the 

21 # top of their block 

22 inpt_blocks = read_block_input(data, validator=validator) 

23 return {"inpt": {"params": inpt_blocks, "comments": comments}} 

24 

25 

26@writer 

27def _write_inpt(fileobj, data_dict, validator=defaultAPI): 

28 if "inpt" not in data_dict: 

29 raise ValueError("Your dict does not contain inpt section!") 

30 

31 inpt_dict = data_dict["inpt"] 

32 

33 if "params" not in inpt_dict: 

34 raise ValueError("Input dict for inpt file does not have `params` field!") 

35 

36 comments = inpt_dict.get("comments", []) 

37 banner = "Input File Generated By SPARC ASE Calculator" 

38 if len(comments) == 0: 

39 comments = [banner] 

40 elif "ASE" not in comments[0]: 

41 comments = [banner] + comments 

42 for line in comments: 

43 fileobj.write(f"# {line}\n") 

44 fileobj.write("\n") 

45 params = inpt_dict["params"] 

46 for key, val in params.items(): 

47 # TODO: can we add a multiline argument? 

48 val_string = validator.convert_value_to_string(key, val) 

49 if (val_string.count("\n") > 0) or ( 

50 key 

51 in [ 

52 "LATVEC", 

53 ] 

54 ): 

55 output = f"{key}:\n{val_string}\n" 

56 else: 

57 output = f"{key}: {val_string}\n" 

58 fileobj.write(output) 

59 return 

60 

61 

62def _inpt_cell_to_ase_cell(data_dict): 

63 """Convert the inpt cell convention to a real cell (in ASE Angstrom unit) 

64 

65 Arguments: 

66 inpt_blocks: an already validated inpt file blocks dict 

67 (i.e. parsed by _read_inpt) 

68 

69 Returns: 

70 cell in ASE-unit 

71 """ 

72 inpt_blocks = data_dict["inpt"]["params"] 

73 if ("CELL" in inpt_blocks) and ("LATVEC_SCALE" in inpt_blocks): 

74 # TODO: customize the exception class 

75 # TODO: how do we convert the rule from doc? 

76 raise ValueError("LATVEC_SCALE and CELL cannot be specified simultaneously!") 

77 

78 # if "CELL" in inpt_blocks: 

79 # cell = np.eye(inpt_blocks["CELL"]) * Bohr 

80 if "LATVEC" not in inpt_blocks: 

81 if ("CELL" in inpt_blocks) or ("LATVEC_SCALE" in inpt_blocks): 

82 lat_array = np.eye(3) * Bohr 

83 else: 

84 raise KeyError( 

85 "LATVEC is missing in inpt file and no CELL / LATVEC_SCALE provided!" 

86 ) 

87 else: 

88 lat_array = np.array(inpt_blocks["LATVEC"]) * Bohr 

89 

90 # LATVEC_SCALE: just multiplies 

91 if "LATVEC_SCALE" in inpt_blocks: 

92 scale = np.array(inpt_blocks["LATVEC_SCALE"]) 

93 cell = (lat_array.T * scale).T 

94 

95 # CELL: the lengths are in the LATVEC directions 

96 # TODO: the documentation about CELL is a bit messy. Is CELL always orthogonal? 

97 # Anyway the lat_array when CELL is none should be ok 

98 elif "CELL" in inpt_blocks: 

99 scale = np.array(inpt_blocks["CELL"]) 

100 unit_lat_array = ( 

101 lat_array / np.linalg.norm(lat_array, axis=1, keepdims=True) * Bohr 

102 ) 

103 cell = (unit_lat_array.T * scale).T 

104 else: 

105 cell = lat_array 

106 return cell