Source code for components.impl.component.base

"""Implementation of component base classes.

To easily integrate with disnake-ext-components, it is recommended to inherit
from any of these base classes. In any case, it is very much recommended to at
least use the `ComponentMeta` metaclass. Without this, a lot of internal
functionality will have to be manually re-implemented.
"""

from __future__ import annotations

import sys
import typing

import attr
import disnake
import typing_extensions
from disnake.ext.components import fields as fields
from disnake.ext.components.api import component as component_api
from disnake.ext.components.api import factory as factory_api
from disnake.ext.components.impl import custom_id as custom_id_impl
from disnake.ext.components.impl import factory as factory_impl

__all__: typing.Sequence[str] = ("ComponentBase",)


_T = typing.TypeVar("_T")

MaybeCoroutine = typing.Union[_T, typing.Coroutine[None, None, _T]]


_CountingAttr: type[typing.Any] = type(attr.field())
_AnyAttr: typing_extensions.TypeAlias = "attr.Attribute[typing.Any]"


def _extract_custom_id(interaction: disnake.Interaction) -> str:
    if isinstance(interaction, disnake.ModalInteraction):
        return interaction.custom_id

    elif isinstance(interaction, disnake.MessageInteraction):
        return typing.cast(str, interaction.component.custom_id)  # Guaranteed to exist.

    msg = "The provided interaction object does not have a custom id."
    raise TypeError(msg)


def _is_attrs_pass(namespace: dict[str, typing.Any]) -> bool:
    """Check if attrs has already influenced the class' namespace.

    Note that we check the namespace instead of using `attr.has`, because
    `attr.has` would always return `True` for a class inheriting an attrs class,
    and we specifically need to distinguish between the two passes inside
    `ComponentMeta.__new__`.
    """
    return namespace.get("__attrs_attrs__") is not None


def _is_protocol(cls: type[typing.Any]) -> bool:
    return getattr(cls, "_is_protocol", False)


def _finalise_custom_id(component: type[ComponentBase]) -> None:
    """Turn a string, auto id, or custom id into a fully-fledged custom id."""
    custom_id = component.custom_id

    if isinstance(custom_id, custom_id_impl.AutoID):
        # Make concrete custom id from provided auto-id...
        component.custom_id = custom_id_impl.CustomID.from_auto_id(component, custom_id)

    elif isinstance(custom_id, custom_id_impl.CustomID):
        # User-created custom id; ensure validity...
        custom_id.validate(component)

    elif isinstance(custom_id, str):  # pyright: ignore[reportUnnecessaryIsInstance]
        # Assume static custom id-- only a name without fields.
        # TODO: is this a good/"valuable" assumption?
        component.custom_id = custom_id_impl.CustomID(name=component.__name__)

    else:
        msg = (
            "A component's custom id must be of type 'str' or any derivative"
            f" thereof, got {type(custom_id).__name__!r}."
        )
        raise TypeError(msg)


def _apply_overrides(
    cls: type[ComponentBase],
    namespace: dict[str, typing.Any],
) -> None:
    """Turn malformed overrides into valid attrs fields."""
    if not attr.has(cls):  # Nothing to override.
        return

    # We only check pre-defined internal fields, such as label.
    for field in fields.get_fields(
        cls,
        kind=fields.FieldType.INTERNAL | fields.FieldType.MODAL,
    ):
        name = field.name
        if name not in namespace:
            continue

        new = namespace[name]

        # Ensure the new field isn't just magically an init-field now.
        if isinstance(new, _CountingAttr):
            # Emulate turning this into an Attribute so that the following checks work.
            # This may be slightly slow but it's only run once during class creation,
            # so it should be fine.
            new = typing.cast(
                "attr.Attribute[typing.Any]",
                attr.Attribute.from_counting_attr(name, new),  # pyright: ignore
            )
            new_field_type = fields.get_field_type(new)
            old_field_type = fields.get_field_type(field)

            # Ensure the field type remains unchanged.
            if new_field_type is not old_field_type:
                new_type_name = (new_field_type.name or "unknown").lower()
                old_type_name = (old_field_type.name or "unknown").lower()

                msg = (
                    f"Field '{cls.__name__}.{name}' is defined as a(n) {old_type_name} "
                    f"field, but was redefined as a(n) {new_type_name} field."
                )
                raise TypeError(msg)

            # Carry over the default value instead of the entire attribute.
            new = new.default

        new_field = attr.field(
            default=new,  # Update the default.
            init=field.init,
            metadata=field.metadata,
            on_setattr=field.on_setattr,
        )

        # Update the field information.
        setattr(cls, name, new_field)

        # Reapply the annotation, otherwise attrs breaks.
        cls.__annotations__.setdefault(name, field.type)


