Source code for sparc.socketio

"""A i-PI compatible socket protocol implemented in SPARC
"""
import hashlib
import io
import os
import pickle
import random
import socket
import string

import numpy as np
from ase.calculators.socketio import (
    IPIProtocol,
    SocketClient,
    SocketClosed,
    SocketServer,
    actualunixsocketname,
)


[docs] def generate_random_socket_name(prefix="sparc_", length=6): """Generate a random socket name with the given prefix and a specified length of random hex characters.""" random_chars = "".join(random.choices(string.hexdigits.lower(), k=length)) return prefix + random_chars
[docs] class SPARCProtocol(IPIProtocol): """Extending the i-PI protocol to support extra routines"""
[docs] def send_string(self, msg, msglen=None): self.log(" send string", repr(msg)) # assert msg in self.statements, msg if msglen is None: msglen = len(msg) assert msglen >= len(msg) msg = msg.encode("ascii").ljust(msglen) self.send(msglen, np.int32) self.socket.sendall(msg) return
[docs] def send_object(self, obj): """Send an object dumped into pickle""" # We can use the highese protocol since the # python requirement >= 3.8 pkl_bytes = pickle.dumps(obj, protocol=5) nbytes = len(pkl_bytes) md5_checksum = hashlib.md5(pkl_bytes) checksum_digest, checksum_count = ( md5_checksum.digest(), md5_checksum.digest_size, ) self.sendmsg("PKLOBJ") # To distinguish from other methods like INIT self.log(" pickle bytes to send: ", str(nbytes)) self.send(nbytes, np.int32) self.log(" sending pickle object....") self.socket.sendall(pkl_bytes) self.log(" sending md5 sum of size: ", str(checksum_count)) self.send(checksum_count, np.int32) self.log(" sending md5 sum..... ", str(checksum_count)) self.socket.sendall(checksum_digest) return
[docs] def recv_object(self, include_header=True): """Return a decoded file include_header: should we receive the header or not """ if include_header: msg = self.recvmsg() assert ( msg.strip() == "PKLOBJ" ), f"Incorrect header {msg} received when calling recv_object method! Please contact the developers" nbytes = int(self.recv(1, np.int32)) self.log(" Will receive pickle object with n-bytes: ", nbytes) bytes_received = self._recvall(nbytes) checksum_nbytes = int(self.recv(1, np.int32)) self.log(" Will receive cheksum digest of nbytes:", checksum_nbytes) digest_received = self._recvall(checksum_nbytes) digest_calc = hashlib.md5(bytes_received).digest() minlen = min(len(digest_calc), len(digest_received)) assert ( digest_calc[:minlen] == digest_received[:minlen] ), "MD5 checksum for the received object does not match!" obj = pickle.loads(bytes_received) return obj
[docs] def send_param(self, name, value): """Send a specific param setting to SPARC This is just a test function to see how things may work TODO: 1) test with just 2 string values to see if SPARC can receive """ self.log(f"Setup param {name}, {value}") msg = self.status() assert msg == "READY", msg # Send message self.sendmsg("SETPARAM") # Send name self.send_string(str(name)) # Send value self.send_string(str(value)) # After this step, socket client should return READY return
[docs] def sendinit(self): """Mimick the old sendinit method but to provide atoms and params to the calculator instance. The actual behavior regarding how the calculator would be (re)-initialized, dependends on the implementation of recvinit """ self.log(" New sendinit for SPARC protocol") self.sendmsg("INIT") self.send(0, np.int32) # fallback msg_chars = [ord(c) for c in "NEWPROTO"] len_msg = len(msg_chars) self.send(len_msg, np.int32) self.send(msg_chars, np.byte) # initialization string return
[docs] def recvinit(self): """Fallback recvinit method""" return super().recvinit()
[docs] def calculate_new_protocol(self, atoms, params): atoms = atoms.copy() atoms.calc = None self.log(" calculate with new protocol") msg = self.status() # We don't know how NEEDINIT is supposed to work, but some codes # seem to be okay if we skip it and send the positions instead. if msg == "NEEDINIT": self.sendinit() self.send_object((atoms, params)) msg = self.status() cell = atoms.get_cell() positions = atoms.get_positions() # Original order assert msg == "READY", msg icell = np.linalg.pinv(cell).transpose() self.sendposdata(cell, icell, positions) msg = self.status() assert msg == "HAVEDATA", msg e, forces, virial, morebytes = self.sendrecv_force() r = dict(energy=e, forces=forces, virial=virial, morebytes=morebytes) # Additional data (e.g. parsed from file output) moredata = self.recv_object() return r, moredata
# TODO: make sure both calc are ok
[docs] class SPARCSocketServer(SocketServer): """We only implement the unix socket version due to simplicity parent: the SPARC parent calculator """ def __init__( self, port=None, unixsocket=None, timeout=None, log=None, parent=None # launch_client=None, ): super().__init__(port=port, unixsocket=unixsocket, timeout=timeout, log=log) self.parent = parent print("Parent : ", self.parent) if self.parent is not None: self.proc = self.parent.process else: self.proc = None print(self.proc) # TODO: guard cases for non-unix sockets @property def socket_filename(self): return self.serversocket.getsockname() @property def proc(self): if self.parent: return self.parent.process else: return None @proc.setter def proc(self, value): return def _accept(self): """Use the SPARCProtocol instead""" print(self.proc) super()._accept() print(self.proc) old_protocol = self.protocol # Swap the protocol if old_protocol: self.protocol = SPARCProtocol(self.clientsocket, txt=self.log) return
[docs] def send_atoms_and_params(self, atoms, params): """Update the atoms and parameters for the SPARC calculator The params should be assignable to SPARC.set The calc for atoms is stripped for simplicity """ atoms.calc = None params = dict(params) pair = (atoms, params) self.protocol.send_object(pair) return
[docs] def calculate_origin_protocol(self, atoms): """Send geometry to client and return calculated things as dict. This will block until client has established connection, then wait for the client to finish the calculation.""" assert not self._closed # If we have not established connection yet, we must block # until the client catches up: if self.protocol is None: self._accept() return self.protocol.calculate(atoms.positions, atoms.cell)
[docs] def calculate_new_protocol(self, atoms, params={}): assert not self._closed # If we have not established connection yet, we must block # until the client catches up: if self.protocol is None: self._accept() return self.protocol.calculate_new_protocol(atoms, params)
[docs] class SPARCSocketClient(SocketClient): def __init__( self, host="localhost", port=None, unixsocket=None, timeout=None, log=None, parent_calc=None # use_v2_protocol=True # If we should use the v2 SPARC protocol ): """Reload the socket client and use SPARCProtocol""" super().__init__( host=host, port=port, unixsocket=unixsocket, timeout=timeout, log=log, ) sock = self.protocol.socket self.protocol = SPARCProtocol(sock, txt=log) self.parent_calc = parent_calc # Track the actual calculator # TODO: make sure the client is compatible with the default socketclient # We shall make NEEDINIT to be the default state # self.state = "NEEDINIT"
[docs] def calculate(self, atoms, use_stress): """Use the calculator instance""" if atoms.calc is None: atoms.calc = self.parent_calc return super().calculate(atoms, use_stress)
[docs] def irun(self, atoms, use_stress=True): """Reimplement single step calculation We're free to implement the INIT method in socket protocol as most calculators do not involve using these. We can let the C-SPARC to spit out error about needinit error. """ # Discard positions received from POSDATA # if the server has send positions through recvinit method discard_posdata = False new_protocol = False try: while True: try: msg = self.protocol.recvmsg() except SocketClosed: # Server closed the connection, but we want to # exit gracefully anyway msg = "EXIT" if msg == "EXIT": # Send stop signal to clients: self.comm.broadcast(np.ones(1, bool), 0) # (When otherwise exiting, things crashed and we should # let MPI_ABORT take care of the mess instead of trying # to synchronize the exit) return elif msg == "STATUS": self.protocol.sendmsg(self.state) elif msg == "POSDATA": assert self.state == "READY" assert ( atoms is not None ), "Your SPARCSocketClient isn't properly initialized!" cell, icell, positions = self.protocol.recvposdata() if not discard_posdata: atoms.cell[:] = cell atoms.positions[:] = positions # At this stage, we should only rely on self.calculate # to continue the socket calculation or restart self.comm.broadcast(np.zeros(1, bool), 0) energy, forces, virial = self.calculate(atoms, use_stress) self.state = "HAVEDATA" yield elif msg == "GETFORCE": assert self.state == "HAVEDATA", self.state self.protocol.sendforce(energy, forces, virial) if new_protocol: # TODO: implement more raw results raw_results = self.parent_calc.raw_results self.protocol.send_object(raw_results) self.state = "NEEDINIT" elif msg == "INIT": assert self.state == "NEEDINIT" # Fall back to the default socketio bead_index, initbytes = self.protocol.recvinit() # The parts below use the new sparc protocol print("Init bytes: ", initbytes) init_msg = "".join([chr(d) for d in initbytes]) if init_msg.startswith("NEWPROTO"): new_protocol = True recv_atoms, params = self.protocol.recv_object() print(recv_atoms, params) if params != {}: self.parent_calc.set(**params) # TODO: should we update the atoms directly or keep copy? atoms = recv_atoms atoms.calc = self.parent_calc discard_posdata = True self.state = "READY" else: raise KeyError("Bad message", msg) finally: self.close()
[docs] def run(self, atoms=None, use_stress=False): """Socket mode in SPARC should allow arbitrary start""" # As a default we shall start the SPARCSocketIO always in needinit mode if atoms is None: self.state = "NEEDINIT" for _ in self.irun(atoms=atoms, use_stress=use_stress): pass