From f6468b70b071ca53c4f235883048549d6efcb825 Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Tue, 23 Apr 2024 00:16:01 -0700 Subject: [PATCH] type fixes --- aiomultiprocess/core.py | 32 +++++++++++++++++++------------- aiomultiprocess/pool.py | 10 +++++----- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/aiomultiprocess/core.py b/aiomultiprocess/core.py index 75587a6..01b321a 100644 --- a/aiomultiprocess/core.py +++ b/aiomultiprocess/core.py @@ -4,10 +4,11 @@ import asyncio import logging import multiprocessing +import multiprocessing.context import multiprocessing.managers import os import sys -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, cast, Dict, Optional, Sequence, Union from .types import Context, R, Unit @@ -16,7 +17,12 @@ # shared context for all multiprocessing primitives, for platform compatibility # "spawn" is default/required on windows and mac, but can't execute non-global functions # see https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods -context = multiprocessing.get_context(DEFAULT_START_METHOD) +ContextTypes = Union[ + multiprocessing.context.ForkContext, + multiprocessing.context.ForkServerContext, + multiprocessing.context.SpawnContext, +] +context = cast(ContextTypes, multiprocessing.get_context(DEFAULT_START_METHOD)) _manager = None log = logging.getLogger(__name__) @@ -50,7 +56,7 @@ def set_start_method(method: Optional[str] = DEFAULT_START_METHOD) -> None: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods """ global context - context = multiprocessing.get_context(method) + context = cast(ContextTypes, multiprocessing.get_context(method)) def get_context() -> Context: @@ -80,12 +86,12 @@ class Process: def __init__( self, group: None = None, - target: Callable = None, - name: str = None, - args: Sequence[Any] = None, - kwargs: Dict[str, Any] = None, + target: Optional[Callable] = None, + name: Optional[str] = None, + args: Optional[Sequence[Any]] = None, + kwargs: Optional[Dict[str, Any]] = None, *, - daemon: bool = None, + daemon: Optional[bool] = None, initializer: Optional[Callable] = None, initargs: Sequence[Any] = (), loop_initializer: Optional[Callable] = None, @@ -127,7 +133,7 @@ def __await__(self) -> Any: return self.join().__await__() @staticmethod - def run_async(unit: Unit) -> R: + def run_async(unit: Unit) -> R: # type: ignore[type-var] """Initialize the child process and event loop, then execute the coroutine.""" try: if unit.loop_initializer is None: @@ -152,7 +158,7 @@ def start(self) -> None: """Start the child process.""" return self.aio_process.start() - async def join(self, timeout: int = None) -> None: + async def join(self, timeout: Optional[int] = None) -> None: """Wait for the process to finish execution without blocking the main thread.""" if not self.is_alive() and self.exitcode is None: raise ValueError("must start process before joining it") @@ -216,7 +222,7 @@ def __init__(self, *args, **kwargs) -> None: self.unit.namespace.result = None @staticmethod - def run_async(unit: Unit) -> R: + def run_async(unit: Unit) -> R: # type: ignore[type-var] """Initialize the child process and event loop, then execute the coroutine.""" try: result: R = Process.run_async(unit) @@ -227,13 +233,13 @@ def run_async(unit: Unit) -> R: unit.namespace.result = e raise - async def join(self, timeout: int = None) -> Any: + async def join(self, timeout: Optional[int] = None) -> Any: """Wait for the worker to finish, and return the final result.""" await super().join(timeout) return self.result @property - def result(self) -> R: + def result(self) -> R: # type: ignore[type-var] """Easy access to the resulting value from the coroutine.""" if self.exitcode is None: raise ValueError("coroutine not completed") diff --git a/aiomultiprocess/pool.py b/aiomultiprocess/pool.py index a3e0821..93e7603 100644 --- a/aiomultiprocess/pool.py +++ b/aiomultiprocess/pool.py @@ -150,13 +150,13 @@ class Pool: def __init__( self, - processes: int = None, - initializer: Callable[..., None] = None, + processes: Optional[int] = None, + initializer: Optional[Callable[..., None]] = None, initargs: Sequence[Any] = (), maxtasksperchild: int = MAX_TASKS_PER_CHILD, childconcurrency: int = CHILD_CONCURRENCY, queuecount: Optional[int] = None, - scheduler: Scheduler = None, + scheduler: Optional[Scheduler] = None, loop_initializer: Optional[LoopInitializer] = None, exception_handler: Optional[Callable[[BaseException], None]] = None, ) -> None: @@ -316,8 +316,8 @@ async def results(self, tids: Sequence[TaskID]) -> Sequence[R]: async def apply( self, func: Callable[..., Awaitable[R]], - args: Sequence[Any] = None, - kwds: Dict[str, Any] = None, + args: Optional[Sequence[Any]] = None, + kwds: Optional[Dict[str, Any]] = None, ) -> R: """Run a single coroutine on the pool.""" if not self.running: