You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
278 lines
8.2 KiB
Python
278 lines
8.2 KiB
Python
1 year ago
|
import abc
|
||
|
import asyncio
|
||
|
import dataclasses
|
||
|
import functools
|
||
|
import signal
|
||
|
from collections import UserList
|
||
|
from enum import Enum, auto
|
||
|
from typing import Callable, Coroutine, Dict, List, Optional
|
||
|
|
||
|
from loguru import logger
|
||
|
|
||
|
from busrtworker.boostrap import RpcBoot
|
||
|
from busrtworker.busrt import OP_PUBLISH, Client, Frame, Rpc, serialize
|
||
|
from busrtworker.scheduler import ScheduledTaskRunner
|
||
|
from busrtworker.tree import RadixTree
|
||
|
|
||
|
|
||
|
@dataclasses.dataclass
|
||
|
class ConnectionInfo:
|
||
|
name: str = dataclasses.field()
|
||
|
uri: str = dataclasses.field()
|
||
|
client_prefix: str = dataclasses.field()
|
||
|
static: bool = dataclasses.field()
|
||
|
topic: str | None = dataclasses.field(default=None)
|
||
|
init_rpc: bool = dataclasses.field(default=True)
|
||
|
bus: Client = dataclasses.field(default=None, init=False)
|
||
|
rpc: Rpc = dataclasses.field(default=None, init=False)
|
||
|
final_name: str = dataclasses.field(default=None, init=False)
|
||
|
|
||
|
def __getattr__(self, item):
|
||
|
if not self.init_rpc:
|
||
|
raise ValueError('must be init rpc client could call')
|
||
|
return getattr(self.rpc, item)
|
||
|
|
||
|
async def send(self, topic, data=None, decode=True):
|
||
|
bus: Client = self.bus
|
||
|
await bus.send(topic, Frame(serialize(data) if decode else data, tp=OP_PUBLISH))
|
||
|
|
||
|
|
||
|
class Router:
|
||
|
table: dict = {}
|
||
|
tree: RadixTree = RadixTree()
|
||
|
|
||
|
def insert(self, path, handler, dynamic=False):
|
||
|
if not dynamic:
|
||
|
if path in self.table:
|
||
|
raise ValueError(f'conflict route {path}')
|
||
|
self.table[path] = handler
|
||
|
else:
|
||
|
self.tree.insert(path, handler, ['RPC'])
|
||
|
|
||
|
def get(self, path):
|
||
|
if path in self.table:
|
||
|
return True, self.table[path], {}
|
||
|
return self.tree.get(path, 'RPC')
|
||
|
|
||
|
|
||
|
class ServiceEntrypoint:
|
||
|
def __init__(self, connection: ConnectionInfo, app: 'App'):
|
||
|
self.name = connection.name
|
||
|
self.app = app
|
||
|
if self.name not in self.app.callers:
|
||
|
self.app.callers[self.name] = {}
|
||
|
if self.name not in self.app.subscribes:
|
||
|
self.app.subscribes[self.name] = Router()
|
||
|
|
||
|
def on_call(self, method=None, auto_decode=True, raw=False):
|
||
|
def _warp(f):
|
||
|
target = method or (f.func.__name__ if isinstance(f, functools.partial) else f.__name__)
|
||
|
self.app.callers[self.name][target] = (f, auto_decode, asyncio.iscoroutinefunction(f), raw)
|
||
|
return f
|
||
|
|
||
|
return _warp
|
||
|
|
||
|
def subscribe(self, topic, auto_decode=True, raw=False):
|
||
|
assert isinstance(topic, str), 'topic must be str or callable'
|
||
|
|
||
|
def _warp(f):
|
||
|
self.app.subscribes[self.name].insert(topic, (f, auto_decode, asyncio.iscoroutinefunction(f), raw),
|
||
|
'/:' in topic)
|
||
|
return f
|
||
|
|
||
|
return _warp
|
||
|
|
||
|
|
||
|
class Freezable(metaclass=abc.ABCMeta):
|
||
|
def __init__(self):
|
||
|
self._frozen = False
|
||
|
|
||
|
@property
|
||
|
def frozen(self) -> bool:
|
||
|
return self._frozen
|
||
|
|
||
|
async def freeze(self):
|
||
|
self._frozen = True
|
||
|
|
||
|
|
||
|
class Signal(UserList, asyncio.Event):
|
||
|
"""
|
||
|
Coroutine-based signal implementation tha behaves as an `asyncio.Event`.
|
||
|
|
||
|
To connect a callback to a signal, use any list method.
|
||
|
|
||
|
Signals are fired using the send() coroutine, which takes named
|
||
|
arguments.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, owner: Freezable) -> None:
|
||
|
UserList.__init__(self)
|
||
|
asyncio.Event.__init__(self)
|
||
|
self._owner = owner
|
||
|
self.frozen = False
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "<Signal owner={}, frozen={}, {!r}>".format(
|
||
|
self._owner, self.frozen, list(self)
|
||
|
)
|
||
|
|
||
|
async def send(self, *args, **kwargs):
|
||
|
"""
|
||
|
Sends data to all registered receivers.
|
||
|
"""
|
||
|
if self.frozen:
|
||
|
raise RuntimeError("Cannot send on frozen signal.")
|
||
|
|
||
|
for receiver in self:
|
||
|
await receiver(*args, **kwargs)
|
||
|
|
||
|
self.frozen = True
|
||
|
await self._owner.freeze()
|
||
|
self.set()
|
||
|
|
||
|
|
||
|
async def on_frame_default(app, frame):
|
||
|
logger.opt(lazy=True).debug('{x}',
|
||
|
x=lambda: f"default print 'Frame:', {hex(frame.type)}, frame.sender, frame.topic, frame.payload")
|
||
|
|
||
|
|
||
|
async def on_call_default(app, event):
|
||
|
logger.opt(lazy=True).debug('{x}',
|
||
|
x=lambda: f"default print 'Rpc:', {event.frame.sender}, {event.method}, {event.get_payload()}")
|
||
|
|
||
|
|
||
|
def entrypoint(f):
|
||
|
@functools.wraps(f)
|
||
|
def _(*args, **kwargs):
|
||
|
try:
|
||
|
loop = asyncio.get_event_loop()
|
||
|
return loop.run_until_complete(f(*args, **kwargs))
|
||
|
except KeyboardInterrupt:
|
||
|
pass
|
||
|
|
||
|
return _
|
||
|
|
||
|
|
||
|
class AutoNameEnum(str, Enum):
|
||
|
def _generate_next_value_( # type: ignore
|
||
|
name: str, start: int, count: int, last_values: List[str]
|
||
|
) -> str:
|
||
|
return name.lower()
|
||
|
|
||
|
|
||
|
class Options(AutoNameEnum):
|
||
|
MAX_CONCURRENCY = auto()
|
||
|
|
||
|
|
||
|
class DefaultValues:
|
||
|
RUN_EVERY_MAX_CONCURRENCY = 1
|
||
|
|
||
|
|
||
|
class App(Freezable):
|
||
|
callers: dict[str, callable] = {}
|
||
|
subscribes: dict[str, Router] = {}
|
||
|
connections: Dict[str, ConnectionInfo] = {}
|
||
|
closeable = []
|
||
|
on_frame_default: callable = on_frame_default
|
||
|
on_call_default: callable = on_call_default
|
||
|
task_runners = []
|
||
|
|
||
|
def __init__(self):
|
||
|
Freezable.__init__(self)
|
||
|
self.boot = RpcBoot()
|
||
|
self._on_startup: Signal = Signal(self)
|
||
|
self._on_shutdown: Signal = Signal(self)
|
||
|
self._on_startup.append(self.boot.startup)
|
||
|
self._on_shutdown.append(self.boot.shutdown)
|
||
|
signal.signal(signal.SIGINT, self.shutdown)
|
||
|
signal.signal(signal.SIGTERM, self.shutdown)
|
||
|
|
||
|
def registry(self, connection: ConnectionInfo):
|
||
|
if self.frozen:
|
||
|
raise RuntimeError(
|
||
|
"You shouldn't change the state of a started application"
|
||
|
)
|
||
|
self.connections[connection.name] = connection
|
||
|
return ServiceEntrypoint(connection, self)
|
||
|
|
||
|
def set_on_frame_default(self, on_frame: callable):
|
||
|
self.on_frame_default = on_frame
|
||
|
|
||
|
def set_on_call_default(self, on_call: callable):
|
||
|
self.on_call_default = on_call
|
||
|
|
||
|
@entrypoint
|
||
|
async def run(self):
|
||
|
if self.frozen:
|
||
|
raise RuntimeError(
|
||
|
"You shouldn't change the state of a started application"
|
||
|
)
|
||
|
logger.debug({"event": "Booting App..."})
|
||
|
await self.startup()
|
||
|
|
||
|
await self._on_shutdown.wait()
|
||
|
|
||
|
async def startup(self):
|
||
|
"""
|
||
|
Causes on_startup signal
|
||
|
|
||
|
Should be called in the event loop along with the request handler.
|
||
|
"""
|
||
|
await self._on_startup.send(self)
|
||
|
|
||
|
def shutdown(self, *args) -> asyncio.Future:
|
||
|
"""
|
||
|
Schedules an on_startup signal
|
||
|
|
||
|
Is called automatically when the application receives
|
||
|
a SIGINT or SIGTERM
|
||
|
"""
|
||
|
logger.debug('do shutdown')
|
||
|
return asyncio.ensure_future(self._on_shutdown.send(self))
|
||
|
|
||
|
def run_on_startup(self, coro: Callable[["App"], Coroutine]) -> None:
|
||
|
"""
|
||
|
Registers a coroutine to be awaited for during app startup
|
||
|
"""
|
||
|
self._on_startup.append(coro)
|
||
|
|
||
|
def run_on_shutdown(self, coro: Callable[["App"], Coroutine]) -> None:
|
||
|
"""
|
||
|
Registers a coroutine to be awaited for during app shutdown
|
||
|
"""
|
||
|
self._on_shutdown.append(coro)
|
||
|
|
||
|
def __getattr__(self, name):
|
||
|
return self.connections[name]
|
||
|
|
||
|
def run_every(self, seconds: int, options: Optional[Dict] = None):
|
||
|
"""
|
||
|
Registers a coroutine to be called with a given interval
|
||
|
"""
|
||
|
if options is None:
|
||
|
options = {}
|
||
|
|
||
|
max_concurrency = options.get(
|
||
|
Options.MAX_CONCURRENCY, DefaultValues.RUN_EVERY_MAX_CONCURRENCY
|
||
|
)
|
||
|
|
||
|
def wrapper(task: Callable[..., Coroutine]):
|
||
|
runner = ScheduledTaskRunner(
|
||
|
seconds=seconds,
|
||
|
task=task,
|
||
|
max_concurrency=max_concurrency,
|
||
|
)
|
||
|
self._on_startup.append(runner.start)
|
||
|
self._on_shutdown.append(runner.stop)
|
||
|
self.task_runners.append(runner)
|
||
|
|
||
|
return task
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
def rpc_running(self, name):
|
||
|
if connection := self.connections.get(name, None):
|
||
|
if rpc := connection.rpc:
|
||
|
return rpc.is_connected()
|
||
|
return False
|