You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

237 lines
8.0 KiB
Python

import functools
import threading
from loguru import logger
from .client import OP_MESSAGE
from .base import ERR_IO, ERR_TIMEOUT, RESPONSE_OK, RPC_NOTIFICATION_HEADER, RPC_REQUEST_HEADER, RPC_REPLY_HEADER, \
RPC_ERROR_REPLY_HEADER, RPC_NOTIFICATION, RPC_REQUEST, RPC_REPLY, RPC_ERROR, RPC_ERROR_CODE_METHOD_NOT_FOUND, \
RPC_ERROR_CODE_INTERNAL, serialize, deserialize,RpcException
from .client import on_frame_default
def on_call_default(event):
raise RpcException('RPC Engine not initialized',
RPC_ERROR_CODE_METHOD_NOT_FOUND)
def on_notification_default(event):
pass
def format_rpc_e_msg(e):
if isinstance(e, RpcException):
return e.rpc_error_payload
else:
return str(e)
class Rpc:
def __init__(self, client):
self.client = client
self.client.on_frame = self._handle_frame
self.call_id = 0
self.call_lock = threading.Lock()
self.calls = {}
self.on_frame = on_frame_default
self.on_call = on_call_default
self.on_notification = on_notification_default
def is_connected(self):
return self.client.connected
def notify(self, target, notification):
return self.client.send(target, notification)
def call0(self, target, request):
request.header = RPC_REQUEST_HEADER + b'\x00\x00\x00\x00' + \
request.method + b'\x00'
return self.client.send(target, request)
def call(self, target, request):
with self.call_lock:
call_id = self.call_id + 1
if call_id == 0xffff_ffff:
self.call_id = 0
else:
self.call_id = call_id
call_event = RpcCallEvent()
self.calls[call_id] = call_event
request.header = RPC_REQUEST_HEADER + call_id.to_bytes(
4, 'little') + request.method + b'\x00'
try:
try:
code = self.client.send(
target, request).wait_completed(timeout=self.client.timeout)
if code != RESPONSE_OK:
try:
del self.calls[call_id]
except KeyError:
pass
err_code = -32000 - code
call_event.error = RpcException('RPC error', code=err_code)
call_event.completed.set()
except TimeoutError:
try:
del self.calls[call_id]
except KeyError:
pass
err_code = -32000 - ERR_TIMEOUT
call_event.error = RpcException('RPC timeout', code=err_code)
call_event.completed.set()
except Exception as e:
try:
del self.calls[call_id]
except KeyError:
pass
call_event.error = RpcException(str(e), code=-32000 - ERR_IO)
call_event.completed.set()
return call_event
def __getattr__(self, method):
return functools.partial(self.params_call, method=method)
def params_call(self, target: str, method: str, **kwargs):
params = kwargs
if '_raw' in kwargs:
params = kwargs.pop('_raw')
params = params or None
c0 = False
if method.endswith('0'):
c0 = True
method = method[:-1]
request = Request(method, serialize(params))
if c0:
self.call0(target, request)
return None
result = self.call(target, request)
return deserialize((result.wait_completed()).get_payload())
def _handle_frame(self, frame):
self.spawn(self._t_handler, frame)
def spawn(self, f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
def _t_handler(self, frame):
try:
if frame.type == OP_MESSAGE:
if frame.payload[0] == RPC_NOTIFICATION:
event = Event(RPC_NOTIFICATION, frame, 1)
self.on_notification(event)
elif frame.payload[0] == RPC_REQUEST:
sender = frame.sender
call_id_b = frame.payload[1:5]
call_id = int.from_bytes(call_id_b, 'little')
method = frame.payload[5:5 +
frame.payload[5:].index(b'\x00')]
event = Event(RPC_REQUEST, frame, 6 + len(method))
event.call_id = call_id
event.method = method
if call_id == 0:
self.on_call(event)
else:
reply = Reply()
try:
reply.payload = self.on_call(event)
if reply.payload is None:
reply.payload = b''
reply.header = RPC_REPLY_HEADER + call_id_b
except Exception as e:
try:
code = e.rpc_error_code
except AttributeError:
code = RPC_ERROR_CODE_INTERNAL
reply.header = (
RPC_ERROR_REPLY_HEADER + call_id_b +
code.to_bytes(2, 'little', signed=True))
reply.payload = format_rpc_e_msg(e)
self.client.send(sender, reply)
elif frame.payload[0] == RPC_REPLY or frame.payload[
0] == RPC_ERROR:
call_id = int.from_bytes(frame.payload[1:5], 'little')
try:
call_event = self.calls.pop(call_id)
call_event.frame = frame
if frame.payload[0] == RPC_ERROR:
err_code = int.from_bytes(frame.payload[5:7],
'little',
signed=True)
call_event.error = RpcException(frame.payload[7:],
code=err_code)
call_event.completed.set()
except KeyError:
logger.warning(f'orphaned RPC response: {call_id}')
else:
logger.error(f'Invalid RPC frame code: {frame.payload[0]}')
else:
self.on_frame(frame)
except Exception:
import traceback
logger.error(traceback.format_exc())
class RpcCallEvent:
def __init__(self):
self.frame = None
self.error = None
self.completed = threading.Event()
def is_completed(self):
return self.completed.is_set()
def wait_completed(self, *args, **kwargs):
if not self.completed.wait(*args, **kwargs):
raise TimeoutError
if self.error is None:
return self
else:
raise self.error
def get_payload(self):
return self.frame.payload[5:]
def is_empty(self):
return len(self.frame.payload) <= 5
class Event:
def __init__(self, tp, frame, payload_pos):
self.tp = tp
self.frame = frame
self._payload_pos = payload_pos
def get_payload(self):
return self.frame.payload[self._payload_pos:]
class Notification:
def __init__(self, payload=b''):
self.payload = payload
self.type = OP_MESSAGE
self.qos = 1
self.header = RPC_NOTIFICATION_HEADER
class Request:
def __init__(self, method, params=b''):
self.payload = b'' if params is None else params
self.type = OP_MESSAGE
self.qos = 1
self.method = method.encode() if isinstance(method, str) else method
self.header = None
class Reply:
def __init__(self, result=b''):
self.payload = result
self.type = OP_MESSAGE
self.qos = 1
self.header = None