Coverage for sparc/socketio.py: 23%
199 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 01:13 +0000
1"""A i-PI compatible socket protocol implemented in SPARC
2"""
3import hashlib
4import io
5import os
6import pickle
7import random
8import socket
9import string
11import numpy as np
12from ase.calculators.socketio import (
13 IPIProtocol,
14 SocketClient,
15 SocketClosed,
16 SocketServer,
17 actualunixsocketname,
18)
21def generate_random_socket_name(prefix="sparc_", length=6):
22 """Generate a random socket name with the given prefix and a specified length of random hex characters."""
23 random_chars = "".join(random.choices(string.hexdigits.lower(), k=length))
24 return prefix + random_chars
27class SPARCProtocol(IPIProtocol):
28 """Extending the i-PI protocol to support extra routines"""
30 def send_string(self, msg, msglen=None):
31 self.log(" send string", repr(msg))
32 # assert msg in self.statements, msg
33 if msglen is None:
34 msglen = len(msg)
35 assert msglen >= len(msg)
36 msg = msg.encode("ascii").ljust(msglen)
37 self.send(msglen, np.int32)
38 self.socket.sendall(msg)
39 return
41 def send_object(self, obj):
42 """Send an object dumped into pickle"""
43 # We can use the highese protocol since the
44 # python requirement >= 3.8
45 pkl_bytes = pickle.dumps(obj, protocol=5)
46 nbytes = len(pkl_bytes)
47 md5_checksum = hashlib.md5(pkl_bytes)
48 checksum_digest, checksum_count = (
49 md5_checksum.digest(),
50 md5_checksum.digest_size,
51 )
52 self.sendmsg("PKLOBJ") # To distinguish from other methods like INIT
53 self.log(" pickle bytes to send: ", str(nbytes))
54 self.send(nbytes, np.int32)
55 self.log(" sending pickle object....")
56 self.socket.sendall(pkl_bytes)
57 self.log(" sending md5 sum of size: ", str(checksum_count))
58 self.send(checksum_count, np.int32)
59 self.log(" sending md5 sum..... ", str(checksum_count))
60 self.socket.sendall(checksum_digest)
61 return
63 def recv_object(self, include_header=True):
64 """Return a decoded file
66 include_header: should we receive the header or not
67 """
68 if include_header:
69 msg = self.recvmsg()
70 assert (
71 msg.strip() == "PKLOBJ"
72 ), f"Incorrect header {msg} received when calling recv_object method! Please contact the developers"
73 nbytes = int(self.recv(1, np.int32))
74 self.log(" Will receive pickle object with n-bytes: ", nbytes)
75 bytes_received = self._recvall(nbytes)
76 checksum_nbytes = int(self.recv(1, np.int32))
77 self.log(" Will receive cheksum digest of nbytes:", checksum_nbytes)
78 digest_received = self._recvall(checksum_nbytes)
79 digest_calc = hashlib.md5(bytes_received).digest()
80 minlen = min(len(digest_calc), len(digest_received))
81 assert (
82 digest_calc[:minlen] == digest_received[:minlen]
83 ), "MD5 checksum for the received object does not match!"
84 obj = pickle.loads(bytes_received)
85 return obj
87 def send_param(self, name, value):
88 """Send a specific param setting to SPARC
89 This is just a test function to see how things may work
91 TODO:
92 1) test with just 2 string values to see if SPARC can receive
93 """
94 self.log(f"Setup param {name}, {value}")
95 msg = self.status()
96 assert msg == "READY", msg
97 # Send message
98 self.sendmsg("SETPARAM")
99 # Send name
100 self.send_string(str(name))
101 # Send value
102 self.send_string(str(value))
103 # After this step, socket client should return READY
104 return
106 def sendinit(self):
107 """Mimick the old sendinit method but to provide atoms and params
108 to the calculator instance.
109 The actual behavior regarding how the calculator would be (re)-initialized, dependends on the implementation of recvinit
110 """
111 self.log(" New sendinit for SPARC protocol")
112 self.sendmsg("INIT")
113 self.send(0, np.int32) # fallback
114 msg_chars = [ord(c) for c in "NEWPROTO"]
115 len_msg = len(msg_chars)
116 self.send(len_msg, np.int32)
117 self.send(msg_chars, np.byte) # initialization string
118 return
120 def recvinit(self):
121 """Fallback recvinit method"""
122 return super().recvinit()
124 def calculate_new_protocol(self, atoms, params):
125 atoms = atoms.copy()
126 atoms.calc = None
127 self.log(" calculate with new protocol")
128 msg = self.status()
129 # We don't know how NEEDINIT is supposed to work, but some codes
130 # seem to be okay if we skip it and send the positions instead.
131 if msg == "NEEDINIT":
132 self.sendinit()
133 self.send_object((atoms, params))
134 msg = self.status()
135 cell = atoms.get_cell()
136 positions = atoms.get_positions() # Original order
137 assert msg == "READY", msg
138 icell = np.linalg.pinv(cell).transpose()
139 self.sendposdata(cell, icell, positions)
140 msg = self.status()
141 assert msg == "HAVEDATA", msg
142 e, forces, virial, morebytes = self.sendrecv_force()
143 r = dict(energy=e, forces=forces, virial=virial, morebytes=morebytes)
144 # Additional data (e.g. parsed from file output)
145 moredata = self.recv_object()
146 return r, moredata
149# TODO: make sure both calc are ok
152class SPARCSocketServer(SocketServer):
153 """We only implement the unix socket version due to simplicity
155 parent: the SPARC parent calculator
156 """
158 def __init__(
159 self,
160 port=None,
161 unixsocket=None,
162 timeout=None,
163 log=None,
164 parent=None
165 # launch_client=None,
166 ):
167 super().__init__(port=port, unixsocket=unixsocket, timeout=timeout, log=log)
168 self.parent = parent
169 print("Parent : ", self.parent)
170 if self.parent is not None:
171 self.proc = self.parent.process
172 else:
173 self.proc = None
174 print(self.proc)
176 # TODO: guard cases for non-unix sockets
177 @property
178 def socket_filename(self):
179 return self.serversocket.getsockname()
181 @property
182 def proc(self):
183 if self.parent:
184 return self.parent.process
185 else:
186 return None
188 @proc.setter
189 def proc(self, value):
190 return
192 def _accept(self):
193 """Use the SPARCProtocol instead"""
194 print(self.proc)
195 super()._accept()
196 print(self.proc)
197 old_protocol = self.protocol
198 # Swap the protocol
199 if old_protocol:
200 self.protocol = SPARCProtocol(self.clientsocket, txt=self.log)
201 return
203 def send_atoms_and_params(self, atoms, params):
204 """Update the atoms and parameters for the SPARC calculator
205 The params should be assignable to SPARC.set
207 The calc for atoms is stripped for simplicity
208 """
209 atoms.calc = None
210 params = dict(params)
211 pair = (atoms, params)
212 self.protocol.send_object(pair)
213 return
215 def calculate_origin_protocol(self, atoms):
216 """Send geometry to client and return calculated things as dict.
218 This will block until client has established connection, then
219 wait for the client to finish the calculation."""
220 assert not self._closed
222 # If we have not established connection yet, we must block
223 # until the client catches up:
224 if self.protocol is None:
225 self._accept()
226 return self.protocol.calculate(atoms.positions, atoms.cell)
228 def calculate_new_protocol(self, atoms, params={}):
229 assert not self._closed
231 # If we have not established connection yet, we must block
232 # until the client catches up:
233 if self.protocol is None:
234 self._accept()
235 return self.protocol.calculate_new_protocol(atoms, params)
238class SPARCSocketClient(SocketClient):
239 def __init__(
240 self,
241 host="localhost",
242 port=None,
243 unixsocket=None,
244 timeout=None,
245 log=None,
246 parent_calc=None
247 # use_v2_protocol=True # If we should use the v2 SPARC protocol
248 ):
249 """Reload the socket client and use SPARCProtocol"""
250 super().__init__(
251 host=host,
252 port=port,
253 unixsocket=unixsocket,
254 timeout=timeout,
255 log=log,
256 )
257 sock = self.protocol.socket
258 self.protocol = SPARCProtocol(sock, txt=log)
259 self.parent_calc = parent_calc # Track the actual calculator
260 # TODO: make sure the client is compatible with the default socketclient
262 # We shall make NEEDINIT to be the default state
263 # self.state = "NEEDINIT"
265 def calculate(self, atoms, use_stress):
266 """Use the calculator instance"""
267 if atoms.calc is None:
268 atoms.calc = self.parent_calc
269 return super().calculate(atoms, use_stress)
271 def irun(self, atoms, use_stress=True):
272 """Reimplement single step calculation
274 We're free to implement the INIT method in socket protocol as most
275 calculators do not involve using these. We can let the C-SPARC to spit out
276 error about needinit error.
277 """
278 # Discard positions received from POSDATA
279 # if the server has send positions through recvinit method
280 discard_posdata = False
281 new_protocol = False
282 try:
283 while True:
284 try:
285 msg = self.protocol.recvmsg()
286 except SocketClosed:
287 # Server closed the connection, but we want to
288 # exit gracefully anyway
289 msg = "EXIT"
291 if msg == "EXIT":
292 # Send stop signal to clients:
293 self.comm.broadcast(np.ones(1, bool), 0)
294 # (When otherwise exiting, things crashed and we should
295 # let MPI_ABORT take care of the mess instead of trying
296 # to synchronize the exit)
297 return
298 elif msg == "STATUS":
299 self.protocol.sendmsg(self.state)
300 elif msg == "POSDATA":
301 assert self.state == "READY"
302 assert (
303 atoms is not None
304 ), "Your SPARCSocketClient isn't properly initialized!"
305 cell, icell, positions = self.protocol.recvposdata()
306 if not discard_posdata:
307 atoms.cell[:] = cell
308 atoms.positions[:] = positions
310 # At this stage, we should only rely on self.calculate
311 # to continue the socket calculation or restart
312 self.comm.broadcast(np.zeros(1, bool), 0)
313 energy, forces, virial = self.calculate(atoms, use_stress)
315 self.state = "HAVEDATA"
316 yield
317 elif msg == "GETFORCE":
318 assert self.state == "HAVEDATA", self.state
319 self.protocol.sendforce(energy, forces, virial)
320 if new_protocol:
321 # TODO: implement more raw results
322 raw_results = self.parent_calc.raw_results
323 self.protocol.send_object(raw_results)
324 self.state = "NEEDINIT"
325 elif msg == "INIT":
326 assert self.state == "NEEDINIT"
327 # Fall back to the default socketio
328 bead_index, initbytes = self.protocol.recvinit()
329 # The parts below use the new sparc protocol
330 print("Init bytes: ", initbytes)
331 init_msg = "".join([chr(d) for d in initbytes])
332 if init_msg.startswith("NEWPROTO"):
333 new_protocol = True
334 recv_atoms, params = self.protocol.recv_object()
335 print(recv_atoms, params)
336 if params != {}:
337 self.parent_calc.set(**params)
338 # TODO: should we update the atoms directly or keep copy?
339 atoms = recv_atoms
340 atoms.calc = self.parent_calc
341 discard_posdata = True
342 self.state = "READY"
343 else:
344 raise KeyError("Bad message", msg)
345 finally:
346 self.close()
348 def run(self, atoms=None, use_stress=False):
349 """Socket mode in SPARC should allow arbitrary start"""
350 # As a default we shall start the SPARCSocketIO always in needinit mode
351 if atoms is None:
352 self.state = "NEEDINIT"
353 for _ in self.irun(atoms=atoms, use_stress=use_stress):
354 pass