You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

389 lines
12 KiB
Python

import abc
import orjson as json
import logging
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from functools import wraps
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
)
from amqpstorm import Message
from amqpworker.easyqueue.connection import AMQPConnection
from amqpworker.easyqueue.exceptions import UndecodableMessageException
from amqpworker.easyqueue.message import AMQPMessage
class DeliveryModes:
NON_PERSISTENT = 1
PERSISTENT = 2
class ConnType(Enum):
CONSUME = auto()
WRITE = auto()
class BaseQueue(metaclass=abc.ABCMeta):
def __init__(
self,
host: str,
username: str,
password: str,
port: int = 5672,
virtual_host: str = "/",
heartbeat: int = 60,
) -> None:
self.host = host
self.port = port
self.username = username
self.password = password
self.virtual_host = virtual_host
self.heartbeat = heartbeat
@abc.abstractmethod
def serialize(self, body: Any, **kwargs) -> str:
raise NotImplementedError
@abc.abstractmethod
def deserialize(self, body: bytes) -> Any:
raise NotImplementedError
def _parse_message(self, content) -> Dict[str, Any]:
"""
Gets the raw message body as an input, handles deserialization and
outputs
:param content: The raw message body
"""
try:
return self.deserialize(content)
except TypeError:
return self.deserialize(content.decode())
except json.JSONDecodeError as e:
raise UndecodableMessageException(
'"{content}" can\'t be decoded as JSON'.format(content=content)
)
class BaseJsonQueue(BaseQueue):
content_type = "application/json"
def serialize(self, body: Any, **kwargs) -> str:
return json.dumps(body, **kwargs).decode('utf-8')
def deserialize(self, body: bytes) -> Any:
return json.loads(body)
def _ensure_conn_is_ready(conn_type: ConnType):
def _ensure_connected(coro: Callable[..., Any]):
@wraps(coro)
def wrapper(self: "JsonQueue", *args, **kwargs):
conn = self.conn_types[conn_type]
retries = 0
while self.is_running and not conn.has_channel_ready():
try:
conn._connect()
break
except Exception as e:
time.sleep(self.seconds_between_conn_retry)
retries += 1
if self.connection_fail_callback:
self.connection_fail_callback(e, retries)
if self.logger:
self.logger.error(
{
"event": "reconnect-failure",
"retry_count": retries,
"exc_traceback": traceback.format_tb(
e.__traceback__
),
}
)
return coro(self, *args, **kwargs)
return wrapper
return _ensure_connected
T = TypeVar("T")
class _ConsumptionHandler:
def __init__(
self,
delegate: "QueueConsumerDelegate",
queue: "JsonQueue",
queue_name: str,
) -> None:
self.delegate = delegate
self.queue = queue
self.queue_name = queue_name
self.consumer_tag: Optional[str] = None
def _handle_callback(self, callback, **kwargs):
"""
Chains the callback coroutine into a try/except and calls
`on_message_handle_error` in case of failure, avoiding unhandled
exceptions.
:param callback:
:param kwargs:
:return:
"""
try:
return callback(**kwargs)
except Exception as e:
return self.delegate.on_message_handle_error(
handler_error=e, **kwargs
)
def handle_message(
self,
message: Message
):
msg = AMQPMessage(
connection=self.queue.connection,
channel=message.channel,
queue=self.queue,
properties=message.properties,
delivery_tag=message.delivery_tag,
deserialization_method=self.queue.deserialize,
queue_name=self.queue_name,
serialized_data=message.body,
)
callback = self._handle_callback(
self.delegate.on_queue_message, msg=msg # type: ignore
)
return callback
class JsonQueue(BaseQueue, Generic[T]):
def __init__(
self,
host: str,
username: str,
password: str,
port: int = 5672,
delegate_class: Optional[Type["QueueConsumerDelegate"]] = None,
delegate: Optional["QueueConsumerDelegate"] = None,
virtual_host: str = "/",
heartbeat: int = 60,
prefetch_count: int = 100,
seconds_between_conn_retry: int = 1,
logger: Optional[logging.Logger] = None,
connection_fail_callback: Optional[
Callable[[Exception, int], None]
] = None,
) -> None:
super().__init__(host, username, password, port, virtual_host, heartbeat)
if delegate is not None and delegate_class is not None:
raise ValueError("Cant provide both delegate and delegate_class")
if delegate_class is not None:
self.delegate = delegate_class()
else:
self.delegate = delegate # type: ignore
self.prefetch_count = prefetch_count
on_error = self.delegate.on_connection_error if self.delegate else None
self.connection = AMQPConnection(
host=host,
port=port,
username=username,
password=password,
virtual_host=virtual_host,
heartbeat=heartbeat,
on_error=on_error,
)
self._write_connection = AMQPConnection(
host=host,
port=port,
username=username,
password=password,
virtual_host=virtual_host,
heartbeat=heartbeat,
on_error=on_error,
)
self.conn_types = {
ConnType.CONSUME: self.connection,
ConnType.WRITE: self._write_connection,
}
self.seconds_between_conn_retry = seconds_between_conn_retry
self.is_running = True
self.logger = logger
self.connection_fail_callback = connection_fail_callback
def serialize(self, body: T, **kwargs) -> str:
return json.dumps(body, **kwargs).decode('utf-8')
def deserialize(self, body: bytes) -> T:
return json.loads(body)
def conn_for(self, type: ConnType) -> AMQPConnection:
return self.conn_types[type]
@_ensure_conn_is_ready(ConnType.WRITE)
def put(
self,
routing_key: str,
data: Any = None,
serialized_data: Union[str, bytes] = "",
exchange: str = "",
properties: Optional[dict] = None,
mandatory: bool = False,
immediate: bool = False,
):
"""
:param data: A serializable data that should be serialized before
publishing
:param serialized_data: A payload to be published as is
:param exchange: The exchange to publish the message
:param routing_key: The routing key to publish the message
"""
if data and serialized_data:
raise ValueError("Only one of data or json should be specified")
if data:
if isinstance(data, (str, bytes)):
serialized_data = data
else:
serialized_data = self.serialize(data)
if properties is None:
properties = {'content_type': 'application/json'}
elif not properties.get('Content-Type'):
properties['content_type'] = 'application/json'
if not isinstance(serialized_data, bytes):
serialized_data = serialized_data.encode()
return self._write_connection.channel.basic.publish(
body=serialized_data,
exchange=exchange,
routing_key=routing_key,
properties=properties,
mandatory=mandatory,
immediate=immediate,
)
@_ensure_conn_is_ready(ConnType.CONSUME)
def consume(
self,
queue_name: str,
pool: ThreadPoolExecutor,
delegate: "QueueConsumerDelegate",
consumer_name: str = "",
) -> str:
"""
Connects the client if needed and starts queue consumption, sending
`on_before_start_consumption` and `on_consumption_start` notifications
to the delegate object
:param queue_name: queue name to consume from
:param consumer_name: An optional name to be used as a consumer
identifier. If one isn't provided, a random one is generated by the
broker
:return: The consumer tag. Useful for cancelling/stopping consumption
"""
# todo: Implement a consumer tag generator
handler = _ConsumptionHandler(
delegate=delegate, queue=self, queue_name=queue_name
)
delegate.on_before_start_consumption(
queue_name=queue_name, queue=self
)
self.connection.channel.basic.qos(
prefetch_count=self.prefetch_count,
prefetch_size=0,
global_=False,
)
consumer_tag = self.connection.channel.basic.consume(
callback=handler.handle_message,
consumer_tag=consumer_name,
queue=queue_name,
)
delegate.on_consumption_start(
consumer_tag=consumer_tag, queue=self
)
def start():
try:
self.connection.channel.start_consuming(auto_decode=False)
except:
traceback.print_exc()
pool.submit(start)
handler.consumer_tag = consumer_tag
return consumer_tag
class QueueConsumerDelegate(metaclass=abc.ABCMeta):
def on_before_start_consumption(
self, queue_name: str, queue: JsonQueue
):
"""
Coroutine called before queue consumption starts. May be overwritten to
implement further custom initialization.
:param queue_name: Queue name that will be consumed
:type queue_name: str
:param queue: AsynQueue instanced
:type queue: JsonQueue
"""
pass
def on_consumption_start(self, consumer_tag: str, queue: JsonQueue):
"""
Coroutine called once consumption started.
"""
@abc.abstractmethod
def on_queue_message(self, msg: AMQPMessage[Any]):
"""
Callback called every time that a new, valid and deserialized message
is ready to be handled.
:param msg: the consumed message
"""
raise NotImplementedError
def on_message_handle_error(self, handler_error: Exception, **kwargs):
"""
Callback called when an uncaught exception was raised during message
handling stage.
:param handler_error: The exception that triggered
:param kwargs: arguments used to call the coroutine that handled
the message
:return:
"""
pass
def on_connection_error(self, exception: Exception):
"""
Called when the connection fails
"""
pass