Source code for debusine.tasks.inputs.inputs

# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.
"""Input fields for tasks."""

import abc
import enum
from collections.abc import Callable, Generator
from collections.abc import Collection as ABCCollection
from operator import attrgetter
from types import GenericAlias
from typing import (
    Any,
    ClassVar,
    NotRequired,
    Self,
    TypedDict,
    Unpack,
    cast,
    override,
)

from debusine.artifacts.models import ArtifactCategory
from debusine.client.models import RelationType
from debusine.tasks.executors import ExecutorImageCategory
from debusine.tasks.inputs import serializers
from debusine.tasks.inputs.resolver import FieldResolver
from debusine.tasks.models import (
    ArtifactInfo,
    BaseDynamicTaskData,
    BaseTaskInputArtifact,
    ExtraExternalRepository,
    ExtraRepository,
    InputArtifactMultiple,
    InputCollectionSingle,
    LookupMultiple,
    LookupSingle,
)
from debusine.utils import extract_generic_type_arguments


class Stage(enum.IntEnum):
    """Sorted enumeration for stages in preparing tasks for scheduling."""

    #: Inputs needed to prepare tasks for configuration lookup
    PENDING = 1
    #: Inputs needed to prepare tasks for running
    RUNNING = 2


class TaskFieldContainer:
    """Class that can contain task fields."""

    #: Inputs defined in this task class
    inputs: dict[str, "TaskInput[Any]"] = {}

    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Merge input field definitions from multiple parents."""
        super().__init_subclass__()
        inputs: dict[str, "TaskInput[Any]"] = {}
        for parent in reversed(cls.__mro__):
            if not issubclass(parent, TaskFieldContainer):
                continue
            inputs.update(parent.inputs)
        inputs.update(cls.inputs)
        cls.inputs = inputs

    def __init__(self) -> None:
        """Initialize a TaskFieldContainer."""
        super().__init__()
        #: If set, when fields of this container are resolved, fields of the
        #: child container are also resolved, using the resolver instantiated
        #: for the parent container
        self.child_task_field_container: TaskFieldContainer | None = None

    def iter_all_fields(
        self, stage: Stage | None = None
    ) -> Generator[tuple["TaskFieldContainer", "TaskInput[Any]"]]:
        """Enumerate the fields to resolve, and their container object."""
        if stage is None:
            stage = Stage.RUNNING

        for inp in self.inputs.values():
            if inp.stage > stage:
                continue
            yield self, inp
        if self.child_task_field_container is not None:
            yield from self.child_task_field_container.iter_all_fields(stage)

    def field_is_resolved(self, name: str) -> bool:
        """Return True if the field has been resolved."""
        return name in self.__dict__

    def save_field_values(self, dynamic_data: BaseDynamicTaskData) -> None:
        """
        Save field values in a BaseDynamicTaskData.

        Fields not defined in this container that are already saved are
        preserved: this allows multiple field containers to serialize to the
        same BaseDynamicTaskData, as long as their names do not conflict.
        """
        task_input_fields = dynamic_data.task_input_fields
        for cont, inp in self.iter_all_fields():
            task_input_fields[inp.name] = inp.serialize_value(
                getattr(cont, inp.name)
            )

        # We need to perform an assignment to task_input_fields and not just
        # modify it in place, otherwise the pydantic model will consider the
        # field unset and not serialize it
        dynamic_data.task_input_fields = task_input_fields

    def restore_field_values(self, dynamic_data: BaseDynamicTaskData) -> None:
        """
        Restore field values serialized in a BaseDynamicTaskData.

        Stored fields with names not defined in this container are ignored:
        this allows multiple field containers to serialize to the same
        BaseDynamicTaskData, as long as their names do not conflict.
        """
        values = dynamic_data.task_input_fields
        for cont, inp in self.iter_all_fields():
            if inp.name not in values:
                continue
            setattr(cont, inp.name, inp.deserialize_value(values[inp.name]))


class TaskInputKwargs(TypedDict):
    """Argument typing for TaskInput constructor."""

    stage: NotRequired[Stage]
    verbose_name: NotRequired[str]
    help_text: NotRequired[str]
    hidden: NotRequired[bool]
    internal: NotRequired[bool]


class TaskInput[ValueType](abc.ABC):
    """Representation for one or more input artifacts for a task."""

    name: str
    stage: Stage
    serializer: ClassVar[serializers.TaskInputSerializer[ValueType]]

    def get_resolver_method(
        self, resolver: FieldResolver
    ) -> Callable[[Self], ValueType]:
        """Override this to forward to a FieldResolver method."""
        raise NotImplementedError(
            "This input field does not forward resolution to the FieldResolver."
        )

    def resolve(self, resolver: FieldResolver) -> ValueType:
        """Resolve the input from task data."""
        method = self.get_resolver_method(resolver)
        return method(self)

    def get_artifact_ids(self, value: ValueType) -> list[int]:  # noqa: ARG002, U100
        """Return artifact IDs."""
        return []

    def __init__(
        self,
        stage: Stage = Stage.RUNNING,
        verbose_name: str | None = None,
        help_text: str | None = None,
        hidden: bool = False,
        internal: bool = False,
    ) -> None:
        """
        Declare an input field.

        :param stage: stage of task preparation when the field is needed
        :param verbose_name: label to use instead of the autogenerated one
        :param help_text: description of the field
        :param hidden: set to True to hide the field from the UI
        :param internal: show this widget as an internal field in the UI
        """
        self.stage = stage
        self.verbose_name = verbose_name
        self.help_text = help_text
        self.hidden = hidden
        self.internal = internal

    def __set_name__(self, owner: type[TaskFieldContainer], name: str) -> None:
        """Set the field name at class construction."""
        self.name = name
        # Avoid sharing the inputs definition across all subclasses of
        # BaseExternalTask
        if "inputs" not in owner.__dict__:
            owner.inputs = owner.inputs.copy()
        owner.inputs[name] = self

    def __get__(
        self,
        obj: TaskFieldContainer,
        _: type[TaskFieldContainer] | None = None,
    ) -> ValueType:
        """Look up this input in the task object."""
        from debusine.tasks import TaskConfigError

        # If __get__ is called, it means that self.name is not in obj.__dict__,
        # which only happens if the field value has not yet been resolved
        raise TaskConfigError(f"Cannot access unresolved input {self.name!r}")

    def get_value(self, field_container: TaskFieldContainer) -> ValueType:
        """Return the current value of this field in the container."""
        return cast(ValueType, getattr(field_container, self.name))

    def serialize_value(self, value: ValueType) -> Any:
        """Serialize a value for storing in dynamic task data."""
        return self.serializer.serialize(value)

    def deserialize_value(self, value: Any) -> ValueType:
        """Serialize a value for storing in dynamic task data."""
        return self.serializer.deserialize(value)


class DataFieldInputKwargs(TaskInputKwargs):
    """Argument typing for DataFieldInput constructor."""

    field: NotRequired[str]


class DataFieldInput[LookupType, ValueType](TaskInput[ValueType]):
    """Input field that resolves values based on a task data field."""

    field: str
    # Actual lookup type, filled by __init_subclass__
    lookup_type: type[LookupType]

    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Fill lookup_type from typing information."""
        super().__init_subclass__(**kwargs)
        [cls.lookup_type, _] = extract_generic_type_arguments(
            cls, DataFieldInput
        )
        # Given that we use lookup_type for isinstance checks, if it is a
        # generic alias we can strip the parameterization and only keep the
        # original type. This allows to type LookupType as ABCCollection[foo]
        # and have isinstance checks against ABCCollection
        if isinstance(cls.lookup_type, GenericAlias):
            cls.lookup_type = cls.lookup_type.__origin__

    def __init__(
        self, field: str | None = None, **kwargs: Unpack[TaskInputKwargs]
    ) -> None:
        """
        Declare an input defined by a task data field.

        :param field: dot-separated list of lookup names in task data. Defaults
          to the input member name
        """
        super().__init__(**kwargs)
        if field is not None:
            self._set_field_name(field)

    def __set_name__(self, owner: type[TaskFieldContainer], name: str) -> None:
        """Set the field name at class construction."""
        super().__set_name__(owner, name)
        if not hasattr(self, "field"):
            self._set_field_name(name)

    def _set_field_name(self, name: str) -> None:
        """Set the attribute name used to look up the value in task input."""
        self.field = name
        self.getter = attrgetter(name)


