Coverage for sparc/socketio.py: 23%
211 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 16:19 +0000
1"""A i-PI compatible socket protocol implemented in SPARC
2"""
3import hashlib
4import io
5import os
6import pickle
7import random
8import socket
9import string
11import numpy as np
12from ase.calculators.socketio import (
13 IPIProtocol,
14 SocketClient,
15 SocketClosed,
16 SocketServer,
17 actualunixsocketname,
18)
19from ase import units
22def generate_random_socket_name(prefix="sparc_", length=6):
23 """Generate a random socket name with the given prefix and a specified length of random hex characters."""
24 random_chars = "".join(random.choices(string.hexdigits.lower(), k=length))
25 return prefix + random_chars
28class SPARCProtocol(IPIProtocol):
29 """Accounting for row major layout in SPARC"""
30 def sendposdata(self, cell, icell, positions):
31 assert cell.size == 9
32 assert icell.size == 9
33 assert positions.size % 3 == 0
35 self.log(' sendposdata')
36 self.sendmsg('POSDATA')
37 self.send(cell / units.Bohr, np.float64)
38 self.send(icell * units.Bohr, np.float64)
39 self.send(len(positions), np.int32)
40 self.send(positions / units.Bohr, np.float64)
42 """Extending the i-PI protocol to support extra routines"""
44 def send_string(self, msg, msglen=None):
45 self.log(" send string", repr(msg))
46 # assert msg in self.statements, msg
47 if msglen is None:
48 msglen = len(msg)
49 assert msglen >= len(msg)
50 msg = msg.encode("ascii").ljust(msglen)
51 self.send(msglen, np.int32)
52 self.socket.sendall(msg)
53 return
55 def send_object(self, obj):
56 """Send an object dumped into pickle"""
57 # We can use the highese protocol since the
58 # python requirement >= 3.8
59 pkl_bytes = pickle.dumps(obj, protocol=5)
60 nbytes = len(pkl_bytes)
61 md5_checksum = hashlib.md5(pkl_bytes)
62 checksum_digest, checksum_count = (
63 md5_checksum.digest(),
64 md5_checksum.digest_size,
65 )
66 self.sendmsg("PKLOBJ") # To distinguish from other methods like INIT
67 self.log(" pickle bytes to send: ", str(nbytes))
68 self.send(nbytes, np.int32)
69 self.log(" sending pickle object....")
70 self.socket.sendall(pkl_bytes)
71 self.log(" sending md5 sum of size: ", str(checksum_count))
72 self.send(checksum_count, np.int32)
73 self.log(" sending md5 sum..... ", str(checksum_count))
74 self.socket.sendall(checksum_digest)
75 return
77 def recv_object(self, include_header=True):
78 """Return a decoded file
80 include_header: should we receive the header or not
81 """
82 if include_header:
83 msg = self.recvmsg()
84 assert (
85 msg.strip() == "PKLOBJ"
86 ), f"Incorrect header {msg} received when calling recv_object method! Please contact the developers"
87 nbytes = int(self.recv(1, np.int32))
88 self.log(" Will receive pickle object with n-bytes: ", nbytes)
89 bytes_received = self._recvall(nbytes)
90 checksum_nbytes = int(self.recv(1, np.int32))
91 self.log(" Will receive cheksum digest of nbytes:", checksum_nbytes)
92 digest_received = self._recvall(checksum_nbytes)
93 digest_calc = hashlib.md5(bytes_received).digest()
94 minlen = min(len(digest_calc), len(digest_received))
95 assert (
96 digest_calc[:minlen] == digest_received[:minlen]
97 ), "MD5 checksum for the received object does not match!"
98 obj = pickle.loads(bytes_received)
99 return obj
101 def send_param(self, name, value):
102 """Send a specific param setting to SPARC
103 This is just a test function to see how things may work
105 TODO:
106 1) test with just 2 string values to see if SPARC can receive
107 """
108 self.log(f"Setup param {name}, {value}")
109 msg = self.status()
110 assert msg == "READY", msg
111 # Send message
112 self.sendmsg("SETPARAM")
113 # Send name
114 self.send_string(str(name))
115 # Send value
116 self.send_string(str(value))
117 # After this step, socket client should return READY
118 return
120 def sendinit(self):
121 """Mimick the old sendinit method but to provide atoms and params
122 to the calculator instance.
123 The actual behavior regarding how the calculator would be (re)-initialized, dependends on the implementation of recvinit
124 """
125 self.log(" New sendinit for SPARC protocol")
126 self.sendmsg("INIT")
127 self.send(0, np.int32) # fallback
128 msg_chars = [ord(c) for c in "NEWPROTO"]
129 len_msg = len(msg_chars)
130 self.send(len_msg, np.int32)
131 self.send(msg_chars, np.byte) # initialization string
132 return
134 def recvinit(self):
135 """Fallback recvinit method"""
136 return super().recvinit()
138 def calculate_new_protocol(self, atoms, params):
139 atoms = atoms.copy()
140 atoms.calc = None
141 self.log(" calculate with new protocol")
142 msg = self.status()
143 # We don't know how NEEDINIT is supposed to work, but some codes
144 # seem to be okay if we skip it and send the positions instead.
145 if msg == "NEEDINIT":
146 self.sendinit()
147 self.send_object((atoms, params))
148 msg = self.status()
149 cell = atoms.get_cell()
150 positions = atoms.get_positions() # Original order
151 assert msg == "READY", msg
152 icell = np.linalg.pinv(cell).transpose()
153 self.sendposdata(cell, icell, positions)
154 msg = self.status()
155 assert msg == "HAVEDATA", msg
156 e, forces, virial, morebytes = self.sendrecv_force()
157 r = dict(energy=e, forces=forces, virial=virial, morebytes=morebytes)
158 # Additional data (e.g. parsed from file output)
159 moredata = self.recv_object()
160 return r, moredata
163# TODO: make sure both calc are ok
166class SPARCSocketServer(SocketServer):
167 """We only implement the unix socket version due to simplicity
169 parent: the SPARC parent calculator
170 """
172 def __init__(
173 self,
174 port=None,
175 unixsocket=None,
176 timeout=None,
177 log=None,
178 parent=None
179 # launch_client=None,
180 ):
181 super().__init__(port=port, unixsocket=unixsocket, timeout=timeout, log=log)
182 self.parent = parent
183 print("Parent : ", self.parent)
184 if self.parent is not None:
185 self.proc = self.parent.process
186 else:
187 self.proc = None
188 print(self.proc)
190 # TODO: guard cases for non-unix sockets
191 @property
192 def socket_filename(self):
193 return self.serversocket.getsockname()
195 @property
196 def proc(self):
197 if self.parent:
198 return self.parent.process
199 else:
200 return None
202 @proc.setter
203 def proc(self, value):
204 return
206 def _accept(self):
207 """Use the SPARCProtocol instead"""
208 print(self.proc)
209 super()._accept()
210 print(self.proc)
211 old_protocol = self.protocol
212 # Swap the protocol
213 if old_protocol:
214 self.protocol = SPARCProtocol(self.clientsocket, txt=self.log)
215 return
217 def send_atoms_and_params(self, atoms, params):
218 """Update the atoms and parameters for the SPARC calculator
219 The params should be assignable to SPARC.set
221 The calc for atoms is stripped for simplicity
222 """
223 atoms.calc = None
224 params = dict(params)
225 pair = (atoms, params)
226 self.protocol.send_object(pair)
227 return
229 def calculate_origin_protocol(self, atoms):
230 """Send geometry to client and return calculated things as dict.
232 This will block until client has established connection, then
233 wait for the client to finish the calculation."""
234 assert not self._closed
236 # If we have not established connection yet, we must block
237 # until the client catches up:
238 if self.protocol is None:
239 self._accept()
240 return self.protocol.calculate(atoms.positions, atoms.cell)
242 def calculate_new_protocol(self, atoms, params={}):
243 assert not self._closed
245 # If we have not established connection yet, we must block
246 # until the client catches up:
247 if self.protocol is None:
248 self._accept()
249 return self.protocol.calculate_new_protocol(atoms, params)
252class SPARCSocketClient(SocketClient):
253 def __init__(
254 self,
255 host="localhost",
256 port=None,
257 unixsocket=None,
258 timeout=None,
259 log=None,
260 parent_calc=None
261 # use_v2_protocol=True # If we should use the v2 SPARC protocol
262 ):
263 """Reload the socket client and use SPARCProtocol"""
264 super().__init__(
265 host=host,
266 port=port,
267 unixsocket=unixsocket,
268 timeout=timeout,
269 log=log,
270 )
271 sock = self.protocol.socket
272 self.protocol = SPARCProtocol(sock, txt=log)
273 self.parent_calc = parent_calc # Track the actual calculator
274 # TODO: make sure the client is compatible with the default socketclient
276 # We shall make NEEDINIT to be the default state
277 # self.state = "NEEDINIT"
279 def calculate(self, atoms, use_stress):
280 """Use the calculator instance"""
281 if atoms.calc is None:
282 atoms.calc = self.parent_calc
283 return super().calculate(atoms, use_stress)
285 def irun(self, atoms, use_stress=True):
286 """Reimplement single step calculation
288 We're free to implement the INIT method in socket protocol as most
289 calculators do not involve using these. We can let the C-SPARC to spit out
290 error about needinit error.
291 """
292 # Discard positions received from POSDATA
293 # if the server has send positions through recvinit method
294 discard_posdata = False
295 new_protocol = False
296 try:
297 while True:
298 try:
299 msg = self.protocol.recvmsg()
300 except SocketClosed:
301 # Server closed the connection, but we want to
302 # exit gracefully anyway
303 msg = "EXIT"
305 if msg == "EXIT":
306 # Send stop signal to clients:
307 self.comm.broadcast(np.ones(1, bool), 0)
308 # (When otherwise exiting, things crashed and we should
309 # let MPI_ABORT take care of the mess instead of trying
310 # to synchronize the exit)
311 return
312 elif msg == "STATUS":
313 self.protocol.sendmsg(self.state)
314 elif msg == "POSDATA":
315 assert self.state == "READY"
316 assert (
317 atoms is not None
318 ), "Your SPARCSocketClient isn't properly initialized!"
319 cell, icell, positions = self.protocol.recvposdata()
320 if not discard_posdata:
321 atoms.cell[:] = cell
322 atoms.positions[:] = positions
324 # At this stage, we should only rely on self.calculate
325 # to continue the socket calculation or restart
326 self.comm.broadcast(np.zeros(1, bool), 0)
327 energy, forces, virial = self.calculate(atoms, use_stress)
329 self.state = "HAVEDATA"
330 yield
331 elif msg == "GETFORCE":
332 assert self.state == "HAVEDATA", self.state
333 self.protocol.sendforce(energy, forces, virial)
334 if new_protocol:
335 # TODO: implement more raw results
336 raw_results = self.parent_calc.raw_results
337 self.protocol.send_object(raw_results)
338 self.state = "NEEDINIT"
339 elif msg == "INIT":
340 assert self.state == "NEEDINIT"
341 # Fall back to the default socketio
342 bead_index, initbytes = self.protocol.recvinit()
343 # The parts below use the new sparc protocol
344 print("Init bytes: ", initbytes)
345 init_msg = "".join([chr(d) for d in initbytes])
346 if init_msg.startswith("NEWPROTO"):
347 new_protocol = True
348 recv_atoms, params = self.protocol.recv_object()
349 print(recv_atoms, params)
350 if params != {}:
351 self.parent_calc.set(**params)
352 # TODO: should we update the atoms directly or keep copy?
353 atoms = recv_atoms
354 atoms.calc = self.parent_calc
355 discard_posdata = True
356 self.state = "READY"
357 else:
358 raise KeyError("Bad message", msg)
359 finally:
360 self.close()
362 def run(self, atoms=None, use_stress=False):
363 """Socket mode in SPARC should allow arbitrary start"""
364 # As a default we shall start the SPARCSocketIO always in needinit mode
365 if atoms is None:
366 self.state = "NEEDINIT"
367 for _ in self.irun(atoms=atoms, use_stress=use_stress):
368 pass