Coverage for sparc/socketio.py: 23%

211 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 16:19 +0000

1"""A i-PI compatible socket protocol implemented in SPARC 

2""" 

3import hashlib 

4import io 

5import os 

6import pickle 

7import random 

8import socket 

9import string 

10 

11import numpy as np 

12from ase.calculators.socketio import ( 

13 IPIProtocol, 

14 SocketClient, 

15 SocketClosed, 

16 SocketServer, 

17 actualunixsocketname, 

18) 

19from ase import units 

20 

21 

22def generate_random_socket_name(prefix="sparc_", length=6): 

23 """Generate a random socket name with the given prefix and a specified length of random hex characters.""" 

24 random_chars = "".join(random.choices(string.hexdigits.lower(), k=length)) 

25 return prefix + random_chars 

26 

27 

28class SPARCProtocol(IPIProtocol): 

29 """Accounting for row major layout in SPARC""" 

30 def sendposdata(self, cell, icell, positions): 

31 assert cell.size == 9 

32 assert icell.size == 9 

33 assert positions.size % 3 == 0 

34 

35 self.log(' sendposdata') 

36 self.sendmsg('POSDATA') 

37 self.send(cell / units.Bohr, np.float64) 

38 self.send(icell * units.Bohr, np.float64) 

39 self.send(len(positions), np.int32) 

40 self.send(positions / units.Bohr, np.float64) 

41 

42 """Extending the i-PI protocol to support extra routines""" 

43 

44 def send_string(self, msg, msglen=None): 

45 self.log(" send string", repr(msg)) 

46 # assert msg in self.statements, msg 

47 if msglen is None: 

48 msglen = len(msg) 

49 assert msglen >= len(msg) 

50 msg = msg.encode("ascii").ljust(msglen) 

51 self.send(msglen, np.int32) 

52 self.socket.sendall(msg) 

53 return 

54 

55 def send_object(self, obj): 

56 """Send an object dumped into pickle""" 

57 # We can use the highese protocol since the 

58 # python requirement >= 3.8 

59 pkl_bytes = pickle.dumps(obj, protocol=5) 

60 nbytes = len(pkl_bytes) 

61 md5_checksum = hashlib.md5(pkl_bytes) 

62 checksum_digest, checksum_count = ( 

63 md5_checksum.digest(), 

64 md5_checksum.digest_size, 

65 ) 

66 self.sendmsg("PKLOBJ") # To distinguish from other methods like INIT 

67 self.log(" pickle bytes to send: ", str(nbytes)) 

68 self.send(nbytes, np.int32) 

69 self.log(" sending pickle object....") 

70 self.socket.sendall(pkl_bytes) 

71 self.log(" sending md5 sum of size: ", str(checksum_count)) 

72 self.send(checksum_count, np.int32) 

73 self.log(" sending md5 sum..... ", str(checksum_count)) 

74 self.socket.sendall(checksum_digest) 

75 return 

76 

77 def recv_object(self, include_header=True): 

78 """Return a decoded file 

79 

80 include_header: should we receive the header or not 

81 """ 

82 if include_header: 

83 msg = self.recvmsg() 

84 assert ( 

85 msg.strip() == "PKLOBJ" 

86 ), f"Incorrect header {msg} received when calling recv_object method! Please contact the developers" 

87 nbytes = int(self.recv(1, np.int32)) 

88 self.log(" Will receive pickle object with n-bytes: ", nbytes) 

89 bytes_received = self._recvall(nbytes) 

90 checksum_nbytes = int(self.recv(1, np.int32)) 

91 self.log(" Will receive cheksum digest of nbytes:", checksum_nbytes) 

92 digest_received = self._recvall(checksum_nbytes) 

93 digest_calc = hashlib.md5(bytes_received).digest() 

94 minlen = min(len(digest_calc), len(digest_received)) 

95 assert ( 

96 digest_calc[:minlen] == digest_received[:minlen] 

97 ), "MD5 checksum for the received object does not match!" 

98 obj = pickle.loads(bytes_received) 

99 return obj 

100 

101 def send_param(self, name, value): 

102 """Send a specific param setting to SPARC 

103 This is just a test function to see how things may work 

104 

105 TODO: 

106 1) test with just 2 string values to see if SPARC can receive 

107 """ 

108 self.log(f"Setup param {name}, {value}") 

109 msg = self.status() 

110 assert msg == "READY", msg 

111 # Send message 

112 self.sendmsg("SETPARAM") 

113 # Send name 

114 self.send_string(str(name)) 

