You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
274 lines
9.8 KiB
Python
274 lines
9.8 KiB
Python
1 year ago
|
import asyncio
|
||
|
from loguru import logger
|
||
|
|
||
|
GREETINGS = 0xEB
|
||
|
|
||
|
PROTOCOL_VERSION = 1
|
||
|
|
||
|
OP_NOP = 0
|
||
|
OP_PUBLISH = 1
|
||
|
OP_SUBSCRIBE = 2
|
||
|
OP_UNSUBSCRIBE = 3
|
||
|
OP_MESSAGE = 0x12
|
||
|
OP_BROADCAST = 0x13
|
||
|
OP_ACK = 0xFE
|
||
|
|
||
|
RESPONSE_OK = 0x01
|
||
|
ERR_CLIENT_NOT_REGISTERED = 0x71
|
||
|
ERR_DATA = 0x72
|
||
|
ERR_IO = 0x73
|
||
|
ERR_OTHER = 0x74
|
||
|
ERR_NOT_SUPPORTED = 0x75
|
||
|
ERR_BUSY = 0x76
|
||
|
ERR_NOT_DELIVERED = 0x77
|
||
|
ERR_TIMEOUT = 0x78
|
||
|
|
||
|
PING_FRAME = b'\x00' * 9
|
||
|
|
||
|
|
||
|
async def on_frame_default(frame):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Client:
|
||
|
|
||
|
def __init__(self, path, name):
|
||
|
self.path = path
|
||
|
self.writer = None
|
||
|
self.reader_fut = None
|
||
|
self.pinger_fut = None
|
||
|
self.buf_size = 8192
|
||
|
self.name = name
|
||
|
self.frame_id = 0
|
||
|
self.ping_interval = 1
|
||
|
self.on_frame = on_frame_default
|
||
|
self.socket_lock = asyncio.Lock()
|
||
|
self.mgmt_lock = asyncio.Lock()
|
||
|
self.connected = False
|
||
|
self.frames = {}
|
||
|
self.timeout = 1
|
||
|
|
||
|
async def connect(self):
|
||
|
async with self.mgmt_lock:
|
||
|
if self.path.endswith('.sock') or self.path.endswith(
|
||
|
'.socket') or self.path.endswith(
|
||
|
'.ipc') or self.path.startswith('/'):
|
||
|
if hasattr(asyncio, 'open_unix_connection'):
|
||
|
reader, writer = await asyncio.open_unix_connection(
|
||
|
self.path, limit=self.buf_size)
|
||
|
else:
|
||
|
raise ValueError('only support unix like system')
|
||
|
else:
|
||
|
host, port = self.path.rsplit(':', maxsplit=2)
|
||
|
reader, writer = await asyncio.open_connection(
|
||
|
host, int(port), limit=self.buf_size)
|
||
|
buf = await asyncio.wait_for(self.readexactly(reader, 3),
|
||
|
timeout=self.timeout)
|
||
|
if buf[0] != GREETINGS:
|
||
|
raise RuntimeError('Unsupported protocol')
|
||
|
if int.from_bytes(buf[1:3], 'little') != PROTOCOL_VERSION:
|
||
|
raise RuntimeError('Unsupported protocol version')
|
||
|
writer.write(buf)
|
||
|
await asyncio.wait_for(writer.drain(), timeout=self.timeout)
|
||
|
buf = await asyncio.wait_for(self.readexactly(reader, 1),
|
||
|
timeout=self.timeout)
|
||
|
if buf[0] != RESPONSE_OK:
|
||
|
raise RuntimeError(f'Server response: {hex(buf[0])}')
|
||
|
name = self.name.encode()
|
||
|
writer.write(len(name).to_bytes(2, 'little') + name)
|
||
|
await asyncio.wait_for(writer.drain(), timeout=self.timeout)
|
||
|
buf = await asyncio.wait_for(self.readexactly(reader, 1),
|
||
|
timeout=self.timeout)
|
||
|
if buf[0] != RESPONSE_OK:
|
||
|
raise RuntimeError(f'Server response: {hex(buf[0])}')
|
||
|
self.writer = writer
|
||
|
self.connected = True
|
||
|
self.pinger_fut = asyncio.ensure_future(self._t_pinger())
|
||
|
self.reader_fut = asyncio.ensure_future(self._t_reader(reader))
|
||
|
|
||
|
async def handle_daemon_exception(self, e):
|
||
|
async with self.mgmt_lock:
|
||
|
if self.connected:
|
||
|
await self._disconnect()
|
||
|
import traceback
|
||
|
logger.error(traceback.format_exc())
|
||
|
|
||
|
async def _t_pinger(self):
|
||
|
try:
|
||
|
while True:
|
||
|
await asyncio.sleep(self.ping_interval)
|
||
|
async with self.socket_lock:
|
||
|
self.writer.write(PING_FRAME)
|
||
|
await asyncio.wait_for(self.writer.drain(),
|
||
|
timeout=self.timeout)
|
||
|
except Exception as e:
|
||
|
asyncio.ensure_future(self.handle_daemon_exception(e))
|
||
|
|
||
|
async def _t_reader(self, reader):
|
||
|
try:
|
||
|
while True:
|
||
|
buf = await self.readexactly(reader, 6)
|
||
|
if buf[0] == OP_NOP:
|
||
|
continue
|
||
|
elif buf[0] == OP_ACK:
|
||
|
op_id = int.from_bytes(buf[1:5], 'little')
|
||
|
try:
|
||
|
o = self.frames.pop(op_id)
|
||
|
o.result = buf[5]
|
||
|
o.completed.set()
|
||
|
except KeyError:
|
||
|
logger.warning(f'orphaned BUS/RT frame ack {op_id}')
|
||
|
else:
|
||
|
|
||
|
async def read_frame(tp, data_len):
|
||
|
frame = Frame()
|
||
|
frame.type = tp
|
||
|
sender = await reader.readuntil(b'\x00')
|
||
|
data_len -= len(sender)
|
||
|
frame.sender = sender[:-1].decode()
|
||
|
if buf[0] == OP_PUBLISH:
|
||
|
topic = await reader.readuntil(b'\x00')
|
||
|
data_len -= len(topic)
|
||
|
frame.topic = topic[:-1].decode()
|
||
|
else:
|
||
|
frame.topic = None
|
||
|
data = b''
|
||
|
while len(data) < data_len:
|
||
|
buf_size = data_len - len(data)
|
||
|
len(data)
|
||
|
try:
|
||
|
data += await reader.readexactly(
|
||
|
buf_size if buf_size < self.buf_size else
|
||
|
self.buf_size)
|
||
|
except asyncio.IncompleteReadError:
|
||
|
pass
|
||
|
frame.payload = data
|
||
|
return frame
|
||
|
|
||
|
try:
|
||
|
data_len = int.from_bytes(buf[1:5], 'little')
|
||
|
frame = await read_frame(buf[0], data_len)
|
||
|
except Exception as e:
|
||
|
logger.error(f'Invalid frame from the server: {e}')
|
||
|
raise
|
||
|
asyncio.ensure_future(self.on_frame(frame))
|
||
|
except Exception as e:
|
||
|
asyncio.ensure_future(self.handle_daemon_exception(e))
|
||
|
|
||
|
async def readexactly(self, reader, data_len):
|
||
|
data = b''
|
||
|
while len(data) < data_len:
|
||
|
buf_size = data_len - len(data)
|
||
|
try:
|
||
|
chunk = await reader.readexactly(
|
||
|
buf_size if buf_size < self.buf_size else self.buf_size)
|
||
|
data += chunk
|
||
|
except asyncio.IncompleteReadError:
|
||
|
await asyncio.sleep(0.01)
|
||
|
return data
|
||
|
|
||
|
async def disconnect(self):
|
||
|
async with self.mgmt_lock:
|
||
|
await self._disconnect()
|
||
|
|
||
|
async def _disconnect(self):
|
||
|
self.connected = False
|
||
|
self.writer.close()
|
||
|
if self.reader_fut is not None:
|
||
|
self.reader_fut.cancel()
|
||
|
if self.pinger_fut is not None:
|
||
|
self.pinger_fut.cancel()
|
||
|
|
||
|
async def send(self, target=None, frame=None):
|
||
|
try:
|
||
|
async with self.socket_lock:
|
||
|
self.frame_id += 1
|
||
|
if self.frame_id > 0xffff_ffff:
|
||
|
self.frame_id = 1
|
||
|
frame_id = self.frame_id
|
||
|
o = ClientFrame(frame.qos)
|
||
|
if frame.qos & 0b1 != 0:
|
||
|
self.frames[frame_id] = o
|
||
|
flags = frame.type | frame.qos << 6
|
||
|
if frame.type == OP_SUBSCRIBE or frame.type == OP_UNSUBSCRIBE:
|
||
|
topics = frame.topic if isinstance(frame.topic,
|
||
|
list) else [frame.topic]
|
||
|
payload = b'\x00'.join(t.encode() for t in topics)
|
||
|
self.writer.write(
|
||
|
frame_id.to_bytes(4, 'little') +
|
||
|
flags.to_bytes(1, 'little') +
|
||
|
len(payload).to_bytes(4, 'little') + payload)
|
||
|
else:
|
||
|
frame_len = len(target) + len(frame.payload) + 1
|
||
|
if frame.header is not None:
|
||
|
frame_len += len(frame.header)
|
||
|
if frame_len > 0xffff_ffff:
|
||
|
raise ValueError('frame too large')
|
||
|
self.writer.write(
|
||
|
frame_id.to_bytes(4, 'little') +
|
||
|
flags.to_bytes(1, 'little') +
|
||
|
frame_len.to_bytes(4, 'little') + target.encode() +
|
||
|
b'\x00' +
|
||
|
(frame.header if frame.header is not None else b''))
|
||
|
self.writer.write(frame.payload.encode(
|
||
|
) if isinstance(frame.payload, str) else frame.payload)
|
||
|
await self.writer.drain()
|
||
|
return o
|
||
|
except:
|
||
|
try:
|
||
|
del self.frames[frame_id]
|
||
|
except KeyError:
|
||
|
pass
|
||
|
raise
|
||
|
|
||
|
def subscribe(self, topics):
|
||
|
frame = Frame(tp=OP_SUBSCRIBE)
|
||
|
frame.topic = topics
|
||
|
return self.send(None, frame)
|
||
|
|
||
|
def unsubscribe(self, topics):
|
||
|
frame = Frame(tp=OP_UNSUBSCRIBE)
|
||
|
frame.topic = topics
|
||
|
return self.send(None, frame)
|
||
|
|
||
|
def is_connected(self):
|
||
|
return self.connected
|
||
|
|
||
|
|
||
|
class ClientFrame:
|
||
|
|
||
|
def __init__(self, qos):
|
||
|
self.qos = qos
|
||
|
self.result = 0
|
||
|
if qos & 0b1 != 0:
|
||
|
self.completed = asyncio.Event()
|
||
|
else:
|
||
|
self.completed = None
|
||
|
|
||
|
def is_completed(self):
|
||
|
if self.qos & 0b1 != 0:
|
||
|
return self.completed.is_set()
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
async def wait_completed(self, timeout=None):
|
||
|
if self.qos & 0b1 == 0:
|
||
|
return RESPONSE_OK
|
||
|
elif timeout:
|
||
|
await asyncio.wait_for(self.completed.wait(), timeout=timeout)
|
||
|
else:
|
||
|
await self.completed.wait()
|
||
|
return self.result
|
||
|
|
||
|
|
||
|
class Frame:
|
||
|
|
||
|
def __init__(self, payload=None, tp=OP_MESSAGE, qos=0):
|
||
|
self.payload = payload
|
||
|
# used for zero-copy
|
||
|
self.header = None
|
||
|
self.type = tp
|
||
|
self.qos = qos
|
||
|
|
||
|
|
||
|
__all__ = ["Client","Frame","OP_PUBLISH","OP_MESSAGE","OP_BROADCAST","ERR_IO","ERR_TIMEOUT","RESPONSE_OK","on_frame_default"]
|