You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
136 lines
3.7 KiB
Python
136 lines
3.7 KiB
Python
import abc
|
|
from collections import UserDict
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Coroutine,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
from cached_property import cached_property
|
|
from pydantic import BaseModel, Extra, root_validator, validator
|
|
|
|
from amqpworker import conf
|
|
from amqpworker.connections import AMQPConnection, Connection
|
|
from amqpworker.options import Actions, DefaultValues, RouteTypes
|
|
|
|
RouteHandler = Callable[..., Coroutine]
|
|
|
|
|
|
class Model(BaseModel, abc.ABC):
|
|
"""
|
|
An abstract pydantic BaseModel that also behaves like a Mapping
|
|
"""
|
|
|
|
def __getitem__(self, item):
|
|
try:
|
|
return getattr(self, item)
|
|
except AttributeError as e:
|
|
raise KeyError from e
|
|
|
|
def __setitem__(self, key, value):
|
|
try:
|
|
return self.__setattr__(key, value)
|
|
except (AttributeError, ValueError) as e:
|
|
raise KeyError from e
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, dict):
|
|
return self.dict() == other
|
|
return super(Model, self).__eq__(other)
|
|
|
|
def __len__(self):
|
|
return len(self.__fields__)
|
|
|
|
def keys(self):
|
|
return self.__fields__.keys()
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
return default
|
|
|
|
|
|
class _RouteOptions(Model):
|
|
pass
|
|
|
|
|
|
class Route(Model, abc.ABC):
|
|
"""
|
|
An abstract Model that acts like a route factory
|
|
"""
|
|
|
|
type: RouteTypes
|
|
handler: Any
|
|
routes: List[str]
|
|
connection: Optional[Connection]
|
|
options: _RouteOptions = _RouteOptions()
|
|
|
|
@staticmethod
|
|
def factory(data: Dict) -> "Route":
|
|
try:
|
|
type_ = data.pop("type")
|
|
except KeyError as e:
|
|
raise ValueError("Routes must have a type") from e
|
|
|
|
if type_ == RouteTypes.HTTP:
|
|
raise ValueError(f"'{type_}' is an invalid RouteType.")
|
|
if type_ == RouteTypes.AMQP_RABBITMQ:
|
|
return AMQPRoute(**data)
|
|
raise ValueError(f"'{type_}' is an invalid RouteType.")
|
|
|
|
|
|
class AMQPRouteOptions(_RouteOptions):
|
|
bulk_size: int = DefaultValues.BULK_SIZE
|
|
max_workers: int = DefaultValues.MAX_SUBMIT_WORKER_SIZE
|
|
bulk_flush_interval: int = DefaultValues.BULK_FLUSH_INTERVAL
|
|
on_success: Actions = DefaultValues.ON_SUCCESS
|
|
on_exception: Actions = DefaultValues.ON_EXCEPTION
|
|
connection_fail_callback: Optional[
|
|
Callable[[Exception, int], Coroutine]
|
|
] = None
|
|
connection: Optional[Union[AMQPConnection, str]]
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = False
|
|
extra = Extra.forbid
|
|
|
|
|
|
class AMQPRoute(Route):
|
|
type: RouteTypes = RouteTypes.AMQP_RABBITMQ
|
|
vhost: str = conf.settings.AMQP_DEFAULT_VHOST
|
|
connection: Optional[AMQPConnection]
|
|
options: AMQPRouteOptions
|
|
|
|
|
|
class RoutesRegistry(UserDict):
|
|
def _get_routes_for_type(self, route_type: Type) -> Iterable:
|
|
return tuple((r for r in self.values() if isinstance(r, route_type)))
|
|
|
|
@cached_property
|
|
def amqp_routes(self) -> Iterable[AMQPRoute]:
|
|
return self._get_routes_for_type(AMQPRoute)
|
|
|
|
def __setitem__(self, key: RouteHandler, value: Union[Dict, Route]):
|
|
if not isinstance(value, Route):
|
|
route = Route.factory({"handler": key, **value})
|
|
else:
|
|
route = value
|
|
|
|
super(RoutesRegistry, self).__setitem__(key, route)
|
|
|
|
def add_route(self, route: Route) -> None:
|
|
self[route.handler] = route
|
|
|
|
def add_amqp_route(self, route: AMQPRoute) -> None:
|
|
self[route.handler] = route
|
|
|
|
def route_for(self, handler: RouteHandler) -> Route:
|
|
return self[handler]
|