Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add concurrency for webhook worker #875

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
20 changes: 19 additions & 1 deletion bot/kodiak/app_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
from pathlib import Path
from typing import Any, Optional, Type, TypeVar, overload
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, overload

import databases
from starlette.config import Config, undefined
Expand Down Expand Up @@ -54,6 +54,24 @@ def __call__(
default=["pull_request", "pull_request_review", "pull_request_comment"],
)
)


def parse_worker_concurrency(items: Sequence[str]) -> Mapping[str, int]:
maps = {}
for item in items:
(install, concurrency) = item.split("=", maxsplit=2)
maps[install] = int(concurrency)
return maps


# 12312309=4,1290301293=1
WEBHOOK_WORKER_CONCURRENCY = parse_worker_concurrency(
config(
"WEBHOOK_WORKER_CONCURRENCY",
cast=CommaSeparatedStrings,
default=[],
)
)
USAGE_REPORTING_QUEUE_LENGTH = config(
"USAGE_REPORTING_QUEUE_LENGTH", cast=int, default=10_000
)
Expand Down
2 changes: 1 addition & 1 deletion bot/kodiak/entrypoints/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def main() -> NoReturn:
if task_meta.kind == "repo":
queue.start_repo_worker(queue_name=task_meta.queue_name)
elif task_meta.kind == "webhook":
queue.start_webhook_worker(queue_name=task_meta.queue_name)
await queue.start_webhook_worker(queue_name=task_meta.queue_name)
else:
assert_never(task_meta.kind)
if ingest_queue_watcher.done():
Expand Down
106 changes: 81 additions & 25 deletions bot/kodiak/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
import time
import typing
import urllib
import uuid
from asyncio.tasks import Task
from dataclasses import dataclass
from datetime import timedelta
from typing import Iterator, MutableMapping, NoReturn, Optional, Tuple
from typing import (
Any,
Callable,
Iterator,
MutableMapping,
NoReturn,
Optional,
Sequence,
Tuple,
)

import sentry_sdk
import structlog
Expand Down Expand Up @@ -367,7 +377,9 @@ async def webhook_event_consumer(
scope.set_tag("queue", queue_name)
scope.set_tag("installation", installation_id_from_queue(queue_name))
log = logger.bind(
queue=queue_name, install=installation_id_from_queue(queue_name)
queue=queue_name,
install=installation_id_from_queue(queue_name),
task_id=uuid.uuid4().hex,
)
log.info("start webhook event consumer")
while True:
Expand Down Expand Up @@ -461,7 +473,7 @@ class TaskMeta:
class RedisWebhookQueue:
def __init__(self) -> None:
self.worker_tasks: MutableMapping[
str, tuple[Task[NoReturn], Literal["repo", "webhook"]]
str, list[tuple[Task[NoReturn], Literal["repo", "webhook"]]]
] = {} # type: ignore [assignment]

async def create(self) -> None:
Expand All @@ -476,20 +488,26 @@ async def create(self) -> None:

for webhook_result in webhook_queues:
queue_name = webhook_result.decode()
self.start_webhook_worker(queue_name=queue_name)

def start_webhook_worker(self, *, queue_name: str) -> None:
await self.start_webhook_worker(queue_name=queue_name)

async def start_webhook_worker(self, *, queue_name: str) -> None:
concurrency_str = await redis_bot.get("queue_concurrency:" + queue_name)
try:
concurrency = int(concurrency_str or 1)
except ValueError:
concurrency = 1
self._start_worker(
queue_name,
"webhook",
webhook_event_consumer(webhook_queue=self, queue_name=queue_name),
lambda: webhook_event_consumer(webhook_queue=self, queue_name=queue_name),
concurrency=concurrency,
)

def start_repo_worker(self, *, queue_name: str) -> None:
self._start_worker(
queue_name,
"repo",
repo_queue_consumer(
lambda: repo_queue_consumer(
queue_name=queue_name,
),
)
Expand All @@ -498,22 +516,59 @@ def _start_worker(
self,
key: str,
kind: Literal["repo", "webhook"],
fut: typing.Coroutine[None, None, NoReturn],
fut: Callable[[], typing.Coroutine[None, None, NoReturn]],
*,
concurrency: int = 1,
) -> None:
log = logger.bind(queue_name=key, kind=kind)
worker_task_result = self.worker_tasks.get(key)
if worker_task_result is not None:
worker_task, _task_kind = worker_task_result
if not worker_task.done():
return
log.info("task failed")
# task failed. record result and restart
exception = worker_task.exception()
log.info("exception", excep=exception)
sentry_sdk.capture_exception(exception)
log.info("creating task for queue")
# create new task for queue
self.worker_tasks[key] = (asyncio.create_task(fut), kind)
worker_task_results = () # type: Sequence[tuple[Task[NoReturn], Literal["repo", "webhook"]]]
try:
worker_task_results = self.worker_tasks[key]
except KeyError:
pass
new_workers: list[tuple[asyncio.Task[Any], Literal["repo", "webhook"]]] = []

previous_task_count = len(worker_task_results)
failed_task_count = 0

for (worker_task, _task_kind) in worker_task_results:
if worker_task.done():
log.info("task failed")
# task failed. record result.
exception = worker_task.exception()
log.info("exception", excep=exception)
sentry_sdk.capture_exception(exception)
failed_task_count += 1
else:
new_workers.append((worker_task, _task_kind))
tasks_to_create = concurrency - len(new_workers)

tasks_created = 0
tasks_cancelled = 0
# we need to create tasks
if tasks_to_create > 0:
for _ in range(tasks_to_create):
new_workers.append((asyncio.create_task(fut()), kind))
tasks_created += 1
# we need to remove tasks
elif tasks_to_create < 0:
# split off tasks we need to cancel.
new_workers, workers_to_delete = (
new_workers[:concurrency],
new_workers[concurrency:],
)
for (task, _kind) in workers_to_delete:
task.cancel()
tasks_cancelled += 1

self.worker_tasks[key] = new_workers
log.info(
"start_workers",
previous_task_count=previous_task_count,
failed_task_count=failed_task_count,
tasks_created=tasks_created,
tasks_cancelled=tasks_cancelled,
)

async def enqueue(self, *, event: WebhookEvent) -> None:
"""
Expand All @@ -531,7 +586,7 @@ async def enqueue(self, *, event: WebhookEvent) -> None:
install=event.installation_id,
)
log.info("enqueue webhook event")
self.start_webhook_worker(queue_name=queue_name)
await self.start_webhook_worker(queue_name=queue_name)

async def enqueue_for_repo(
self, *, event: WebhookEvent, first: bool
Expand Down Expand Up @@ -577,8 +632,9 @@ async def enqueue_for_repo(
return find_position((key for key, value in kvs), event.json().encode())

def all_tasks(self) -> Iterator[tuple[TaskMeta, Task[NoReturn]]]:
for queue_name, (task, task_kind) in self.worker_tasks.items():
yield (TaskMeta(kind=task_kind, queue_name=queue_name), task)
for queue_name, tasks in self.worker_tasks.items():
for (task, task_kind) in tasks:
yield (TaskMeta(kind=task_kind, queue_name=queue_name), task)


def get_merge_queue_name(event: WebhookEvent) -> str:
Expand Down