# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.
"""Data models for pools of workers."""
import logging
from functools import cached_property
from typing import TYPE_CHECKING
import pydantic
from django.core.exceptions import ValidationError
from django.db import models, transaction
from django.db.models import JSONField, QuerySet
from django.urls import reverse
from django.utils import timezone
from debusine.db.models.scopes import Scope
from debusine.server.worker_pools import (
ScopeWorkerPoolLimits,
WorkerPoolInterface,
WorkerPoolLimits,
WorkerPoolSpecifications,
provider_interface,
worker_pool_specifications_model,
)
if TYPE_CHECKING:
from django.http import HttpRequest
from django_stubs_ext.db.models import TypedModelMeta
from debusine.db.models.workers import Worker
from debusine.web.views.ui.worker_pools import WorkerPoolUI
else:
TypedModelMeta = object
logger = logging.getLogger(__name__)
class WorkerPoolManager(models.Manager["WorkerPool"]):
"""Manager for WorkerPool model."""
def enabled(self) -> QuerySet["WorkerPool"]:
"""Return connected workers."""
return WorkerPool.objects.filter(enabled=True)
[docs]
class WorkerPool(models.Model):
"""Database model of a worker pool."""
name = models.SlugField(
unique=True,
help_text='Human readable name of the worker pool',
)
provider_account = models.ForeignKey("Asset", on_delete=models.PROTECT)
enabled = models.BooleanField(default=True)
architectures = JSONField(default=list)
tags = JSONField(default=list, blank=True)
specifications = JSONField(default=dict)
instance_wide = models.BooleanField(default=True)
ephemeral = models.BooleanField(default=False)
limits = JSONField(default=dict, blank=True)
registered_at = models.DateTimeField()
objects = WorkerPoolManager()
class Meta(TypedModelMeta):
base_manager_name = "objects"
def __str__(self) -> str:
"""Return the id and name of the WorkerPool."""
return f"Id: {self.id} Name: {self.name}"
[docs]
def get_absolute_url(self) -> str:
"""Return an absolute URL to view this WorkerPool."""
return reverse("worker-pools:detail", kwargs={"name": self.name})
[docs]
def ui(self, request: "HttpRequest") -> "WorkerPoolUI":
"""Return a UI helper for this instance."""
from debusine.web.views.ui.worker_pools import WorkerPoolUI
return WorkerPoolUI.get(request, self)
@property
def limits_model(self) -> WorkerPoolLimits:
"""Return the pydantic model for limits."""
return WorkerPoolLimits.model_validate(self.limits)
@property
def specifications_model(self) -> WorkerPoolSpecifications:
"""Return the pydantic model for specifications."""
model = worker_pool_specifications_model(self.specifications)
provider_model = self.provider_account.data_model
assert hasattr(provider_model, "provider_type")
if provider_model.provider_type != model.provider_type:
raise ValueError(
f"specifications for worker_pool {self.name} do not have a "
f"provider_account with a matching provider_type."
)
return model
@property
def workers_running(self) -> QuerySet["Worker"]:
"""Return the Worker instances that are currently running."""
return self.worker_set.filter(instance_created_at__isnull=False)
@property
def workers_stopped(self) -> QuerySet["Worker"]:
"""Return the Worker instances that are currently stopped."""
return self.worker_set.filter(instance_created_at__isnull=True)
[docs]
@cached_property
def provider_interface(self) -> WorkerPoolInterface:
"""Return a WorkerPoolInterface instance for this pool."""
return provider_interface(self)
[docs]
def launch_workers(
self, count: int, override_disabled: bool = False
) -> None:
"""Launch count additional worker instances."""
from debusine.db.models.auth import Token
from debusine.db.models.workers import Worker
if not (self.enabled or override_disabled):
raise ValueError("Pool is disabled, refusing to launch workers.")
available = self.workers_stopped.count()
if available < count:
Worker.objects.create_pool_members(self, count - available)
launched = 0
while launched < count:
with transaction.atomic():
worker = (
self.workers_stopped.select_for_update(skip_locked=True)
.order_by("name")
.first()
)
if worker is None: # pragma: no cover
# Locked or deleted since we expanded the pool, above
return
old_activation_token = worker.activation_token
worker.activation_token = (
Token.objects.create_worker_activation()
)
worker.save()
if old_activation_token is not None:
old_activation_token.delete()
if worker.token is not None:
worker.token.disable()
self.provider_interface.launch_worker(worker)
launched += 1
[docs]
def terminate_worker(self, worker: "Worker") -> None:
"""Terminate the specified worker instance."""
# Import here to prevent circular imports
from debusine.db.context import context
from debusine.db.models.work_requests import (
CannotRetry,
WorkRequest,
WorkRequestRetryReason,
)
if worker.worker_pool != self:
raise ValueError(
f"pool {self} cannot terminate worker"
f" for pool {worker.worker_pool}"
)
# Commit this early to avoid scheduling any more work on the worker
with transaction.atomic():
if worker.activation_token is not None:
worker.activation_token.disable()
if worker.token is not None:
worker.token.disable()
if worker.instance_created_at is not None:
WorkerPoolStatistics.objects.create(
worker_pool=self,
worker=worker,
runtime=int(
(
timezone.now() - worker.instance_created_at
).total_seconds()
),
)
# Trigger worker termination
with transaction.atomic():
self.provider_interface.terminate_worker(worker)
with transaction.atomic():
# De-assign any pending tasks assigned to the worker
for pending in WorkRequest.objects.reassignable(worker=worker):
pending.de_assign_worker()
# Retry any work requests that were previously running on the worker
for running in WorkRequest.objects.running(worker=worker):
running.mark_aborted()
try:
with context.disable_permission_checks():
running.retry(
reason=WorkRequestRetryReason.WORKER_FAILED
)
except CannotRetry as e:
logger.debug( # noqa: G200
"Cannot retry previously-running work request: %s", e
)
[docs]
def clean(self) -> None:
"""
Ensure that data is valid for this worker pool.
:raise ValidationError: for invalid data.
"""
try:
self.limits_model
self.specifications_model
except pydantic.ValidationError as e:
raise ValidationError(message=str(e)) from e
[docs]
class ScopeWorkerPool(models.Model):
"""Through table for linking a WorkerPool to a Scope."""
worker_pool = models.ForeignKey(WorkerPool, on_delete=models.CASCADE)
scope = models.ForeignKey(Scope, on_delete=models.CASCADE)
priority = models.IntegerField(default=0)
limits = JSONField(default=dict, blank=True)
def __str__(self) -> str:
"""Return the id and name of the ScopeWorkerPool."""
return (
f"Id: {self.id} WorkerPool: {self.worker_pool.name} "
f"Scope: {self.scope.name}"
)
@property
def limits_model(self) -> ScopeWorkerPoolLimits:
"""Return the pydantic model for limits."""
return ScopeWorkerPoolLimits.model_validate(self.limits)
[docs]
class WorkerPoolTaskExecutionStatistics(models.Model):
"""Time spent executing tasks in a scope, stored at completion."""
worker_pool = models.ForeignKey(WorkerPool, on_delete=models.CASCADE)
worker = models.ForeignKey(
"Worker", null=True, blank=True, on_delete=models.SET_NULL
)
scope = models.ForeignKey(Scope, on_delete=models.CASCADE)
timestamp = models.DateTimeField(auto_now_add=True)
runtime = models.IntegerField()
class Meta(TypedModelMeta):
indexes = [
models.Index(
"timestamp", name="%(app_label)s_worker_pool_exec_ts_idx"
),
]
[docs]
class WorkerPoolStatistics(models.Model):
"""Running time for historical worker instances, stored at shutdown."""
worker_pool = models.ForeignKey(WorkerPool, on_delete=models.CASCADE)
worker = models.ForeignKey(
"Worker", null=True, blank=True, on_delete=models.SET_NULL
)
timestamp = models.DateTimeField(auto_now_add=True)
runtime = models.IntegerField()
class Meta(TypedModelMeta):
indexes = [
models.Index(
"timestamp", name="%(app_label)s_worker_pool_stat_ts_idx"
),
]