class DataListFieldInput[LookupType, ValueType](
    DataFieldInput[ABCCollection[LookupType], ValueType]
):
    """DataFieldInput that resolves to a list of type checked elements."""

    # Type of each list item, filled by __init_subclass__
    item_type: type[LookupType]

    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Fill lookup_type from typing information."""
        super().__init_subclass__(**kwargs)
        [cls.item_type, _] = extract_generic_type_arguments(
            cls, DataListFieldInput
        )


class BaseArtifactInputKwargs(DataFieldInputKwargs):
    """Argument typing for BaseArtifactInput constructor."""

    categories: NotRequired[ABCCollection[ArtifactCategory]]


class BaseArtifactInput[LookupType, ValueType](
    DataFieldInput[LookupType, ValueType]
):
    """Look up artifacts enforcing their types."""

    categories: tuple[ArtifactCategory, ...] | None

    def __init__(
        self,
        *,
        categories: ABCCollection[ArtifactCategory] | None = None,
        **kwargs: Unpack[DataFieldInputKwargs],
    ) -> None:
        """
        Look up artifacts enforcing their type.

        :param field: dot-separated list of lookup names in task data
        :param categories: list of acceptable artifact categories
        """
        super().__init__(**kwargs)
        self.categories = tuple(categories) if categories is not None else None

    def check_category(
        self, info: BaseTaskInputArtifact, field_name: str | None = None
    ) -> None:
        """Check the category of an ArtifactInfo."""
        # Prevent import loop
        from debusine.tasks import ensure_artifact_categories

        if self.categories is None:
            return

        # TODO: this should not happen, as we are checking the results of
        # database lookups which always have a category. However, info.category
        # can be None for compatibility during a transition time of merging
        # with InputArtifactMultiple, which needs to be able to deserialize
        # from only a list of artifact IDs
        assert info.category is not None

        ensure_artifact_categories(
            configuration_key=field_name or self.field,
            category=info.category,
            expected=self.categories,
        )


class ArtifactListInput[LookupType, ValueType](
    BaseArtifactInput[ABCCollection[LookupType], ValueType],
    DataListFieldInput[LookupType, ValueType],
):
    """Input taking a list of artifact lookups, enforcing their types."""


[docs] class SingleInput(BaseArtifactInput[LookupSingle, ArtifactInfo]): """One single input artifact specified in a task data field.""" serializer = serializers.PydanticSerializer(ArtifactInfo)
[docs] @override def resolve(self, resolver: FieldResolver) -> ArtifactInfo: lookup = resolver.get_field(self) assert lookup is not None info = resolver.lookup_single_artifact(lookup, label=self.field) self.check_category(info) return info
[docs] @override def get_artifact_ids(self, value: ArtifactInfo) -> list[int]: return [value.artifact_id]
[docs] class OptionalSingleInput(BaseArtifactInput[LookupSingle, ArtifactInfo | None]): """One optional single input artifact specified in a task data field.""" serializer = serializers.OptionalPydanticSerializer(ArtifactInfo)
[docs] @override def resolve(self, resolver: FieldResolver) -> ArtifactInfo | None: lookup = resolver.get_field(self) if lookup is None: return None info = resolver.lookup_single_artifact(lookup, label=self.field) self.check_category(info) return info
[docs] @override def get_artifact_ids(self, value: ArtifactInfo | None) -> list[int]: if value is None: return [] return [value.artifact_id]
[docs] class SingleInputList(ArtifactListInput[LookupSingle, list[ArtifactInfo]]): """A list of single input artifact lookups in a task data field.""" serializer = serializers.PydanticListSerializer(ArtifactInfo)
[docs] @override def resolve(self, resolver: FieldResolver) -> list[ArtifactInfo]: infos: list[ArtifactInfo] = [] for lookup in resolver.get_field_collection(self): assert isinstance(lookup, LookupSingle) info = resolver.lookup_single_artifact(lookup, label=self.field) self.check_category(info) infos.append(info) return infos
[docs] @override def get_artifact_ids(self, value: list[ArtifactInfo]) -> list[int]: return [artifact.artifact_id for artifact in value]
[docs] class MultiInput(BaseArtifactInput[LookupMultiple, InputArtifactMultiple]): """A list of multiple artifacts lookups specified in a task data field.""" serializer = serializers.PydanticSerializer(InputArtifactMultiple)
[docs] @override def resolve(self, resolver: FieldResolver) -> InputArtifactMultiple: lookup = resolver.get_field(self) assert lookup is not None result = resolver.lookup_multiple_artifacts(lookup, label=self.field) if self.categories is not None: for i, artifact in enumerate(result.artifacts): self.check_category(artifact, f"{self.field}[{i}]") return result
[docs] @override def get_artifact_ids(self, value: InputArtifactMultiple) -> list[int]: return [ artifact.artifact_id for artifact in value.artifacts if artifact.artifact_id is not None ]
[docs] class MultiInputList( ArtifactListInput[LookupMultiple, list[InputArtifactMultiple]] ): """A multiple artifacts lookup specified in a task data field.""" serializer = serializers.MultiInputListSerializer()
[docs] @override def resolve(self, resolver: FieldResolver) -> list[InputArtifactMultiple]: infos: list[InputArtifactMultiple] = [] for lookup_idx, lookup in enumerate( resolver.get_field_collection(self) ): result = resolver.lookup_multiple_artifacts( lookup, label=self.field ) if self.categories is not None: for result_idx, artifact in enumerate(result.artifacts): self.check_category( artifact, f"{self.field}[{lookup_idx}][{result_idx}]" ) infos.append(result) return infos
[docs] @override def get_artifact_ids(self, value: list[InputArtifactMultiple]) -> list[int]: ids = [] for input_artifact_multiple in value: for artifact in input_artifact_multiple.artifacts: if artifact.artifact_id is not None: ids.append(artifact.artifact_id) return ids
[docs] class UploadArtifactsInput( BaseArtifactInput[LookupMultiple, InputArtifactMultiple] ): """A multiple artifact lookup finding all related binaries in an upload.""" serializer = serializers.PydanticSerializer(InputArtifactMultiple)
[docs] @override def __init__(self, **kwargs: Unpack[BaseArtifactInputKwargs]) -> None: categories_with_upload = [ArtifactCategory.UPLOAD] categories_with_upload.extend(kwargs.get("categories") or ()) kwargs["categories"] = categories_with_upload super().__init__(**kwargs)
[docs] @override def resolve(self, resolver: FieldResolver) -> InputArtifactMultiple: lookup = resolver.get_field(self) assert lookup is not None # Incrementally build the result set result: list[BaseTaskInputArtifact] = [] # Lookup the initial set of packages upload_ids: list[int] = [] initial = resolver.lookup_multiple_artifacts( lookup, label=self.field ).artifacts for i, artifact in enumerate(initial): if artifact.category == ArtifactCategory.UPLOAD: # File uploads for later expansion assert artifact.artifact_id is not None upload_ids.append(artifact.artifact_id) else: self.check_category(artifact, f"{self.field}[{i}]") result.append(artifact) # Expand with binaries related to uploads if upload_ids: result.extend( resolver.find_related_artifacts( upload_ids, target_category=ArtifactCategory.BINARY_PACKAGE, relation_type=RelationType.EXTENDS, label=self.field, ).artifacts ) return InputArtifactMultiple( lookup=lookup, label=self.field, artifacts=result )
class EnvironmentInputKwargs(BaseArtifactInputKwargs): """Argument typing for EnvironmentInput constructor.""" image_category: NotRequired[ExecutorImageCategory | None] set_backend: NotRequired[bool] try_variant: NotRequired[bool]
[docs] class EnvironmentInput(BaseArtifactInput[LookupSingle, ArtifactInfo]): """One environment artifact specified in a task data field.""" serializer = serializers.PydanticSerializer(ArtifactInfo)
[docs] def __init__( self, *, image_category: ExecutorImageCategory | None = None, set_backend: bool = True, try_variant: bool = True, **kwargs: Unpack[BaseArtifactInputKwargs], ) -> None: """ Look up artifacts enforcing their type. :param field: dot-separated list of lookup names in task data :param categories: list of acceptable artifact categories :param image_category: try to use an environment with this image category; defaults to the image category needed by the executor for `self.backend` :param set_backend: if True (default), try to use an environment matching `self.backend` :param try_variant: if True (default), try to use an environment whose variant is `self.name`, but fall back to looking up an environment without a variant if the first lookup fails """ super().__init__(**kwargs) self.image_category = image_category self.set_backend = set_backend self.try_variant = try_variant
[docs] @override def get_resolver_method( self, resolver: FieldResolver ) -> Callable[[Self], ArtifactInfo]: return resolver.resolve_environmentinput
[docs] @override def get_artifact_ids(self, value: ArtifactInfo) -> list[int]: return [value.artifact_id]
[docs] class OptionalEnvironmentInput( BaseArtifactInput[LookupSingle, ArtifactInfo | None] ): """One optional environment artifact specified in a task data field.""" serializer = serializers.OptionalPydanticSerializer(ArtifactInfo)
[docs] @override def get_resolver_method( self, resolver: FieldResolver ) -> Callable[[Self], ArtifactInfo | None]: return resolver.resolve_optionalenvironmentinput
[docs] class SuiteArchiveInput( DataFieldInput[ LookupSingle, tuple[InputCollectionSingle, InputCollectionSingle | None] ] ): """Look up a suite and an archive based on a task data field.""" serializer = serializers.SuiteArchiveInputSerializer()
[docs] @override def get_resolver_method( self, resolver: FieldResolver ) -> Callable[ [Self], tuple[InputCollectionSingle, InputCollectionSingle | None] ]: return resolver.resolve_suitearchiveinput
[docs] class ExtraRepositoriesInput( DataFieldInput[list[ExtraRepository], list[ExtraExternalRepository] | None] ): """Look up extra repositories.""" serializer = serializers.ExtraRepositoriesSerializer()
[docs] @override def get_resolver_method( self, resolver: FieldResolver ) -> Callable[[Self], list[ExtraExternalRepository] | None]: return resolver.resolve_extrarepositoriesinput
[docs] class DebusineFQDNInput(TaskInput[str]): """Resolve to the Debusine FQDN.""" serializer = serializers.ScalarSerializer(str)
[docs] @override def resolve(self, resolver: FieldResolver) -> str: return resolver.get_server_setting("DEBUSINE_FQDN")
[docs] class BuildArchitectureInput(DataFieldInput[str, str]): """ Build architecture for the task. This currently looks up a value in the task data, and uses a fallback value if it is not set. """ serializer = serializers.ScalarSerializer(str)
[docs] @override def resolve(self, resolver: FieldResolver) -> str: if arch := resolver.get_field(self): return arch return resolver.get_fallback_architecture()
class UnsignedMultiInput( BaseArtifactInput[LookupMultiple, InputArtifactMultiple] ): """A multiple artifact lookup finding all related artifacts.""" serializer = serializers.PydanticSerializer(InputArtifactMultiple) @override def resolve(self, resolver: FieldResolver) -> InputArtifactMultiple: lookup = resolver.get_field(self) assert lookup is not None result: list[BaseTaskInputArtifact] = [] signed_artifacts: set[int] = set() for signed in resolver.lookup_multiple_artifacts( lookup, label=self.field ).artifacts: assert signed.artifact_id signed_artifacts.add(signed.artifact_id) signing_inputs: set[int] = set() for signing_input in resolver.find_related_artifacts( signed_artifacts, target_category=ArtifactCategory.SIGNING_INPUT, relation_type=RelationType.RELATES_TO, label=self.field, ).artifacts: assert signing_input.artifact_id result.append(signing_input) signing_inputs.add(signing_input.artifact_id) for binary_package in resolver.find_related_artifacts( signing_inputs, target_category=ArtifactCategory.BINARY_PACKAGE, relation_type=RelationType.RELATES_TO, label=self.field, ).artifacts: result.append(binary_package) def artifact_id(result: BaseTaskInputArtifact) -> int: assert result.artifact_id is not None return result.artifact_id result.sort(key=artifact_id) return InputArtifactMultiple( lookup=lookup, label=self.field, artifacts=result )