def _build_field_transformer_with_parsers(
    factory_builder: factory_impl.ComponentFactoryBuilder,
) -> typing.Callable[[type, list[_AnyAttr]], list[_AnyAttr]]:
    # Provide a ComponentFactoryBuilder to use as builder for the class, then
    # pass the resulting callback to the `field_transformer`. Finally, build
    # the full factory object after attrs is done creating the class.
    # We have a separate field transformer for protocols, as there's no reason
    # to build parsers for those classes, as they aren't instantiable anyways.

    def _field_transformer(cls: type, attributes: list[_AnyAttr]) -> list[_AnyAttr]:
        # Ensure all fields have valid metadata, fill missing parser types, and
        # build a ComponentFactory given all the attributes' parsers.

        # NOTE: Metadata is a mapping proxy, which means we can't directly
        #       mutate it. For this reason, we use evolve to copy and modify it.

        cls = typing.cast("type[ComponentBase]", cls)

        new_attributes: list[attr.Attribute[typing.Any]] = []

        for attribute in attributes:

            # Check if the field already has a field type defined.
            if fields.FieldMetadata.FIELDTYPE in attribute.metadata:
                if fields.FieldMetadata.PARSER in attribute.metadata:
                    parser = fields.get_parser(attribute)
                    if parser:
                        # Parser field defined and provided
                        factory_builder.add_field(attribute)
                        new_attributes.append(attribute)
                        continue

                    # Parser field defined but None provided (default).
                    parser = factory_builder.add_field(attribute)
                    evolved = attribute.evolve(
                        metadata={
                            **attribute.metadata,
                            fields.FieldMetadata.PARSER: parser,
                        }
                    )
                    new_attributes.append(evolved)
                    continue

                # Parser field not found, therefore no parser necessary.
                new_attributes.append(attribute)
                continue

            # No field definition found whatsoever; it's probably a custom id field.
            parser = factory_builder.add_field(attribute)
            evolved = attribute.evolve(
                metadata={
                    **attribute.metadata,  # Copy existing metadata
                    fields.FieldMetadata.FIELDTYPE: fields.FieldType.CUSTOM_ID,
                    fields.FieldMetadata.PARSER: parser,
                },
            )
            new_attributes.append(evolved)

        return new_attributes

    return _field_transformer


def _field_transformer(_: type, attributes: list[_AnyAttr]) -> list[_AnyAttr]:
    new_attributes: list[_AnyAttr] = []
    for attribute in attributes:

        # Check if the field already has a field type defined.
        if fields.FieldMetadata.FIELDTYPE in attribute.metadata:
            new_attributes.append(attribute)
            continue

        # If not, create a new attribute with field type set to custom id.
        # NOTE: Metadata is a mapping proxy, which means we can't directly
        #       mutate it. For this reason, we use evolve to copy and modify it.
        evolved = attribute.evolve(
            metadata={
                **attribute.metadata,  # Copy existing metadata
                fields.FieldMetadata.FIELDTYPE: fields.FieldType.CUSTOM_ID,
                fields.FieldMetadata.PARSER: None,
            },
        )
        new_attributes.append(evolved)

    return new_attributes


