You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
8.5 KiB
Python

import asyncio
import functools
from loguru import logger
from .busrt_client import (
ERR_IO,
ERR_TIMEOUT,
OP_MESSAGE,
RESPONSE_OK,
on_frame_default,
)
from .msgutils import deserialize, serialize
RPC_NOTIFICATION_HEADER = b'\x00'
RPC_REQUEST_HEADER = b'\x01'
RPC_REPLY_HEADER = b'\x11'
RPC_ERROR_REPLY_HEADER = b'\x12'
RPC_NOTIFICATION = 0x00
RPC_REQUEST = 0x01
RPC_REPLY = 0x11
RPC_ERROR = 0x12
RPC_ERROR_CODE_PARSE = -32700
RPC_ERROR_CODE_INVALID_REQUEST = -32600
RPC_ERROR_CODE_METHOD_NOT_FOUND = -32601
RPC_ERROR_CODE_INVALID_METHOD_PARAMS = -32602
RPC_ERROR_CODE_INTERNAL = -32603
async def on_call_default(event):
raise RpcException('RPC Engine not initialized',
RPC_ERROR_CODE_METHOD_NOT_FOUND)
async def on_notification_default(event):
pass
def format_rpc_e_msg(e):
if isinstance(e, RpcException):
return e.rpc_error_payload
else:
return str(e)
class RpcException(Exception):
def __init__(self, msg: str | bytes = '', code=RPC_ERROR_CODE_INTERNAL):
self.rpc_error_code = code
self.rpc_error_payload = msg
super().__init__(msg if isinstance(msg, str) else msg.decode())
def __str__(self):
return super().__str__() + f' (code: {self.rpc_error_code})'
class Rpc:
def __init__(self, client):
self.client = client
self.client.on_frame = self._handle_frame
self.call_id = 0
self.call_lock = asyncio.Lock()
self.calls = {}
self.on_frame = on_frame_default
self.on_call = on_call_default
self.on_notification = on_notification_default
def is_connected(self):
return self.client.connected
def notify(self, target, notification):
return self.client.send(target, notification)
def call0(self, target, request):
request.header = RPC_REQUEST_HEADER + b'\x00\x00\x00\x00' + \
request.method + b'\x00'
return self.client.send(target, request)
async def call(self, target, request):
async with self.call_lock:
call_id = self.call_id + 1
if call_id == 0xffff_ffff:
self.call_id = 0
else:
self.call_id = call_id
call_event = RpcCallEvent()
self.calls[call_id] = call_event
request.header = RPC_REQUEST_HEADER + call_id.to_bytes(
4, 'little') + request.method + b'\x00'
try:
try:
code = await (await self.client.send(
target,
request)).wait_completed(timeout=self.client.timeout)
if code != RESPONSE_OK:
try:
del self.calls[call_id]
except KeyError:
pass
err_code = -32000 - code
call_event.error = RpcException('RPC error', code=err_code)
call_event.completed.set()
except asyncio.TimeoutError:
try:
del self.calls[call_id]
except KeyError:
pass
err_code = -32000 - ERR_TIMEOUT
call_event.error = RpcException('RPC timeout', code=err_code)
call_event.completed.set()
except Exception as e:
try:
del self.calls[call_id]
except KeyError:
pass
call_event.error = RpcException(str(e), code=-32000 - ERR_IO)
call_event.completed.set()
return call_event
def __getattr__(self, method):
return functools.partial(self.params_call, method=method)
async def params_call(self, target, method: str, **kwargs):
params = kwargs
if '_raw' in kwargs:
params = kwargs.pop('_raw')
params = params or None
c0 = False
if method.endswith('0'):
c0 = True
method = method[:-1]
request = Request(method, serialize(params))
if c0:
await self.call0(target, request)
return None
result = await self.call(target, request)
return deserialize((await result.wait_completed()).get_payload())
async def _handle_frame(self, frame):
try:
if frame.type == OP_MESSAGE:
if frame.payload[0] == RPC_NOTIFICATION:
event = Event(RPC_NOTIFICATION, frame, 1)
await self.on_notification(event)
elif frame.payload[0] == RPC_REQUEST:
sender = frame.sender
call_id_b = frame.payload[1:5]
call_id = int.from_bytes(call_id_b, 'little')
method = frame.payload[5:5 +
frame.payload[5:].index(b'\x00')]
event = Event(RPC_REQUEST, frame, 6 + len(method))
event.call_id = call_id
event.method = method
if call_id == 0:
await self.on_call(event)
else:
reply = Reply()
try:
reply.payload = await self.on_call(event)
if reply.payload is None:
reply.payload = b''
reply.header = RPC_REPLY_HEADER + call_id_b
except Exception as e:
code = getattr(e, 'rpc_error_code', RPC_ERROR_CODE_INTERNAL)
reply.header = (
RPC_ERROR_REPLY_HEADER + call_id_b +
code.to_bytes(2, 'little', signed=True))
reply.payload = format_rpc_e_msg(e)
await self.client.send(sender, reply)
elif frame.payload[0] == RPC_REPLY or frame.payload[
0] == RPC_ERROR:
call_id = int.from_bytes(frame.payload[1:5], 'little')
try:
call_event = self.calls.pop(call_id)
call_event.frame = frame
if frame.payload[0] == RPC_ERROR:
err_code = int.from_bytes(frame.payload[5:7],
'little',
signed=True)
call_event.error = RpcException(frame.payload[7:],
code=err_code)
call_event.completed.set()
except KeyError:
logger.warning(f'orphaned RPC response: {call_id}')
else:
logger.error(f'Invalid RPC frame code: {frame.payload[0]}')
else:
await self.on_frame(frame)
except Exception:
import traceback
logger.error(traceback.format_exc())
class RpcCallEvent:
def __init__(self):
self.frame = None
self.error = None
self.completed = asyncio.Event()
def is_completed(self):
return self.completed.is_set()
async def wait_completed(self, timeout=None):
if timeout:
if not await asyncio.wait_for(self.completed.wait(),
timeout=timeout):
raise TimeoutError
else:
await self.completed.wait()
if self.error is None:
return self
else:
raise self.error
def get_payload(self):
return self.frame.payload[5:]
def is_empty(self):
return len(self.frame.payload) <= 5
class Event:
def __init__(self, tp, frame, payload_pos):
self.tp = tp
self.frame = frame
self._payload_pos = payload_pos
def get_payload(self):
return self.frame.payload[self._payload_pos:]
class Notification:
def __init__(self, payload=b''):
self.payload = payload
self.type = OP_MESSAGE
self.qos = 1
self.header = RPC_NOTIFICATION_HEADER
class Request:
def __init__(self, method, params=b''):
self.payload = b'' if params is None else params
self.type = OP_MESSAGE
self.qos = 1
self.method = method.encode() if isinstance(method, str) else method
self.header = None
class Reply:
def __init__(self, result=b''):
self.payload = result
self.type = OP_MESSAGE
self.qos = 1
self.header = None
__all__ = ["Rpc", "Request", "Reply", "RpcCallEvent", "Event", "RpcException"]