You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
import asyncio
|
|
import functools
|
|
import os
|
|
import traceback
|
|
from typing import TYPE_CHECKING
|
|
|
|
from .busrt import Client, Rpc, RpcException, deserialize, on_frame_default, serialize
|
|
|
|
if TYPE_CHECKING: # pragma: no cover
|
|
from .app import App # noqa: F401
|
|
|
|
|
|
class RpcBoot:
|
|
async def startup(self, app: 'App'):
|
|
for i, (name, connection) in enumerate(app.connections.items()):
|
|
uri = connection.uri
|
|
static = connection.static
|
|
init_rpc = connection.init_rpc
|
|
prefix = connection.client_prefix
|
|
topic = connection.topic
|
|
connection.final_name = f'{prefix}.{name}' if static else f'{prefix}.{os.urandom(4).hex()}'
|
|
bus = Client(uri, connection.final_name)
|
|
connection.bus = bus
|
|
|
|
app.closeable.append(bus.disconnect)
|
|
await bus.connect()
|
|
if topic:
|
|
await bus.subscribe(topic)
|
|
|
|
async def on_frame(frame, router):
|
|
success, route, params = router.get(frame.topic)
|
|
if not success:
|
|
return await app.on_frame_default(frame)
|
|
else:
|
|
func, auto_decode, is_async, raw = route
|
|
if auto_decode:
|
|
payload = deserialize(frame.payload)
|
|
if raw:
|
|
payload['frame'] = frame
|
|
if is_async:
|
|
return await func(**payload, **params)
|
|
return func(**payload, **params)
|
|
else:
|
|
if is_async:
|
|
return await func(frame, **params)
|
|
return func(frame, **params)
|
|
|
|
on_frame_func = functools.partial(on_frame, router=app.subscribes.get(name, None))
|
|
else:
|
|
on_frame_func = on_frame_default
|
|
if init_rpc:
|
|
|
|
async def on_call(event, caller):
|
|
try:
|
|
method = event.method.decode('utf-8')
|
|
caller_tuple = caller.get(method, None)
|
|
if not caller_tuple:
|
|
return serialize(await app.on_call_default(event))
|
|
call, auto_decode, is_async, raw = caller_tuple
|
|
if auto_decode:
|
|
payload = deserialize(event.get_payload())
|
|
if raw:
|
|
payload['event'] = event
|
|
if is_async:
|
|
return serialize(await call(**payload))
|
|
return serialize(call(**payload))
|
|
else:
|
|
if is_async:
|
|
return serialize(await call(event))
|
|
return serialize(call(event))
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
raise RpcException(str(e), 11)
|
|
|
|
rpc_client = Rpc(bus)
|
|
connection.rpc = rpc_client
|
|
if topic:
|
|
rpc_client.on_frame = functools.partial(on_frame_func, router=app.subscribes.get(name, None))
|
|
rpc_client.on_call = functools.partial(on_call, caller=app.callers.get(name, None))
|
|
elif topic:
|
|
bus.on_frame = functools.partial(on_frame_func, router=app.subscribes.get(name, None))
|
|
|
|
async def shutdown(self, app: 'App'):
|
|
gathers = []
|
|
|
|
for func in app.closeable:
|
|
if asyncio.iscoroutinefunction(func):
|
|
gathers.append(func())
|
|
else:
|
|
try:
|
|
func()
|
|
except:
|
|
traceback.print_exc()
|
|
if gathers:
|
|
await asyncio.gather(*gathers, return_exceptions=True)
|