Coverage for sparc/socketio.py: 23%

199 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 01:13 +0000

1"""A i-PI compatible socket protocol implemented in SPARC 

2""" 

3import hashlib 

4import io 

5import os 

6import pickle 

7import random 

8import socket 

9import string 

10 

11import numpy as np 

12from ase.calculators.socketio import ( 

13 IPIProtocol, 

14 SocketClient, 

15 SocketClosed, 

16 SocketServer, 

17 actualunixsocketname, 

18) 

19 

20 

21def generate_random_socket_name(prefix="sparc_", length=6): 

22 """Generate a random socket name with the given prefix and a specified length of random hex characters.""" 

23 random_chars = "".join(random.choices(string.hexdigits.lower(), k=length)) 

24 return prefix + random_chars 

25 

26 

27class SPARCProtocol(IPIProtocol): 

28 """Extending the i-PI protocol to support extra routines""" 

29 

30 def send_string(self, msg, msglen=None): 

31 self.log(" send string", repr(msg)) 

32 # assert msg in self.statements, msg 

33 if msglen is None: 

34 msglen = len(msg) 

35 assert msglen >= len(msg) 

36 msg = msg.encode("ascii").ljust(msglen) 

37 self.send(msglen, np.int32) 

38 self.socket.sendall(msg) 

39 return 

40 

41 def send_object(self, obj): 

42 """Send an object dumped into pickle""" 

43 # We can use the highese protocol since the 

44 # python requirement >= 3.8 

45 pkl_bytes = pickle.dumps(obj, protocol=5) 

46 nbytes = len(pkl_bytes) 

47 md5_checksum = hashlib.md5(pkl_bytes) 

48 checksum_digest, checksum_count = ( 

49 md5_checksum.digest(), 

50 md5_checksum.digest_size, 

51 ) 

52 self.sendmsg("PKLOBJ") # To distinguish from other methods like INIT 

53 self.log(" pickle bytes to send: ", str(nbytes)) 

54 self.send(nbytes, np.int32) 

55 self.log(" sending pickle object....") 

56 self.socket.sendall(pkl_bytes) 

57 self.log(" sending md5 sum of size: ", str(checksum_count)) 

58 self.send(checksum_count, np.int32) 

59 self.log(" sending md5 sum..... ", str(checksum_count)) 

60 self.socket.sendall(checksum_digest) 

61 return 

62 

63 def recv_object(self, include_header=True): 

64 """Return a decoded file 

65 

66 include_header: should we receive the header or not 

67 """ 

68 if include_header: 

69 msg = self.recvmsg() 

70 assert ( 

71 msg.strip() == "PKLOBJ" 

72 ), f"Incorrect header {msg} received when calling recv_object method! Please contact the developers" 

73 nbytes = int(self.recv(1, np.int32)) 

74 self.log(" Will receive pickle object with n-bytes: ", nbytes) 

75 bytes_received = self._recvall(nbytes) 

76 checksum_nbytes = int(self.recv(1, np.int32)) 

77 self.log(" Will receive cheksum digest of nbytes:", checksum_nbytes) 

78 digest_received = self._recvall(checksum_nbytes) 

79 digest_calc = hashlib.md5(bytes_received).digest() 

80 minlen = min(len(digest_calc), len(digest_received)) 

81 assert ( 

82 digest_calc[:minlen] == digest_received[:minlen] 

83 ), "MD5 checksum for the received object does not match!" 

84 obj = pickle.loads(bytes_received) 

85 return obj 

86 

87 def send_param(self, name, value): 

88 """Send a specific param setting to SPARC 

89 This is just a test function to see how things may work 

90 

91 TODO: 

92 1) test with just 2 string values to see if SPARC can receive 

93 """ 

94 self.log(f"Setup param {name}, {value}") 

95 msg = self.status() 

96 assert msg == "READY", msg 

97 # Send message 

98 self.sendmsg("SETPARAM") 

99 # Send name 

100 self.send_string(str(name)) 

101 # Send value 

102 self.send_string(str(value)) 

103 # After this step, socket client should return READY 

104 return 

105 

106 def sendinit(self): 

107 """Mimick the old sendinit method but to provide atoms and params 

108 to the calculator instance. 

109 The actual behavior regarding how the calculator would be (re)-initialized, dependends on the implementation of recvinit 

110 """ 

