Skip to content

Commit

Permalink
improve lifespan typecheck and debug (#4014)
Browse files Browse the repository at this point in the history
* add lifespan debug statement

* improve some of the logic for lifespan tasks

* fix partial name with update_wrapper
  • Loading branch information
Lendemor authored Sep 27, 2024
1 parent 9ca5d4a commit 1b3422d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
30 changes: 24 additions & 6 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import contextlib
import functools
import inspect
import sys
from typing import Callable, Coroutine, Set, Union

from fastapi import FastAPI

from reflex.utils import console
from reflex.utils.exceptions import InvalidLifespanTaskType

from .mixin import AppMixin


Expand All @@ -26,6 +28,7 @@ async def _run_lifespan_tasks(self, app: FastAPI):
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # type: ignore
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
Expand All @@ -35,23 +38,38 @@ async def _run_lifespan_tasks(self, app: FastAPI):
_t = task()
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(_t)
console.debug(run_msg.format(type="asynccontextmanager"))
elif isinstance(_t, Coroutine):
running_tasks.append(asyncio.create_task(_t))
task_ = asyncio.create_task(_t)
task_.add_done_callback(lambda t: t.result())
running_tasks.append(task_)
console.debug(run_msg.format(type="coroutine"))
else:
console.debug(run_msg.format(type="function"))
yield
finally:
cancel_kwargs = (
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
)
for task in running_tasks:
task.cancel(**cancel_kwargs)
console.debug(f"Canceling lifespan task: {task}")
task.cancel(msg="lifespan_cleanup")

def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
Raises:
InvalidLifespanTaskType: If the task is a generator function.
"""
if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
raise InvalidLifespanTaskType(
f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
)

if task_kwargs:
original_task = task
task = functools.partial(task, **task_kwargs) # type: ignore
functools.update_wrapper(task, original_task) # type: ignore
self.lifespan_tasks.add(task) # type: ignore
console.debug(f"Registered lifespan task: {task.__name__}") # type: ignore
4 changes: 4 additions & 0 deletions reflex/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,7 @@ class GeneratedCodeHasNoFunctionDefs(ReflexError):

class PrimitiveUnserializableToJSON(ReflexError, ValueError):
"""Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""


class InvalidLifespanTaskType(ReflexError, TypeError):
"""Raised when an invalid task type is registered as a lifespan task."""

0 comments on commit 1b3422d

Please sign in to comment.