@typing_extensions.dataclass_transform(
    kw_only_default=True, field_specifiers=(fields.field, fields.internal)
)
class ComponentMeta(typing._ProtocolMeta):  # pyright: ignore[reportPrivateUsage]
    """Metaclass for all disnake-ext-components component types.

    It is **highly** recommended to use this metaclass for any class that
    should interface with the componenents api exposed by
    disnake-ext-components.

    This metaclass handles :mod:`attr` class generation, custom id completion,
    interfacing with component managers, parser and factory generation, and
    automatic slotting.
    """

    custom_id: custom_id_impl.CustomID

    # HACK: Pyright doesn't like this but it does seem to work with typechecking
    #       down the line. I might change this later (e.g. define it on
    #       BaseComponent instead, but that comes with its own challenges).
    factory: factory_api.ComponentFactory[typing_extensions.Self]  # pyright: ignore
    _parent: typing.Optional[type[typing.Any]]
    __module_id__: int

    def __new__(  # noqa: D102
        mcls,  # pyright: ignore[reportSelfClsParameterName]
        name: str,
        bases: tuple[type, ...],
        namespace: dict[str, typing.Any],
    ) -> ComponentMeta:
        # NOTE: This is run twice for each new class; once for the actual class
        #       definition, and once more by attr.define(). We ensure we only
        #       run the full class creation logic once.

        # Set slots if attrs hasn't already done so.
        namespace.setdefault("__slots__", ())

        cls = typing.cast(
            "type[ComponentBase]",
            super().__new__(mcls, name, bases, namespace),
        )

        # If this is attrs' pass, return immediately after it has worked its magic.
        if _is_attrs_pass(namespace):
            return cls

        # A reference to the actual module object is needed to ensure the
        # component is still in scope. In case the referenced module is no
        # longer in sys.modules, the component should be considered inactive,
        # and it will (hopefully) soon be GC'ed.
        cls.__module_id__ = id(sys.modules[cls.__module__])

        # Before we pass the class off to attrs, check if any fields were
        # overwritten. If so, check them for validity and update them to proper
        # attrs fields. This adds support for redefining internal fields as
        # `label = "foo"` instead of `label = fields.internal("foo")`
        _apply_overrides(cls, namespace)

        if _is_protocol(cls):
            cls = attr.define(
                cls,
                slots=True,
                kw_only=True,
                field_transformer=_field_transformer,
            )

            cls.factory = factory_impl.NoopFactory.from_component(cls)
            return cls

        builder = factory_impl.ComponentFactoryBuilder()
        cls = attr.define(
            cls,
            slots=True,
            kw_only=True,
            field_transformer=_build_field_transformer_with_parsers(builder),
        )

        cls.factory = builder.build(cls)  # pyright: ignore

        # Subscribe the new component to its manager if it inherited one.
        if cls.manager:
            cls.manager.subscribe(cls)

        _finalise_custom_id(cls)
        return cls

    # NOTE: This is relevant because classes are removed by gc instead of
    #       reference-counting. This means that, even though a module has been
    #       unloaded or a class has been `del`'d, it will still stick around
    #       until gc picks it up. Since we do not want to activate components
    #       that have gone out-of-scope in this sense, we need to explicitly
    #       account for this.
    @property
    def is_active(self) -> bool:
        """Determine whether this component is currently in an active module."""
        return (
            self.__module__ in sys.modules
            and self.__module_id__ == id(sys.modules[self.__module__])
        )  # fmt: skip


[docs]@typing.runtime_checkable class ComponentBase( component_api.RichComponent, typing.Protocol, metaclass=ComponentMeta ): """Overarching base class for any kind of component.""" _parent: typing.ClassVar[typing.Optional[type[typing.Any]]] = None manager: typing.ClassVar[typing.Optional[component_api.ComponentManager]] = None
[docs] @classmethod def set_manager( # noqa: D102 cls, manager: typing.Optional[component_api.ComponentManager], / ) -> None: # <<docstring inherited from component_api.RichComponent>> if cls.manager is manager: return if cls.manager: cls.manager.unsubscribe(cls, recursive=False) cls.manager = manager
[docs] @classmethod def should_invoke_for( # noqa: D102 cls, interaction: disnake.Interaction, / ) -> bool: # <<Docstring inherited from component_api.RichComponent>> custom_id = typing.cast(custom_id_impl.CustomID, cls.custom_id) return custom_id.check_name(_extract_custom_id(interaction))
[docs] async def dumps(self) -> str: # noqa: D102 # <<Docstring inherited from component_api.RichComponent>> factory = type(self).factory return await factory.dumps(self)
[docs] async def as_ui_component(self) -> disnake.ui.WrappedComponent: # noqa: D102 # <<Docstring inherited from component_api.RichComponent>> ...