import asyncio import functools from loguru import logger from .busrt_client import ( ERR_IO, ERR_TIMEOUT, OP_MESSAGE, RESPONSE_OK, on_frame_default, ) from .msgutils import deserialize, serialize RPC_NOTIFICATION_HEADER = b'\x00' RPC_REQUEST_HEADER = b'\x01' RPC_REPLY_HEADER = b'\x11' RPC_ERROR_REPLY_HEADER = b'\x12' RPC_NOTIFICATION = 0x00 RPC_REQUEST = 0x01 RPC_REPLY = 0x11 RPC_ERROR = 0x12 RPC_ERROR_CODE_PARSE = -32700 RPC_ERROR_CODE_INVALID_REQUEST = -32600 RPC_ERROR_CODE_METHOD_NOT_FOUND = -32601 RPC_ERROR_CODE_INVALID_METHOD_PARAMS = -32602 RPC_ERROR_CODE_INTERNAL = -32603 async def on_call_default(event): raise RpcException('RPC Engine not initialized', RPC_ERROR_CODE_METHOD_NOT_FOUND) async def on_notification_default(event): pass def format_rpc_e_msg(e): if isinstance(e, RpcException): return e.rpc_error_payload else: return str(e) class RpcException(Exception): def __init__(self, msg: str | bytes = '', code=RPC_ERROR_CODE_INTERNAL): self.rpc_error_code = code self.rpc_error_payload = msg super().__init__(msg if isinstance(msg, str) else msg.decode()) def __str__(self): return super().__str__() + f' (code: {self.rpc_error_code})' class Rpc: def __init__(self, client): self.client = client self.client.on_frame = self._handle_frame self.call_id = 0 self.call_lock = asyncio.Lock() self.calls = {} self.on_frame = on_frame_default self.on_call = on_call_default self.on_notification = on_notification_default def is_connected(self): return self.client.connected def notify(self, target, notification): return self.client.send(target, notification) def call0(self, target, request): request.header = RPC_REQUEST_HEADER + b'\x00\x00\x00\x00' + \ request.method + b'\x00' return self.client.send(target, request) async def call(self, target, request): async with self.call_lock: call_id = self.call_id + 1 if call_id == 0xffff_ffff: self.call_id = 0 else: self.call_id = call_id call_event = RpcCallEvent() self.calls[call_id] = call_event request.header = RPC_REQUEST_HEADER + call_id.to_bytes( 4, 'little') + request.method + b'\x00' try: try: code = await (await self.client.send( target, request)).wait_completed(timeout=self.client.timeout) if code != RESPONSE_OK: try: del self.calls[call_id] except KeyError: pass err_code = -32000 - code call_event.error = RpcException('RPC error', code=err_code) call_event.completed.set() except asyncio.TimeoutError: try: del self.calls[call_id] except KeyError: pass err_code = -32000 - ERR_TIMEOUT call_event.error = RpcException('RPC timeout', code=err_code) call_event.completed.set() except Exception as e: try: del self.calls[call_id] except KeyError: pass call_event.error = RpcException(str(e), code=-32000 - ERR_IO) call_event.completed.set() return call_event def __getattr__(self, method): return functools.partial(self.params_call, method=method) async def params_call(self, target, method: str, **kwargs): params = kwargs if '_raw' in kwargs: params = kwargs.pop('_raw') params = params or None c0 = False if method.endswith('0'): c0 = True method = method[:-1] request = Request(method, serialize(params)) if c0: await self.call0(target, request) return None result = await self.call(target, request) return deserialize((await result.wait_completed()).get_payload()) async def _handle_frame(self, frame): try: if frame.type == OP_MESSAGE: if frame.payload[0] == RPC_NOTIFICATION: event = Event(RPC_NOTIFICATION, frame, 1) await self.on_notification(event) elif frame.payload[0] == RPC_REQUEST: sender = frame.sender call_id_b = frame.payload[1:5] call_id = int.from_bytes(call_id_b, 'little') method = frame.payload[5:5 + frame.payload[5:].index(b'\x00')] event = Event(RPC_REQUEST, frame, 6 + len(method)) event.call_id = call_id event.method = method if call_id == 0: await self.on_call(event) else: reply = Reply() try: reply.payload = await self.on_call(event) if reply.payload is None: reply.payload = b'' reply.header = RPC_REPLY_HEADER + call_id_b except Exception as e: code = getattr(e, 'rpc_error_code', RPC_ERROR_CODE_INTERNAL) reply.header = ( RPC_ERROR_REPLY_HEADER + call_id_b + code.to_bytes(2, 'little', signed=True)) reply.payload = format_rpc_e_msg(e) await self.client.send(sender, reply) elif frame.payload[0] == RPC_REPLY or frame.payload[ 0] == RPC_ERROR: call_id = int.from_bytes(frame.payload[1:5], 'little') try: call_event = self.calls.pop(call_id) call_event.frame = frame if frame.payload[0] == RPC_ERROR: err_code = int.from_bytes(frame.payload[5:7], 'little', signed=True) call_event.error = RpcException(frame.payload[7:], code=err_code) call_event.completed.set() except KeyError: logger.warning(f'orphaned RPC response: {call_id}') else: logger.error(f'Invalid RPC frame code: {frame.payload[0]}') else: await self.on_frame(frame) except Exception: import traceback logger.error(traceback.format_exc()) class RpcCallEvent: def __init__(self): self.frame = None self.error = None self.completed = asyncio.Event() def is_completed(self): return self.completed.is_set() async def wait_completed(self, timeout=None): if timeout: if not await asyncio.wait_for(self.completed.wait(), timeout=timeout): raise TimeoutError else: await self.completed.wait() if self.error is None: return self else: raise self.error def get_payload(self): return self.frame.payload[5:] def is_empty(self): return len(self.frame.payload) <= 5 class Event: def __init__(self, tp, frame, payload_pos): self.tp = tp self.frame = frame self._payload_pos = payload_pos def get_payload(self): return self.frame.payload[self._payload_pos:] class Notification: def __init__(self, payload=b''): self.payload = payload self.type = OP_MESSAGE self.qos = 1 self.header = RPC_NOTIFICATION_HEADER class Request: def __init__(self, method, params=b''): self.payload = b'' if params is None else params self.type = OP_MESSAGE self.qos = 1 self.method = method.encode() if isinstance(method, str) else method self.header = None class Reply: def __init__(self, result=b''): self.payload = result self.type = OP_MESSAGE self.qos = 1 self.header = None __all__ = ["Rpc", "Request", "Reply", "RpcCallEvent", "Event", "RpcException"]