Skip to content

Commit

Permalink
type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
amyreese committed Apr 23, 2024
1 parent bba6e65 commit da4efe3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
32 changes: 19 additions & 13 deletions aiomultiprocess/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import asyncio
import logging
import multiprocessing
import multiprocessing.context
import multiprocessing.managers
import os
import sys
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, cast, Dict, Optional, Sequence, Union

from .types import Context, R, Unit

Expand All @@ -16,7 +17,12 @@
# shared context for all multiprocessing primitives, for platform compatibility
# "spawn" is default/required on windows and mac, but can't execute non-global functions
# see https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
context = multiprocessing.get_context(DEFAULT_START_METHOD)
ContextTypes = Union[
multiprocessing.context.ForkContext,
multiprocessing.context.ForkServerContext,
multiprocessing.context.SpawnContext,
]
context = cast(ContextTypes, multiprocessing.get_context(DEFAULT_START_METHOD))
_manager = None

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,7 +56,7 @@ def set_start_method(method: Optional[str] = DEFAULT_START_METHOD) -> None:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""
global context
context = multiprocessing.get_context(method)
context = cast(ContextTypes, multiprocessing.get_context(method))


def get_context() -> Context:
Expand Down Expand Up @@ -80,12 +86,12 @@ class Process:
def __init__(
self,
group: None = None,
target: Callable = None,
name: str = None,
args: Sequence[Any] = None,
kwargs: Dict[str, Any] = None,
target: Optional[Callable] = None,
name: Optional[str] = None,
args: Optional[Sequence[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
*,
daemon: bool = None,
daemon: Optional[bool] = None,
initializer: Optional[Callable] = None,
initargs: Sequence[Any] = (),
loop_initializer: Optional[Callable] = None,
Expand Down Expand Up @@ -127,7 +133,7 @@ def __await__(self) -> Any:
return self.join().__await__()

@staticmethod
def run_async(unit: Unit) -> R:
def run_async(unit: Unit) -> R: # type: ignore[type-var]
"""Initialize the child process and event loop, then execute the coroutine."""
try:
if unit.loop_initializer is None:
Expand All @@ -152,7 +158,7 @@ def start(self) -> None:
"""Start the child process."""
return self.aio_process.start()

async def join(self, timeout: int = None) -> None:
async def join(self, timeout: Optional[int] = None) -> None:
"""Wait for the process to finish execution without blocking the main thread."""
if not self.is_alive() and self.exitcode is None:
raise ValueError("must start process before joining it")
Expand Down Expand Up @@ -216,7 +222,7 @@ def __init__(self, *args, **kwargs) -> None:
self.unit.namespace.result = None

@staticmethod
def run_async(unit: Unit) -> R:
def run_async(unit: Unit) -> R: # type: ignore[type-var]
"""Initialize the child process and event loop, then execute the coroutine."""
try:
result: R = Process.run_async(unit)
Expand All @@ -227,13 +233,13 @@ def run_async(unit: Unit) -> R:
unit.namespace.result = e
raise

async def join(self, timeout: int = None) -> Any:
async def join(self, timeout: Optional[int] = None) -> Any:
"""Wait for the worker to finish, and return the final result."""
await super().join(timeout)
return self.result

@property
def result(self) -> R:
def result(self) -> R: # type: ignore[type-var]
"""Easy access to the resulting value from the coroutine."""
if self.exitcode is None:
raise ValueError("coroutine not completed")
Expand Down
10 changes: 5 additions & 5 deletions aiomultiprocess/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ class Pool:

def __init__(
self,
processes: int = None,
initializer: Callable[..., None] = None,
processes: Optional[int] = None,
initializer: Optional[Callable[..., None]] = None,
initargs: Sequence[Any] = (),
maxtasksperchild: int = MAX_TASKS_PER_CHILD,
childconcurrency: int = CHILD_CONCURRENCY,
queuecount: Optional[int] = None,
scheduler: Scheduler = None,
scheduler: Optional[Scheduler] = None,
loop_initializer: Optional[LoopInitializer] = None,
exception_handler: Optional[Callable[[BaseException], None]] = None,
) -> None:
Expand Down Expand Up @@ -316,8 +316,8 @@ async def results(self, tids: Sequence[TaskID]) -> Sequence[R]:
async def apply(
self,
func: Callable[..., Awaitable[R]],
args: Sequence[Any] = None,
kwds: Dict[str, Any] = None,
args: Optional[Sequence[Any]] = None,
kwds: Optional[Dict[str, Any]] = None,
) -> R:
"""Run a single coroutine on the pool."""
if not self.running:
Expand Down

0 comments on commit da4efe3

Please sign in to comment.