You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
237 lines
8.0 KiB
Python
237 lines
8.0 KiB
Python
import functools
|
|
import threading
|
|
from loguru import logger
|
|
|
|
from .client import OP_MESSAGE
|
|
from .base import ERR_IO, ERR_TIMEOUT, RESPONSE_OK, RPC_NOTIFICATION_HEADER, RPC_REQUEST_HEADER, RPC_REPLY_HEADER, \
|
|
RPC_ERROR_REPLY_HEADER, RPC_NOTIFICATION, RPC_REQUEST, RPC_REPLY, RPC_ERROR, RPC_ERROR_CODE_METHOD_NOT_FOUND, \
|
|
RPC_ERROR_CODE_INTERNAL, serialize, deserialize,RpcException
|
|
from .client import on_frame_default
|
|
|
|
|
|
def on_call_default(event):
|
|
raise RpcException('RPC Engine not initialized',
|
|
RPC_ERROR_CODE_METHOD_NOT_FOUND)
|
|
|
|
|
|
def on_notification_default(event):
|
|
pass
|
|
|
|
|
|
def format_rpc_e_msg(e):
|
|
if isinstance(e, RpcException):
|
|
return e.rpc_error_payload
|
|
else:
|
|
return str(e)
|
|
|
|
|
|
|
|
class Rpc:
|
|
|
|
def __init__(self, client):
|
|
self.client = client
|
|
self.client.on_frame = self._handle_frame
|
|
self.call_id = 0
|
|
self.call_lock = threading.Lock()
|
|
self.calls = {}
|
|
self.on_frame = on_frame_default
|
|
self.on_call = on_call_default
|
|
self.on_notification = on_notification_default
|
|
|
|
def is_connected(self):
|
|
return self.client.connected
|
|
|
|
def notify(self, target, notification):
|
|
return self.client.send(target, notification)
|
|
|
|
def call0(self, target, request):
|
|
request.header = RPC_REQUEST_HEADER + b'\x00\x00\x00\x00' + \
|
|
request.method + b'\x00'
|
|
return self.client.send(target, request)
|
|
|
|
def call(self, target, request):
|
|
with self.call_lock:
|
|
call_id = self.call_id + 1
|
|
if call_id == 0xffff_ffff:
|
|
self.call_id = 0
|
|
else:
|
|
self.call_id = call_id
|
|
call_event = RpcCallEvent()
|
|
self.calls[call_id] = call_event
|
|
request.header = RPC_REQUEST_HEADER + call_id.to_bytes(
|
|
4, 'little') + request.method + b'\x00'
|
|
try:
|
|
try:
|
|
code = self.client.send(
|
|
target, request).wait_completed(timeout=self.client.timeout)
|
|
if code != RESPONSE_OK:
|
|
try:
|
|
del self.calls[call_id]
|
|
except KeyError:
|
|
pass
|
|
err_code = -32000 - code
|
|
call_event.error = RpcException('RPC error', code=err_code)
|
|
call_event.completed.set()
|
|
except TimeoutError:
|
|
try:
|
|
del self.calls[call_id]
|
|
except KeyError:
|
|
pass
|
|
err_code = -32000 - ERR_TIMEOUT
|
|
call_event.error = RpcException('RPC timeout', code=err_code)
|
|
call_event.completed.set()
|
|
except Exception as e:
|
|
try:
|
|
del self.calls[call_id]
|
|
except KeyError:
|
|
pass
|
|
call_event.error = RpcException(str(e), code=-32000 - ERR_IO)
|
|
call_event.completed.set()
|
|
return call_event
|
|
|
|
def __getattr__(self, method):
|
|
return functools.partial(self.params_call, method=method)
|
|
|
|
def params_call(self, target: str, method: str, **kwargs):
|
|
params = kwargs
|
|
if '_raw' in kwargs:
|
|
params = kwargs.pop('_raw')
|
|
params = params or None
|
|
c0 = False
|
|
if method.endswith('0'):
|
|
c0 = True
|
|
method = method[:-1]
|
|
request = Request(method, serialize(params))
|
|
if c0:
|
|
self.call0(target, request)
|
|
return None
|
|
result = self.call(target, request)
|
|
return deserialize((result.wait_completed()).get_payload())
|
|
|
|
def _handle_frame(self, frame):
|
|
self.spawn(self._t_handler, frame)
|
|
|
|
def spawn(self, f, *args, **kwargs):
|
|
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
|
|
|
def _t_handler(self, frame):
|
|
try:
|
|
if frame.type == OP_MESSAGE:
|
|
if frame.payload[0] == RPC_NOTIFICATION:
|
|
event = Event(RPC_NOTIFICATION, frame, 1)
|
|
self.on_notification(event)
|
|
elif frame.payload[0] == RPC_REQUEST:
|
|
sender = frame.sender
|
|
call_id_b = frame.payload[1:5]
|
|
call_id = int.from_bytes(call_id_b, 'little')
|
|
method = frame.payload[5:5 +
|
|
frame.payload[5:].index(b'\x00')]
|
|
event = Event(RPC_REQUEST, frame, 6 + len(method))
|
|
event.call_id = call_id
|
|
event.method = method
|
|
if call_id == 0:
|
|
self.on_call(event)
|
|
else:
|
|
reply = Reply()
|
|
try:
|
|
reply.payload = self.on_call(event)
|
|
if reply.payload is None:
|
|
reply.payload = b''
|
|
reply.header = RPC_REPLY_HEADER + call_id_b
|
|
except Exception as e:
|
|
try:
|
|
code = e.rpc_error_code
|
|
except AttributeError:
|
|
code = RPC_ERROR_CODE_INTERNAL
|
|
reply.header = (
|
|
RPC_ERROR_REPLY_HEADER + call_id_b +
|
|
code.to_bytes(2, 'little', signed=True))
|
|
reply.payload = format_rpc_e_msg(e)
|
|
self.client.send(sender, reply)
|
|
elif frame.payload[0] == RPC_REPLY or frame.payload[
|
|
0] == RPC_ERROR:
|
|
call_id = int.from_bytes(frame.payload[1:5], 'little')
|
|
try:
|
|
call_event = self.calls.pop(call_id)
|
|
call_event.frame = frame
|
|
if frame.payload[0] == RPC_ERROR:
|
|
err_code = int.from_bytes(frame.payload[5:7],
|
|
'little',
|
|
signed=True)
|
|
call_event.error = RpcException(frame.payload[7:],
|
|
code=err_code)
|
|
call_event.completed.set()
|
|
except KeyError:
|
|
logger.warning(f'orphaned RPC response: {call_id}')
|
|
else:
|
|
logger.error(f'Invalid RPC frame code: {frame.payload[0]}')
|
|
else:
|
|
self.on_frame(frame)
|
|
except Exception:
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
|
class RpcCallEvent:
|
|
|
|
def __init__(self):
|
|
self.frame = None
|
|
self.error = None
|
|
self.completed = threading.Event()
|
|
|
|
def is_completed(self):
|
|
return self.completed.is_set()
|
|
|
|
def wait_completed(self, *args, **kwargs):
|
|
if not self.completed.wait(*args, **kwargs):
|
|
raise TimeoutError
|
|
if self.error is None:
|
|
return self
|
|
else:
|
|
raise self.error
|
|
|
|
def get_payload(self):
|
|
return self.frame.payload[5:]
|
|
|
|
def is_empty(self):
|
|
return len(self.frame.payload) <= 5
|
|
|
|
|
|
class Event:
|
|
|
|
def __init__(self, tp, frame, payload_pos):
|
|
self.tp = tp
|
|
self.frame = frame
|
|
self._payload_pos = payload_pos
|
|
|
|
def get_payload(self):
|
|
return self.frame.payload[self._payload_pos:]
|
|
|
|
|
|
class Notification:
|
|
|
|
def __init__(self, payload=b''):
|
|
self.payload = payload
|
|
self.type = OP_MESSAGE
|
|
self.qos = 1
|
|
self.header = RPC_NOTIFICATION_HEADER
|
|
|
|
|
|
class Request:
|
|
|
|
def __init__(self, method, params=b''):
|
|
self.payload = b'' if params is None else params
|
|
self.type = OP_MESSAGE
|
|
self.qos = 1
|
|
self.method = method.encode() if isinstance(method, str) else method
|
|
self.header = None
|
|
|
|
|
|
class Reply:
|
|
|
|
def __init__(self, result=b''):
|
|
self.payload = result
|
|
self.type = OP_MESSAGE
|
|
self.qos = 1
|
|
self.header = None
|