Source code for debusine.worker._worker

# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""
Worker client: connects to debusine server.

Overview
--------
* Registration (needed only once per worker): If the worker doesn't have a
  token: it will generate it and register with the server
  (HTTP POST to ``/api/1.0/worker/register``)
* The client will use this token to connect to the server (WebSocket to
  ``/api/ws/1.0/worker/connect``)

Flow
----

  #. The worker is executed and chooses ``~/.config/debusine/worker``
     (if it exists) or ``/etc/debusine/worker``. It reads the file
     ``config.ini`` from the directory and if it already exists the file
     ``token``.
  #. If there isn't a token the worker generates one (using
     :py:func:`secrets.token_hex`) and registers it to the Debusine server
     via HTTP POST to `/api/1.0/worker/register` sending the generated token
     and the worker's FQDN. The token is saved to the ``token`` file in the
     chosen config directory.
  #. The server will create a new Token and Worker in the DB via the models.
     They wouldn't be used until manual validation.
  #. The client can then connect using WebSockets to
     ``/api/ws/1.0/worker/connect`` and wait for commands to execute.

Objects documentation
---------------------

"""

import abc
import asyncio
import functools
import logging
import secrets
import shutil
import signal
import socket
import time as tm
import types
from pathlib import Path
from textwrap import dedent
from threading import Lock
from typing import Any, Literal, NoReturn, Self, override

import aiohttp
import tenacity
from aiohttp.web_exceptions import HTTPCreated, HTTPNoContent, HTTPOk
from pydantic import ValidationError

from debusine.artifacts.models import WorkRequestResults
from debusine.client.debusine import Debusine
from debusine.client.models import WorkRequestResponse
from debusine.tasks import (
    analyze_external_worker_tasks,
    get_provided_external_worker_tags,
)
from debusine.tasks.executors import (
    analyze_worker_all_executors,
    executor_provided_tags,
)
from debusine.tasks.models import OutputData, OutputDataError, WorkerType
from debusine.utils import debusine_version
from debusine.worker.config import ConfigHandler
from debusine.worker.debusine_async_http_client import DebusineAsyncHttpClient
from debusine.worker.system_information import (
    system_metadata,
    system_provided_tags,
)
from debusine.worker.task_runner import TaskRunner, TaskRunnerError

logger = logging.getLogger("debusine.worker")


class Event:
    """Base class for worker events."""

    def __init__(self, name: str) -> None:
        self.name = name
        self.perf_ns = tm.perf_counter_ns()

    def __str__(self) -> str:
        return self.name


class ShutdownEvent(Event, abc.ABC):
    """Shut down the worker."""

    def __init__(self, message: str) -> None:
        super().__init__("shutdown")
        self.message = message

    def __str__(self) -> str:
        return f"{self.name}: {self.message}"

    @abc.abstractmethod
    def get_exit_code(self) -> int:
        """Return the exit code to use."""


class ServerDisconnectShutdownEvent(ShutdownEvent):
    """Shut down because the server disconnected the event stream."""

    def __init__(self) -> None:
        super().__init__("server notification connection closed")

    @override
    def get_exit_code(self) -> int:
        return 1


class ReportTaskResultFailedShutdownEvent(ShutdownEvent):
    """Shut down after failing to report the result of a task."""

    def __init__(self) -> None:
        super().__init__("cannot send results to server")

    @override
    def get_exit_code(self) -> int:
        return 2


class SignalReceivedShutdownEvent(ShutdownEvent):
    """Shutdown after receiving a signal."""

    def __init__(self, sig: signal.Signals) -> None:
        super().__init__(f"signal {sig} received")
        self.sig = sig

    @override
    def get_exit_code(self) -> int:
        return 128 + self.sig


class ServerEvent(Event):
    """Server-triggered event."""

    def __init__(self, name: str) -> None:
        super().__init__(name)

    @classmethod
    def from_message(cls, payload: dict[str, Any]) -> "ServerEvent":
        """Instantiate a server event from a server message."""
        match payload.get("text"):
            case "connected":
                return ConnectedEvent()
            case "request_dynamic_metadata":
                return RequestDynamicMetadataEvent()
            case "work_request_available":
                return WorkRequestAvailableEvent()
            case None:
                raise ValueError(f"{payload!r}: message payload without 'text'")
            case _:
                raise NotImplementedError(
                    f"{payload!r}: unsupported server event"
                )


class ConnectedEvent(ServerEvent):
    """Connection to server established."""

    def __init__(self) -> None:
        super().__init__("connected")


class RequestDynamicMetadataEvent(ServerEvent):
    """Server requested dynamic metadata."""

    def __init__(self) -> None:
        super().__init__("request_dynamic_metadata")


class WorkRequestAvailableEvent(ServerEvent):
    """Server notifies that there is a work request available."""

    def __init__(self) -> None:
        super().__init__("work_request_available")


class QuitMainLoopEvent(Event):
    """
    Exit the main loop.

    This is mostly used by tests to exit the main loop waiting for pending
    messages.
    """

    def __init__(self) -> None:
        super().__init__("quit_main_loop")


[docs] class Worker: """Worker class: waits for commands from the debusine server.""" DEFAULT_LOG_LEVEL = logging.INFO AFTER_WORK_REQUEST_HOOK_DIR = Path( "/etc/debusine/worker/hooks/after-work-request.d" )
[docs] def __init__( self, *, log_level: str | None = None, worker_type: Literal[ WorkerType.EXTERNAL, WorkerType.SIGNING ] = WorkerType.EXTERNAL, config: ConfigHandler | None = None, working_directory: str | None = None, ) -> None: """ Initialize Worker. :param log_level: minimum level of the logs being saved. If None uses settings from config.ini or DEFAULT_LOG_LEVEL. :param config: ConfigHandler to use (or creates a default one) """ self._original_log_level = log_level self._worker_type = worker_type if config is None: if worker_type == WorkerType.SIGNING: self._config = ConfigHandler( directories=[ str(Path.home() / ".config/debusine/signing"), "/etc/debusine/signing", ], require_https=True, ) else: self._config = ConfigHandler() else: self._config = config self._setup_logging() self._setup_workdir(working_directory) self._config.validate_config_or_fail() self._aiohttp_client_session: aiohttp.ClientSession | None = None self._async_http_client = DebusineAsyncHttpClient(self._config) # Main event queue self.event_queue: asyncio.Queue[Event] = asyncio.Queue() # Used by ThreadPoolExecutor to submit tasks to the main event loop self._main_event_loop: asyncio.AbstractEventLoop | None = None # Lock to protect from concurrent work request execution self._task_lock: Lock = Lock()
async def __aenter__(self) -> Self: """Async context manager support.""" return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: """Release worker resources.""" await self._async_http_client.close() def _setup_workdir(self, working_directory: str | None) -> None: if working_directory: self._workdir = Path(working_directory) else: self._workdir = self._config.working_directory # Create the working directory if needed try: self._workdir.mkdir(parents=True, exist_ok=True) except OSError as e: self._fail( f"could not create working directory '{self._workdir}': {e}" )
[docs] def clean_workdir(self) -> None: """Clean the working directory from interrupted work requests.""" for entry in self._workdir.iterdir(): if entry.is_dir() and entry.name.startswith("WR-"): try: shutil.rmtree(entry) except OSError as e: logger.warning("Failed to remove '%s': %s", entry, str(e))
def _setup_logging(self) -> None: """ Set logging configuration. Use the parameters passed to Worker.__init__() and self._config configuration. """ log_level = ( self._original_log_level or self._config.log_level or logging.getLevelName(self.DEFAULT_LOG_LEVEL) ) logging.basicConfig( format='%(message)s', level=log_level, force=True, ) def _create_task_signal_int_term_handler( self, signum: signal.Signals ) -> asyncio.Task[None]: return asyncio.create_task(self._signal_int_term_handler(signum))
[docs] async def main(self) -> None: """Run the worker.""" self.clean_workdir() logger.info( "Debusine worker is starting. Version: %s", debusine_version() ) for signum in [signal.SIGINT, signal.SIGTERM]: asyncio.get_running_loop().add_signal_handler( signum, functools.partial( self._create_task_signal_int_term_handler, signum ), ) await self.require_token() if event := await self.process_notifications(): raise SystemExit(event.get_exit_code()) else: logger.warning("Unknown shutdown cause") raise SystemExit(4)
[docs] async def process_notifications(self) -> ShutdownEvent | None: """Process server notifications.""" class Shutdown(Exception): """Shutdown has been requested.""" shutdown_event: ShutdownEvent | None = None try: async with asyncio.TaskGroup() as tg: tg.create_task(self.handle_server_notifications()) while True: event = await self.event_queue.get() logger.debug("Worker event: %s", event) match event: case ShutdownEvent(): # Raise an exception instead of breaking out of the # while, so that tasks in tg are cancelled shutdown_event = event raise Shutdown(str(event)) case ConnectedEvent(): logger.debug("Received the 'connected' message") case RequestDynamicMetadataEvent(): tg.create_task(self._send_dynamic_metadata()) case WorkRequestAvailableEvent(): tg.create_task( self._request_and_execute_work_request() ) case QuitMainLoopEvent(): break case _: logger.warning( "Ignoring unsupported event: %s", event ) except* Shutdown as exc: logger.info("Shutdown requested: %s", exc.exceptions[0]) return shutdown_event
[docs] async def require_token(self) -> None: """Handle registration if needed.""" if not self._config.token: result = await self._register() if result is False: self._fail('Exiting...')
async def _signal_int_term_handler(self, signum: signal.Signals) -> None: self.log_forced_exit(signum) await self.event_queue.put(SignalReceivedShutdownEvent(signum)) @staticmethod def _log_server_error(status_code: int, body: str) -> None: logger.error( 'Could not register. Server HTTP status code: %s Body: %s\n', status_code, body, ) async def _register(self) -> bool: """ Create a token, registers it to debusine, saves it locally. The worker will not receive any tasks until the debusine admin has approved the token. """ token = secrets.token_hex(32) register_path = '/1.0/worker/register/' data = { "token": token, "fqdn": socket.getfqdn(), "worker_type": self._worker_type, } try: response = await self._async_http_client.post( register_path, json=data, token=self._config.activation_token ) except aiohttp.client_exceptions.ClientResponseError as err: self._log_server_error(err.status, 'Not available') return False except aiohttp.client_exceptions.ClientConnectorError as err: logger.error( # noqa: G200 'Could not register. Server unreachable: %s', err ) return False status, body = ( response.status, (await response.content.read()).decode('utf-8'), ) if status == HTTPCreated.status_code: self._config.write_token(token) return True else: self._log_server_error(status, body) return False @staticmethod async def _asyncio_sleep(delay: float) -> None: """Sleep asynchronously (mocked in tests).""" await asyncio.sleep(delay)
[docs] async def handle_server_notifications(self) -> None: """Read incoming server notifications and generate worker events.""" client = Debusine( self._config.api_url, self._config.token, timeout=self._config.http_timeout, logger=logger, scope=None, ) async with client.server_notifications( endpoint="1.0/worker/connect/" ) as server_notifications: async for payload in server_notifications.messages(): try: event = ServerEvent.from_message(payload) except Exception as exc: logger.warning( "Ignoring unsupported or malformed server message: %s", exc, ) continue await self.event_queue.put(event) await self.event_queue.put(ServerDisconnectShutdownEvent())
@staticmethod def _fail(message: str) -> NoReturn: logger.fatal(message) raise SystemExit(3) async def _request_and_execute_work_request(self) -> None: # Grab the task lock to make sure that the previous work request # completed fully. We can be notified before the end because the # server notifies via websocket while we are still wrapping up # the task execution. logger.info("Requesting a work request to execute") retries = 0 while not self._task_lock.acquire(blocking=False): await self._asyncio_sleep(1) retries += 1 if retries >= 60: logger.error( "Worker is busy and can't execute a new work request" ) return logger.debug( "Waited %s seconds for the end of the former task", retries ) work_request = await self._request_work_request() if work_request: logger.info( "Executing work request %d (%s)", work_request.id, work_request.task_name, ) async with TaskRunner( self._config, self._worker_type, work_request, self._workdir ) as runner: try: await runner.setup_task() result = await runner.execute() except TaskRunnerError as exc: logger.error(exc.message, *exc.args) await self.report_task_result( runner, WorkRequestResults.ERROR, OutputData( errors=[ OutputDataError( message=exc.message % exc.args, code=exc.code, ) ] ), ) else: # Send the task results output_data = OutputData( runtime_statistics=runner.statistics ) await self.report_task_result(runner, result, output_data) logger.debug("Finished executing work request %d", work_request.id) else: logger.info('No work request available') self._task_lock.release()
[docs] async def report_task_result( self, task_runner: TaskRunner, result: WorkRequestResults, output_data: OutputData, ) -> None: try: if task_runner.task and task_runner.task.aborted: logger.info("Task: %s has been aborted", task_runner.task.name) # No need to notify debusine-server return try: await asyncio.to_thread( task_runner.client.work_request_completed_update, task_runner.work_request.id, result, output_data, ) # This is asynchronous on the server side: the server may # not actually have marked the work request as completed # yet. However, this doesn't matter much, as the worker # won't do anything other than cleanup until the server # assigns it something else to do. logger.info( 'Task: %i completed (%s)', task_runner.work_request.id, result, ) except Exception: # Log this, but leave the work request running. The server # will retry it when this worker next manages to connect and # request a new work request to run. logger.exception( "Cannot reach server to report work request completed. " "Exiting." ) # Shutdown so that the worker will try to reconnect. await self.event_queue.put( ReportTaskResultFailedShutdownEvent() ) finally: await self._run_after_work_request_hook() self._reset_state()
async def _request_work_request(self) -> WorkRequestResponse | None: """Request a work request and returns it.""" @tenacity.retry( sleep=self._asyncio_sleep, wait=tenacity.wait_random_exponential( multiplier=5, min=5, max=60 * 5 ), retry=tenacity.retry_if_exception_type( aiohttp.client_exceptions.ClientError ) | tenacity.retry_if_exception_type(ConnectionError), ) async def inner() -> WorkRequestResponse | None: work_request = await self._async_http_client.get( '/1.0/work-request/get-next-for-worker/' ) if work_request.status == HTTPOk.status_code: try: work_request_obj = WorkRequestResponse.model_validate_json( await work_request.text() ) except ValidationError as exc: logger.warning( # noqa: G200 'Invalid WorkRequest received from' ' /get-next-for-worker/: %s', exc, ) return None return work_request_obj else: work_request.raise_for_status() return None return await inner() def _get_provided_tags(self) -> set[str]: """Return a set with tags provided by this worker.""" tags: set[str] = set() tags |= system_provided_tags(self._worker_type) tags |= executor_provided_tags() tags |= get_provided_external_worker_tags() return tags def _get_worker_metadata(self) -> dict[str, Any]: """Compute the worker metadata.""" return { **system_metadata(self._worker_type), **analyze_worker_all_executors(), **analyze_external_worker_tasks(), "provided_tags": sorted(self._get_provided_tags()), } async def _send_dynamic_metadata(self) -> None: dynamic_metadata_path = '/1.0/worker/dynamic-metadata/' metadata = self._get_worker_metadata() logger.info('Send dynamic metadata') logger.debug('Dynamic metadata: %s', metadata) response = await self._async_http_client.put( dynamic_metadata_path, json=metadata ) if response.status != HTTPNoContent.status_code: logger.error( 'Failed to send dynamic_metadata: (response: %d): %s', response.status, await response.text(), ) def _reset_state(self) -> None: """Reset worker to a clean state, ready for the next work request.""" if self._task_lock.locked(): # Release the task lock to allow the worker to process the next # work request self._task_lock.release() async def _run_after_work_request_hook(self) -> None: """Run any sysadmin configured post-work-request hooks.""" hook_dir = self.AFTER_WORK_REQUEST_HOOK_DIR if not hook_dir.exists() or not any(hook_dir.iterdir()): logger.debug("No hooks in %r, skipping execution", str(hook_dir)) return p = await asyncio.create_subprocess_exec( "run-parts", str(hook_dir), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await p.communicate() if p.returncode != 0: logger.warning( dedent( """\ Hooks failed to execute in %r . stdout: %s stderr: %s """ ).rstrip("\n"), str(hook_dir), stdout.decode("utf-8", errors="replace"), stderr.decode("utf-8", errors="replace"), )
[docs] @staticmethod def log_forced_exit(signum: signal.Signals) -> None: """Log a forced exit.""" logger.info('Terminated with signal %s', signum.name)