115 # Send value 

116 self.send_string(str(value)) 

117 # After this step, socket client should return READY 

118 return 

119 

120 def sendinit(self): 

121 """Mimick the old sendinit method but to provide atoms and params 

122 to the calculator instance. 

123 The actual behavior regarding how the calculator would be (re)-initialized, dependends on the implementation of recvinit 

124 """ 

125 self.log(" New sendinit for SPARC protocol") 

126 self.sendmsg("INIT") 

127 self.send(0, np.int32) # fallback 

128 msg_chars = [ord(c) for c in "NEWPROTO"] 

129 len_msg = len(msg_chars) 

130 self.send(len_msg, np.int32) 

131 self.send(msg_chars, np.byte) # initialization string 

132 return 

133 

134 def recvinit(self): 

135 """Fallback recvinit method""" 

136 return super().recvinit() 

137 

138 def calculate_new_protocol(self, atoms, params): 

139 atoms = atoms.copy() 

140 atoms.calc = None 

141 self.log(" calculate with new protocol") 

142 msg = self.status() 

143 # We don't know how NEEDINIT is supposed to work, but some codes 

144 # seem to be okay if we skip it and send the positions instead. 

145 if msg == "NEEDINIT": 

146 self.sendinit() 

147 self.send_object((atoms, params)) 

148 msg = self.status() 

149 cell = atoms.get_cell() 

150 positions = atoms.get_positions() # Original order 

151 assert msg == "READY", msg 

152 icell = np.linalg.pinv(cell).transpose() 

153 self.sendposdata(cell, icell, positions) 

154 msg = self.status() 

155 assert msg == "HAVEDATA", msg 

156 e, forces, virial, morebytes = self.sendrecv_force() 

157 r = dict(energy=e, forces=forces, virial=virial, morebytes=morebytes) 

158 # Additional data (e.g. parsed from file output) 

159 moredata = self.recv_object() 

160 return r, moredata 

161 

162 

163# TODO: make sure both calc are ok 

164 

165 

166class SPARCSocketServer(SocketServer): 

167 """We only implement the unix socket version due to simplicity 

168 

169 parent: the SPARC parent calculator 

170 """ 

171 

172 def __init__( 

173 self, 

174 port=None, 

175 unixsocket=None, 

176 timeout=None, 

177 log=None, 

178 parent=None 

179 # launch_client=None, 

180 ): 

181 super().__init__(port=port, unixsocket=unixsocket, timeout=timeout, log=log) 

182 self.parent = parent 

183 print("Parent : ", self.parent) 

184 if self.parent is not None: 

185 self.proc = self.parent.process 

186 else: 

187 self.proc = None 

188 print(self.proc) 

189 

190 # TODO: guard cases for non-unix sockets 

191 @property 

192 def socket_filename(self): 

193 return self.serversocket.getsockname() 

194 

195 @property 

196 def proc(self): 

197 if self.parent: 

198 return self.parent.process 

199 else: 

200 return None 

201 

202 @proc.setter 

203 def proc(self, value): 

204 return 

205 

206 def _accept(self): 

207 """Use the SPARCProtocol instead""" 

208 print(self.proc) 

209 super()._accept() 

210 print(self.proc) 

211 old_protocol = self.protocol 

212 # Swap the protocol 

213 if old_protocol: 

214 self.protocol = SPARCProtocol(self.clientsocket, txt=self.log) 

215 return 

216 

217 def send_atoms_and_params(self, atoms, params): 

218 """Update the atoms and parameters for the SPARC calculator 

219 The params should be assignable to SPARC.set 

220 

221 The calc for atoms is stripped for simplicity 

222 """ 

223 atoms.calc = None 

224 params = dict(params) 

225 pair = (atoms, params) 

226 self.protocol.send_object(pair) 

227 return 

228 

229 def calculate_origin_protocol(self, atoms): 

230 """Send geometry to client and return calculated things as dict. 

231 

232 This will block until client has established connection, then 

233 wait for the client to finish the calculation.""" 

234 assert not self._closed 

235 

236 # If we have not established connection yet, we must block 

237 # until the client catches up: 

238 if self.protocol is None: 

239 self._accept() 

240 return self.protocol.calculate(atoms.positions, atoms.cell) 

241 

242 def calculate_new_protocol(self, atoms, params={}): 

243 assert not self._closed 

244 

245 # If we have not established connection yet, we must block 

246 # until the client catches up: 

247 if self.protocol is None: 

248 self._accept() 

249 return self.protocol.calculate_new_protocol(atoms, params) 

