You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
246 lines
7.5 KiB
Python
246 lines
7.5 KiB
Python
import asyncio
|
|
import sys
|
|
from abc import ABC
|
|
from functools import wraps
|
|
from inspect import Parameter as InspectParameter, isclass, signature
|
|
from typing import Any, Callable, Dict, ForwardRef, NewType, Tuple, Type, TypeVar, Union # type: ignore
|
|
|
|
from typing_extensions import Protocol
|
|
|
|
from .container import Container, di
|
|
from .errors import ExecutionError
|
|
|
|
T = TypeVar("T")
|
|
S = TypeVar("S")
|
|
|
|
ServiceDefinition = Union[Type[S], Callable]
|
|
ServiceResult = Union[S, Callable]
|
|
|
|
Undefined = NewType("Undefined", int)
|
|
|
|
|
|
class _ProtocolInit(Protocol):
|
|
pass
|
|
|
|
|
|
_no_init = _ProtocolInit.__init__
|
|
|
|
|
|
def _resolve_forward_reference(module: Any, ref: Union[str, ForwardRef]) -> Any:
|
|
if isinstance(ref, str):
|
|
name = ref
|
|
else:
|
|
name = ref.__forward_arg__
|
|
|
|
if name in sys.modules[module].__dict__:
|
|
return sys.modules[module].__dict__[name]
|
|
|
|
return None
|
|
|
|
|
|
class Parameter:
|
|
type: Any
|
|
name: str
|
|
default: Any
|
|
|
|
def __init__(self, name: str, type: Any = Any, default: Any = Undefined):
|
|
self.name = name
|
|
self.type = type
|
|
self.default = default
|
|
|
|
|
|
def _inspect_function_arguments(
|
|
function: Callable,
|
|
) -> Tuple[Tuple[str, ...], Dict[str, Parameter]]:
|
|
parameters_name: Tuple[str, ...] = tuple(signature(function).parameters.keys())
|
|
parameters = {}
|
|
|
|
for name, parameter in signature(function).parameters.items():
|
|
|
|
if isinstance(parameter.annotation, (str, ForwardRef)) and hasattr(function, "__module__"):
|
|
annotation = _resolve_forward_reference(function.__module__, parameter.annotation)
|
|
else:
|
|
annotation = parameter.annotation
|
|
|
|
parameters[name] = Parameter(
|
|
parameter.name,
|
|
annotation,
|
|
parameter.default if parameter.default is not InspectParameter.empty else Undefined,
|
|
)
|
|
|
|
return parameters_name, parameters
|
|
|
|
|
|
def _resolve_function_kwargs(
|
|
alias_map: Dict[str, str],
|
|
parameters_name: Tuple[str, ...],
|
|
parameters: Dict[str, Parameter],
|
|
container: Container,
|
|
) -> Dict[str, Any]:
|
|
resolved_kwargs = {}
|
|
for name in parameters_name:
|
|
if name in alias_map and alias_map[name] in container:
|
|
resolved_kwargs[name] = container[alias_map[name]]
|
|
continue
|
|
|
|
if name in container:
|
|
resolved_kwargs[name] = container[name]
|
|
continue
|
|
|
|
if parameters[name].type in container:
|
|
resolved_kwargs[name] = container[parameters[name].type]
|
|
continue
|
|
|
|
if parameters[name].default is not Undefined:
|
|
resolved_kwargs[name] = parameters[name].default
|
|
|
|
return resolved_kwargs
|
|
|
|
|
|
def _decorate(binding: Dict[str, Any], service: ServiceDefinition, container: Container) -> ServiceResult:
|
|
# ignore abstract class initialiser and protocol initialisers
|
|
if (
|
|
service in [ABC.__init__, _no_init] or service.__name__ == "_no_init"
|
|
): # FIXME: fix this when typing_extensions library gets fixed
|
|
return service
|
|
|
|
# Add class definition to dependency injection
|
|
parameters_name, parameters = _inspect_function_arguments(service)
|
|
|
|
def _resolve_kwargs(args, kwargs) -> dict:
|
|
# attach named arguments
|
|
passed_kwargs = {**kwargs}
|
|
|
|
# resolve positional arguments
|
|
if args:
|
|
for key, value in enumerate(args):
|
|
passed_kwargs[parameters_name[key]] = value
|
|
|
|
# prioritise passed kwargs and args resolving
|
|
if len(passed_kwargs) == len(parameters_name):
|
|
return passed_kwargs
|
|
|
|
resolved_kwargs = _resolve_function_kwargs(binding, parameters_name, parameters, container)
|
|
|
|
all_kwargs = {**resolved_kwargs, **passed_kwargs}
|
|
|
|
if len(all_kwargs) < len(parameters_name):
|
|
missing_parameters = [arg for arg in parameters_name if arg not in all_kwargs]
|
|
raise ExecutionError(
|
|
"Cannot execute function without required parameters. "
|
|
+ f"Did you forget to bind the following parameters: `{'`, `'.join(missing_parameters)}`?"
|
|
)
|
|
|
|
return all_kwargs
|
|
|
|
@wraps(service)
|
|
def _decorated(*args, **kwargs):
|
|
# all arguments were passed
|
|
if len(args) == len(parameters_name):
|
|
return service(*args, **kwargs)
|
|
|
|
if parameters_name == tuple(kwargs.keys()):
|
|
return service(**kwargs)
|
|
|
|
all_kwargs = _resolve_kwargs(args, kwargs)
|
|
return service(**all_kwargs)
|
|
|
|
@wraps(service)
|
|
async def _async_decorated(*args, **kwargs):
|
|
# all arguments were passed
|
|
if len(args) == len(parameters_name):
|
|
return await service(*args)
|
|
|
|
if parameters_name == tuple(kwargs.keys()):
|
|
return await service(**kwargs)
|
|
|
|
all_kwargs = _resolve_kwargs(args, kwargs)
|
|
return await service(**all_kwargs)
|
|
|
|
if asyncio.iscoroutinefunction(service):
|
|
return _async_decorated
|
|
|
|
return _decorated
|
|
|
|
|
|
def inject(
|
|
_service: ServiceDefinition = None,
|
|
alias: Any = None,
|
|
bind: Dict[str, Any] = None,
|
|
container: Container = di,
|
|
use_factory: bool = False,
|
|
) -> Union[ServiceResult, Callable[[ServiceDefinition], ServiceResult]]:
|
|
def _decorator(_service: ServiceDefinition) -> ServiceResult:
|
|
if isclass(_service):
|
|
setattr(
|
|
_service,
|
|
"__init__",
|
|
_decorate(bind or {}, getattr(_service, "__init__"), container),
|
|
)
|
|
if use_factory:
|
|
container.factories[_service] = lambda _di: _service()
|
|
if alias:
|
|
container.add_alias(alias, _service)
|
|
else:
|
|
container[_service] = lambda _di: _service()
|
|
if alias:
|
|
container.add_alias(alias, _service)
|
|
|
|
return _service
|
|
|
|
service_function = _decorate(bind or {}, _service, container)
|
|
container[service_function.__name__] = service_function
|
|
if alias:
|
|
container.add_alias(alias, service_function.__name__)
|
|
|
|
return service_function
|
|
|
|
if _service is None:
|
|
return _decorator
|
|
|
|
return _decorator(_service)
|
|
|
|
|
|
def provider(
|
|
_service: ServiceDefinition = None,
|
|
alias: Any = None,
|
|
bind: Dict[str, Any] = None,
|
|
container: Container = di,
|
|
process: bool = False,
|
|
) -> Union[ServiceResult, Callable[[ServiceDefinition], ServiceResult]]:
|
|
def _decorator(_service: ServiceDefinition) -> ServiceResult:
|
|
if isclass(_service):
|
|
setattr(
|
|
_service,
|
|
"__init__",
|
|
_decorate(bind or {}, getattr(_service, "__init__"), container),
|
|
)
|
|
if process:
|
|
container.factories[_service] = lambda _di: _service()
|
|
if alias:
|
|
container.add_alias(alias, _service)
|
|
else:
|
|
container[_service] = lambda _di: _service()
|
|
if alias:
|
|
container.add_alias(alias, _service)
|
|
|
|
return _service
|
|
|
|
service_function = _decorate(bind or {}, _service, container)
|
|
if process:
|
|
container.factories[service_function.__name__] = lambda x: service_function()
|
|
else:
|
|
container[service_function.__name__] = lambda x: service_function()
|
|
if alias:
|
|
container.add_alias(alias, service_function.__name__)
|
|
|
|
return service_function
|
|
|
|
if _service is None:
|
|
return _decorator
|
|
|
|
return _decorator(_service)
|
|
|
|
|
|
__all__ = ["inject", "provider"]
|