111 self.log(" New sendinit for SPARC protocol") 

112 self.sendmsg("INIT") 

113 self.send(0, np.int32) # fallback 

114 msg_chars = [ord(c) for c in "NEWPROTO"] 

115 len_msg = len(msg_chars) 

116 self.send(len_msg, np.int32) 

117 self.send(msg_chars, np.byte) # initialization string 

118 return 

119 

120 def recvinit(self): 

121 """Fallback recvinit method""" 

122 return super().recvinit() 

123 

124 def calculate_new_protocol(self, atoms, params): 

125 atoms = atoms.copy() 

126 atoms.calc = None 

127 self.log(" calculate with new protocol") 

128 msg = self.status() 

129 # We don't know how NEEDINIT is supposed to work, but some codes 

130 # seem to be okay if we skip it and send the positions instead. 

131 if msg == "NEEDINIT": 

132 self.sendinit() 

133 self.send_object((atoms, params)) 

134 msg = self.status() 

135 cell = atoms.get_cell() 

136 positions = atoms.get_positions() # Original order 

137 assert msg == "READY", msg 

138 icell = np.linalg.pinv(cell).transpose() 

139 self.sendposdata(cell, icell, positions) 

140 msg = self.status() 

141 assert msg == "HAVEDATA", msg 

142 e, forces, virial, morebytes = self.sendrecv_force() 

143 r = dict(energy=e, forces=forces, virial=virial, morebytes=morebytes) 

144 # Additional data (e.g. parsed from file output) 

145 moredata = self.recv_object() 

146 return r, moredata 

147 

148 

149# TODO: make sure both calc are ok 

150 

151 

152class SPARCSocketServer(SocketServer): 

153 """We only implement the unix socket version due to simplicity 

154 

155 parent: the SPARC parent calculator 

156 """ 

157 

158 def __init__( 

159 self, 

160 port=None, 

161 unixsocket=None, 

162 timeout=None, 

163 log=None, 

164 parent=None 

165 # launch_client=None, 

166 ): 

167 super().__init__(port=port, unixsocket=unixsocket, timeout=timeout, log=log) 

168 self.parent = parent 

169 print("Parent : ", self.parent) 

170 if self.parent is not None: 

171 self.proc = self.parent.process 

172 else: 

173 self.proc = None 

174 print(self.proc) 

175 

176 # TODO: guard cases for non-unix sockets 

177 @property 

178 def socket_filename(self): 

179 return self.serversocket.getsockname() 

180 

181 @property 

182 def proc(self): 

183 if self.parent: 

184 return self.parent.process 

185 else: 

186 return None 

187 

188 @proc.setter 

189 def proc(self, value): 

190 return 

191 

192 def _accept(self): 

193 """Use the SPARCProtocol instead""" 

194 print(self.proc) 

195 super()._accept() 

196 print(self.proc) 

197 old_protocol = self.protocol 

198 # Swap the protocol 

199 if old_protocol: 

200 self.protocol = SPARCProtocol(self.clientsocket, txt=self.log) 

201 return 

202 

203 def send_atoms_and_params(self, atoms, params): 

204 """Update the atoms and parameters for the SPARC calculator 

205 The params should be assignable to SPARC.set 

206 

207 The calc for atoms is stripped for simplicity 

208 """ 

209 atoms.calc = None 

210 params = dict(params) 

211 pair = (atoms, params) 

212 self.protocol.send_object(pair) 

213 return 

214 

215 def calculate_origin_protocol(self, atoms): 

216 """Send geometry to client and return calculated things as dict. 

217 

218 This will block until client has established connection, then 

219 wait for the client to finish the calculation.""" 

220 assert not self._closed 

221 

222 # If we have not established connection yet, we must block 

223 # until the client catches up: 

224 if self.protocol is None: 

225 self._accept() 

226 return self.protocol.calculate(atoms.positions, atoms.cell) 

227 

228 def calculate_new_protocol(self, atoms, params={}): 

229 assert not self._closed 

230 

231 # If we have not established connection yet, we must block 

232 # until the client catches up: 

233 if self.protocol is None: 

234 self._accept() 

235 return self.protocol.calculate_new_protocol(atoms, params) 

236 

237 

238class SPARCSocketClient(SocketClient): 

