Source code for pyplugin.group

from __future__ import annotations
import functools
import typing
from collections.abc import MutableSequence

from pyplugin.base import Plugin, _R, get_registered_plugin, lookup_plugin, get_aliases
from pyplugin.utils import void_args, empty
from pyplugin.exceptions import (
    PluginNotFoundError,
    PluginLoadError,
    PluginUnloadError,
    DependencyError,
)


[docs] class PluginGroup(Plugin[list[_R]], MutableSequence[typing.Union[Plugin[_R], str]]): """ This class groups together plugins under certain guarantees: - If this group is loaded, then every plugin in this group is loaded, and the instance of this plugin is a list of instances in order. - Loading this group will attempt to load all plugins, consequently changing any one of these plugins will also reload this group. - Unloading this group will unload all plugins in this group Note: Plugins in this group may still be loaded individually and separately. The load_ and unload_callables for a PluginGroup are of a different form than normal. They are written in contextlib.contextmanager style with a single yield statement. The load callable is passed the list of plugins in addition to the load args and load kwargs. It may yield back these three things which will be used to determine load order and load args. The unload_callable similarly is passed the list of instances along with unload args / kwargs. Attributes: plugins (list[Plugin | str]): The list of plugins in this group Arguments: plugins (Iterable[Plugin | str]): The plugins to initialize this group with """ def __init__( self, plugin: typing.Callable = void_args, unload_callable: typing.Callable = void_args, plugins: typing.Iterable[Plugin[_R] | str] = None, **kwargs, ): self.plugins: list[typing.Union[Plugin[_R], str]] = list(plugins) if plugins else [] super().__init__( plugin, unload_callable=unload_callable, **kwargs, ) self._load_callable = functools.partial(self._group_load, self._load_callable) self._unload_callable = functools.partial(self._group_unload, self._unload_callable) def _handle_enforce_type(self, instance, type_=None, is_class_type=None): type_ = type_ if type_ else self.type is_class_type = is_class_type if is_class_type else self.is_class_type if self.enforce_type and type_: for plugin in instance: super()._handle_enforce_type(plugin, type_=type_, is_class_type=is_class_type) def _load_dependencies(self, kwargs): ret = {} for dest, plugin in self.dependencies.copy().items(): if plugin in self: continue if dest in kwargs: continue ret[dest] = plugin.load(conflict_strategy="keep_existing") return ret def _group_load(self, load_callable, *args, **kwargs) -> list[_R]: plugins, args, kwargs = self.plugins, args, kwargs gen = load_callable(self.plugins, *args, **kwargs) ret = None if gen: try: ret = next(gen) except StopIteration as err: ret = err.value if ret: if not (isinstance(ret, typing.Sequence) and len(ret) == 3 and isinstance(ret[0], typing.Iterable)): ret = ret, args, kwargs plugins_, args_, kwargs_ = ret plugins, args, kwargs = ( plugins_ if plugins_ else plugins, args_ if args_ else args, kwargs_ if kwargs_ else kwargs, ) kwargs.setdefault("safe_args", True) kwargs.setdefault("conflict_strategy", "keep_existing") ret = [] for plugin in plugins: if not isinstance(plugin, Plugin): try: plugin = lookup_plugin(plugin, import_lookup=self._settings["import_lookup"]) except PluginNotFoundError as err: raise PluginLoadError(f"{self.get_full_name()}: Could not find plugin in group {plugin}") from err instance = plugin.load(*args, **kwargs) if not self.type and self.infer_type: self._set_type_from_instance(instance) ret.append(instance) if gen: try: next(gen) except StopIteration: pass return ret def _group_unload(self, unload_callable, instance, *args, **kwargs) -> list[typing.Any]: plugins, instance = self.plugins, instance gen = unload_callable(self.plugins, instance) ret = None if gen: try: ret = next(gen) except StopIteration as err: ret = err.value if ret: if not (isinstance(ret, typing.Sequence) and len(ret) == 2 and isinstance(ret[0], typing.Iterable)): ret = ret, instance, args, kwargs plugins_, instance_, args_, kwargs_ = ret plugins, instance, args, kwargs = ( plugins_ if plugins_ else plugins, instance_ if instance_ is not empty else instance, args_ if args_ else args, kwargs_ if kwargs_ else kwargs, ) ret = [] for plugin in reversed(plugins): if not isinstance(plugin, Plugin): try: plugin = lookup_plugin(plugin, import_lookup=self._settings["import_lookup"]) except PluginNotFoundError as err: raise PluginUnloadError(f"{self.get_full_name()}: Could not find plugin in group {plugin}") from err ret.append(plugin._unload(*args, _unload_dependents=False, **kwargs)) if gen: try: next(gen) except StopIteration: pass return ret def _set_type(self, plugin: PluginGroup = None) -> typing.Optional[typing.Type]: if not plugin: plugin = self for plugin_ in plugin: if isinstance(plugin_, Plugin): return super()._set_type(plugin=plugin) return None def _add(self, value: typing.Union[Plugin[_R], str]): if self.enforce_type and self.type and value.type: super()._handle_enforce_type(value.type, type_=self.type, is_class_type=True) self.add_requirement(value) def _infer_type_from(self, value: typing.Union[Plugin[_R], str]): if not self.type and self.infer_type and isinstance(value, Plugin): super()._set_type(plugin=value) add = MutableSequence.append
[docs] def safe_add(self, value: typing.Union[Plugin[_R], str]): """ Adds the given plugin to this group, unloading first before adding and then reloading if it was loaded. Arguments: value (Plugin | str): The plugin to add """ is_loaded = self.is_loaded() if is_loaded: self.unload() self.append(value) if is_loaded: self.load()
def __getitem__(self, index: int) -> Plugin[_R]: return self.plugins[index] def __setitem__(self, index: int, value: typing.Union[Plugin[_R], str]): self._add(value) self.plugins[index] = value self._infer_type_from(value) def __delitem__(self, index: int): self.plugins.pop(index) def __len__(self) -> int: return len(self.plugins)
[docs] def insert(self, index: int, value: typing.Union[Plugin[_R], str]): self._add(value) self.plugins.insert(index, value) self._infer_type_from(value)
def __contains__(self, plugin: typing.Union[Plugin, str]) -> bool: if super(MutableSequence, self).__contains__(plugin): return True if isinstance(plugin, Plugin) and any( super(MutableSequence, self).__contains__(alias) for alias in get_aliases(plugin) ): return True if not isinstance(plugin, str): return False try: plugin = get_registered_plugin(plugin) except PluginNotFoundError: return False else: return super(MutableSequence, self).__contains__(plugin)