You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
3.2 KiB
Python

from types import LambdaType
from typing import Any, Callable, Dict, List, Type, Union
from .errors.service_error import ServiceError
_MISSING_SERVICE = object()
class Container:
def __init__(self):
self._memoized_services: Dict[Union[str, Type], Any] = {}
self._services: Dict[Union[str, Type], Any] = {}
self._factories: Dict[Union[str, Type], Callable[[Container], Any]] = {}
self._aliases: Dict[Union[str, Type], List[Union[str, Type]]] = {}
def __setitem__(self, key: Union[str, Type], value: Any) -> None:
self._services[key] = value
if key in self._memoized_services:
del self._memoized_services[key]
def add_alias(self, name: Union[str, Type], target: Union[str, Type]):
if List[target] in self._memoized_services: # type: ignore
del self._memoized_services[List[target]] # type: ignore
if name not in self._aliases:
self._aliases[name] = []
self._aliases[name].append(target)
def __getitem__(self, key: Union[str, Type]) -> Any:
if key in self._factories:
return self._factories[key](self)
service = self._get(key)
if service is not _MISSING_SERVICE:
return service
if key in self._aliases:
unaliased_key = self._aliases[key][0] # By default return first aliased service
if unaliased_key in self._factories:
return self._factories[unaliased_key](self)
service = self._get(unaliased_key)
# service = self._get(self._aliases[key][0]) # By default return first aliased service
if service is not _MISSING_SERVICE:
return service
# Support aliasing
if self._has_alias_list_for(key):
result = [self._get(alias) for alias in self._aliases[key.__args__[0]]] # type: ignore
self._memoized_services[key] = result
return result
raise ServiceError(f"Service {key} is not registered.")
def _get(self, key: Union[str, Type]) -> Any:
if key in self._memoized_services:
return self._memoized_services[key]
if key not in self._services:
return _MISSING_SERVICE
value = self._services[key]
if isinstance(value, LambdaType) and value.__name__ == "<lambda>":
self._memoized_services[key] = value(self)
return self._memoized_services[key]
return value
def __contains__(self, key) -> bool:
contains = key in self._services or key in self._factories or key in self._aliases
if contains:
return contains
if self._has_alias_list_for(key):
return True
return False
def _has_alias_list_for(self, key: Union[str, Type]) -> bool:
return hasattr(key, "__origin__") and hasattr(key, "__args__") and key.__origin__ == list and key.__args__[0] in self._aliases # type: ignore
@property
def factories(self) -> Dict[Union[str, Type], Callable[["Container"], Any]]:
return self._factories
def clear_cache(self) -> None:
self._memoized_services = {}
di: Container = Container()
__all__ = ["Container", "di"]