239 def __init__( 

240 self, 

241 host="localhost", 

242 port=None, 

243 unixsocket=None, 

244 timeout=None, 

245 log=None, 

246 parent_calc=None 

247 # use_v2_protocol=True # If we should use the v2 SPARC protocol 

248 ): 

249 """Reload the socket client and use SPARCProtocol""" 

250 super().__init__( 

251 host=host, 

252 port=port, 

253 unixsocket=unixsocket, 

254 timeout=timeout, 

255 log=log, 

256 ) 

257 sock = self.protocol.socket 

258 self.protocol = SPARCProtocol(sock, txt=log) 

259 self.parent_calc = parent_calc # Track the actual calculator 

260 # TODO: make sure the client is compatible with the default socketclient 

261 

262 # We shall make NEEDINIT to be the default state 

263 # self.state = "NEEDINIT" 

264 

265 def calculate(self, atoms, use_stress): 

266 """Use the calculator instance""" 

267 if atoms.calc is None: 

268 atoms.calc = self.parent_calc 

269 return super().calculate(atoms, use_stress) 

270 

271 def irun(self, atoms, use_stress=True): 

272 """Reimplement single step calculation 

273 

274 We're free to implement the INIT method in socket protocol as most 

275 calculators do not involve using these. We can let the C-SPARC to spit out 

276 error about needinit error. 

277 """ 

278 # Discard positions received from POSDATA 

279 # if the server has send positions through recvinit method 

280 discard_posdata = False 

281 new_protocol = False 

282 try: 

283 while True: 

284 try: 

285 msg = self.protocol.recvmsg() 

286 except SocketClosed: 

287 # Server closed the connection, but we want to 

288 # exit gracefully anyway 

289 msg = "EXIT" 

290 

291 if msg == "EXIT": 

292 # Send stop signal to clients: 

293 self.comm.broadcast(np.ones(1, bool), 0) 

294 # (When otherwise exiting, things crashed and we should 

295 # let MPI_ABORT take care of the mess instead of trying 

296 # to synchronize the exit) 

297 return 

298 elif msg == "STATUS": 

299 self.protocol.sendmsg(self.state) 

300 elif msg == "POSDATA": 

301 assert self.state == "READY" 

302 assert ( 

303 atoms is not None 

304 ), "Your SPARCSocketClient isn't properly initialized!" 

305 cell, icell, positions = self.protocol.recvposdata() 

306 if not discard_posdata: 

307 atoms.cell[:] = cell 

308 atoms.positions[:] = positions 

309 

310 # At this stage, we should only rely on self.calculate 

311 # to continue the socket calculation or restart 

312 self.comm.broadcast(np.zeros(1, bool), 0) 

313 energy, forces, virial = self.calculate(atoms, use_stress) 

314 

315 self.state = "HAVEDATA" 

316 yield 

317 elif msg == "GETFORCE": 

318 assert self.state == "HAVEDATA", self.state 

319 self.protocol.sendforce(energy, forces, virial) 

320 if new_protocol: 

321 # TODO: implement more raw results 

322 raw_results = self.parent_calc.raw_results 

323 self.protocol.send_object(raw_results) 

324 self.state = "NEEDINIT" 

325 elif msg == "INIT": 

326 assert self.state == "NEEDINIT" 

327 # Fall back to the default socketio 

328 bead_index, initbytes = self.protocol.recvinit() 

329 # The parts below use the new sparc protocol 

330 print("Init bytes: ", initbytes) 

331 init_msg = "".join([chr(d) for d in initbytes]) 

332 if init_msg.startswith("NEWPROTO"): 

333 new_protocol = True 

334 recv_atoms, params = self.protocol.recv_object() 

335 print(recv_atoms, params) 

336 if params != {}: 

337 self.parent_calc.set(**params) 

338 # TODO: should we update the atoms directly or keep copy? 

339 atoms = recv_atoms 

340 atoms.calc = self.parent_calc 

341 discard_posdata = True 

342 self.state = "READY" 

343 else: 

344 raise KeyError("Bad message", msg) 

345 finally: 

346 self.close() 

347 

348 def run(self, atoms=None, use_stress=False): 

349 """Socket mode in SPARC should allow arbitrary start""" 

350 # As a default we shall start the SPARCSocketIO always in needinit mode 

351 if atoms is None: 

352 self.state = "NEEDINIT" 

353 for _ in self.irun(atoms=atoms, use_stress=use_stress): 

354 pass