You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

274 lines
9.8 KiB
Python

1 year ago
import asyncio
from loguru import logger
GREETINGS = 0xEB
PROTOCOL_VERSION = 1
OP_NOP = 0
OP_PUBLISH = 1
OP_SUBSCRIBE = 2
OP_UNSUBSCRIBE = 3
OP_MESSAGE = 0x12
OP_BROADCAST = 0x13
OP_ACK = 0xFE
RESPONSE_OK = 0x01
ERR_CLIENT_NOT_REGISTERED = 0x71
ERR_DATA = 0x72
ERR_IO = 0x73
ERR_OTHER = 0x74
ERR_NOT_SUPPORTED = 0x75
ERR_BUSY = 0x76
ERR_NOT_DELIVERED = 0x77
ERR_TIMEOUT = 0x78
PING_FRAME = b'\x00' * 9
async def on_frame_default(frame):
pass
class Client:
def __init__(self, path, name):
self.path = path
self.writer = None
self.reader_fut = None
self.pinger_fut = None
self.buf_size = 8192
self.name = name
self.frame_id = 0
self.ping_interval = 1
self.on_frame = on_frame_default
self.socket_lock = asyncio.Lock()
self.mgmt_lock = asyncio.Lock()
self.connected = False
self.frames = {}
self.timeout = 1
async def connect(self):
async with self.mgmt_lock:
if self.path.endswith('.sock') or self.path.endswith(
'.socket') or self.path.endswith(
'.ipc') or self.path.startswith('/'):
if hasattr(asyncio, 'open_unix_connection'):
reader, writer = await asyncio.open_unix_connection(
self.path, limit=self.buf_size)
else:
raise ValueError('only support unix like system')
else:
host, port = self.path.rsplit(':', maxsplit=2)
reader, writer = await asyncio.open_connection(
host, int(port), limit=self.buf_size)
buf = await asyncio.wait_for(self.readexactly(reader, 3),
timeout=self.timeout)
if buf[0] != GREETINGS:
raise RuntimeError('Unsupported protocol')
if int.from_bytes(buf[1:3], 'little') != PROTOCOL_VERSION:
raise RuntimeError('Unsupported protocol version')
writer.write(buf)
await asyncio.wait_for(writer.drain(), timeout=self.timeout)
buf = await asyncio.wait_for(self.readexactly(reader, 1),
timeout=self.timeout)
if buf[0] != RESPONSE_OK:
raise RuntimeError(f'Server response: {hex(buf[0])}')
name = self.name.encode()
writer.write(len(name).to_bytes(2, 'little') + name)
await asyncio.wait_for(writer.drain(), timeout=self.timeout)
buf = await asyncio.wait_for(self.readexactly(reader, 1),
timeout=self.timeout)
if buf[0] != RESPONSE_OK:
raise RuntimeError(f'Server response: {hex(buf[0])}')
self.writer = writer
self.connected = True
self.pinger_fut = asyncio.ensure_future(self._t_pinger())
self.reader_fut = asyncio.ensure_future(self._t_reader(reader))
async def handle_daemon_exception(self, e):
async with self.mgmt_lock:
if self.connected:
await self._disconnect()
import traceback
logger.error(traceback.format_exc())
async def _t_pinger(self):
try:
while True:
await asyncio.sleep(self.ping_interval)
async with self.socket_lock:
self.writer.write(PING_FRAME)
await asyncio.wait_for(self.writer.drain(),
timeout=self.timeout)
except Exception as e:
asyncio.ensure_future(self.handle_daemon_exception(e))
async def _t_reader(self, reader):
try:
while True:
buf = await self.readexactly(reader, 6)
if buf[0] == OP_NOP:
continue
elif buf[0] == OP_ACK:
op_id = int.from_bytes(buf[1:5], 'little')
try:
o = self.frames.pop(op_id)
o.result = buf[5]
o.completed.set()
except KeyError:
logger.warning(f'orphaned BUS/RT frame ack {op_id}')
else:
async def read_frame(tp, data_len):
frame = Frame()
frame.type = tp
sender = await reader.readuntil(b'\x00')
data_len -= len(sender)
frame.sender = sender[:-1].decode()
if buf[0] == OP_PUBLISH:
topic = await reader.readuntil(b'\x00')
data_len -= len(topic)
frame.topic = topic[:-1].decode()
else:
frame.topic = None
data = b''
while len(data) < data_len:
buf_size = data_len - len(data)
len(data)
try:
data += await reader.readexactly(
buf_size if buf_size < self.buf_size else
self.buf_size)
except asyncio.IncompleteReadError:
pass
frame.payload = data
return frame
try:
data_len = int.from_bytes(buf[1:5], 'little')
frame = await read_frame(buf[0], data_len)
except Exception as e:
logger.error(f'Invalid frame from the server: {e}')
raise
asyncio.ensure_future(self.on_frame(frame))
except Exception as e:
asyncio.ensure_future(self.handle_daemon_exception(e))
async def readexactly(self, reader, data_len):
data = b''
while len(data) < data_len:
buf_size = data_len - len(data)
try:
chunk = await reader.readexactly(
buf_size if buf_size < self.buf_size else self.buf_size)
data += chunk
except asyncio.IncompleteReadError:
await asyncio.sleep(0.01)
return data
async def disconnect(self):
async with self.mgmt_lock:
await self._disconnect()
async def _disconnect(self):
self.connected = False
self.writer.close()
if self.reader_fut is not None:
self.reader_fut.cancel()
if self.pinger_fut is not None:
self.pinger_fut.cancel()
async def send(self, target=None, frame=None):
try:
async with self.socket_lock:
self.frame_id += 1
if self.frame_id > 0xffff_ffff:
self.frame_id = 1
frame_id = self.frame_id
o = ClientFrame(frame.qos)
if frame.qos & 0b1 != 0:
self.frames[frame_id] = o
flags = frame.type | frame.qos << 6
if frame.type == OP_SUBSCRIBE or frame.type == OP_UNSUBSCRIBE:
topics = frame.topic if isinstance(frame.topic,
list) else [frame.topic]
payload = b'\x00'.join(t.encode() for t in topics)
self.writer.write(
frame_id.to_bytes(4, 'little') +
flags.to_bytes(1, 'little') +
len(payload).to_bytes(4, 'little') + payload)
else:
frame_len = len(target) + len(frame.payload) + 1
if frame.header is not None:
frame_len += len(frame.header)
if frame_len > 0xffff_ffff:
raise ValueError('frame too large')
self.writer.write(
frame_id.to_bytes(4, 'little') +
flags.to_bytes(1, 'little') +
frame_len.to_bytes(4, 'little') + target.encode() +
b'\x00' +
(frame.header if frame.header is not None else b''))
self.writer.write(frame.payload.encode(
) if isinstance(frame.payload, str) else frame.payload)
await self.writer.drain()
return o
except:
try:
del self.frames[frame_id]
except KeyError:
pass
raise
def subscribe(self, topics):
frame = Frame(tp=OP_SUBSCRIBE)
frame.topic = topics
return self.send(None, frame)
def unsubscribe(self, topics):
frame = Frame(tp=OP_UNSUBSCRIBE)
frame.topic = topics
return self.send(None, frame)
def is_connected(self):
return self.connected
class ClientFrame:
def __init__(self, qos):
self.qos = qos
self.result = 0
if qos & 0b1 != 0:
self.completed = asyncio.Event()
else:
self.completed = None
def is_completed(self):
if self.qos & 0b1 != 0:
return self.completed.is_set()
else:
return True
async def wait_completed(self, timeout=None):
if self.qos & 0b1 == 0:
return RESPONSE_OK
elif timeout:
await asyncio.wait_for(self.completed.wait(), timeout=timeout)
else:
await self.completed.wait()
return self.result
class Frame:
def __init__(self, payload=None, tp=OP_MESSAGE, qos=0):
self.payload = payload
# used for zero-copy
self.header = None
self.type = tp
self.qos = qos
__all__ = ["Client","Frame","OP_PUBLISH","OP_MESSAGE","OP_BROADCAST","ERR_IO","ERR_TIMEOUT","RESPONSE_OK","on_frame_default"]