250 

251 

252class SPARCSocketClient(SocketClient): 

253 def __init__( 

254 self, 

255 host="localhost", 

256 port=None, 

257 unixsocket=None, 

258 timeout=None, 

259 log=None, 

260 parent_calc=None 

261 # use_v2_protocol=True # If we should use the v2 SPARC protocol 

262 ): 

263 """Reload the socket client and use SPARCProtocol""" 

264 super().__init__( 

265 host=host, 

266 port=port, 

267 unixsocket=unixsocket, 

268 timeout=timeout, 

269 log=log, 

270 ) 

271 sock = self.protocol.socket 

272 self.protocol = SPARCProtocol(sock, txt=log) 

273 self.parent_calc = parent_calc # Track the actual calculator 

274 # TODO: make sure the client is compatible with the default socketclient 

275 

276 # We shall make NEEDINIT to be the default state 

277 # self.state = "NEEDINIT" 

278 

279 def calculate(self, atoms, use_stress): 

280 """Use the calculator instance""" 

281 if atoms.calc is None: 

282 atoms.calc = self.parent_calc 

283 return super().calculate(atoms, use_stress) 

284 

285 def irun(self, atoms, use_stress=True): 

286 """Reimplement single step calculation 

287 

288 We're free to implement the INIT method in socket protocol as most 

289 calculators do not involve using these. We can let the C-SPARC to spit out 

290 error about needinit error. 

291 """ 

292 # Discard positions received from POSDATA 

293 # if the server has send positions through recvinit method 

294 discard_posdata = False 

295 new_protocol = False 

296 try: 

297 while True: 

298 try: 

299 msg = self.protocol.recvmsg() 

300 except SocketClosed: 

301 # Server closed the connection, but we want to 

302 # exit gracefully anyway 

303 msg = "EXIT" 

304 

305 if msg == "EXIT": 

306 # Send stop signal to clients: 

307 self.comm.broadcast(np.ones(1, bool), 0) 

308 # (When otherwise exiting, things crashed and we should 

309 # let MPI_ABORT take care of the mess instead of trying 

310 # to synchronize the exit) 

311 return 

312 elif msg == "STATUS": 

313 self.protocol.sendmsg(self.state) 

314 elif msg == "POSDATA": 

315 assert self.state == "READY" 

316 assert ( 

317 atoms is not None 

318 ), "Your SPARCSocketClient isn't properly initialized!" 

319 cell, icell, positions = self.protocol.recvposdata() 

320 if not discard_posdata: 

321 atoms.cell[:] = cell 

322 atoms.positions[:] = positions 

323 

324 # At this stage, we should only rely on self.calculate 

325 # to continue the socket calculation or restart 

326 self.comm.broadcast(np.zeros(1, bool), 0) 

327 energy, forces, virial = self.calculate(atoms, use_stress) 

328 

329 self.state = "HAVEDATA" 

330 yield 

331 elif msg == "GETFORCE": 

332 assert self.state == "HAVEDATA", self.state 

333 self.protocol.sendforce(energy, forces, virial) 

334 if new_protocol: 

335 # TODO: implement more raw results 

336 raw_results = self.parent_calc.raw_results 

337 self.protocol.send_object(raw_results) 

338 self.state = "NEEDINIT" 

339 elif msg == "INIT": 

340 assert self.state == "NEEDINIT" 

341 # Fall back to the default socketio 

342 bead_index, initbytes = self.protocol.recvinit() 

343 # The parts below use the new sparc protocol 

344 print("Init bytes: ", initbytes) 

345 init_msg = "".join([chr(d) for d in initbytes]) 

346 if init_msg.startswith("NEWPROTO"): 

347 new_protocol = True 

348 recv_atoms, params = self.protocol.recv_object() 

349 print(recv_atoms, params) 

350 if params != {}: 

351 self.parent_calc.set(**params) 

352 # TODO: should we update the atoms directly or keep copy? 

353 atoms = recv_atoms 

354 atoms.calc = self.parent_calc 

355 discard_posdata = True 

356 self.state = "READY" 

357 else: 

358 raise KeyError("Bad message", msg) 

359 finally: 

360 self.close() 

361 

362 def run(self, atoms=None, use_stress=False): 

363 """Socket mode in SPARC should allow arbitrary start""" 

364 # As a default we shall start the SPARCSocketIO always in needinit mode 

365 if atoms is None: 

366 self.state = "NEEDINIT" 

367 for _ in self.irun(atoms=atoms, use_stress=use_stress): 

368 pass