diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index b8c3574..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -include test-requirements.txt diff --git a/Makefile b/Makefile index 88bb990..9352009 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,7 @@ lint: test_deps test: test_deps lint coverage run --source=$$(python setup.py --name) ./test/test.py -v + mypy watchtower docs: sphinx-build docs docs/html diff --git a/setup.cfg b/setup.cfg index 1a6ca11..874e8d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,20 @@ [flake8] max-line-length=120 ignore: E301, E401 + +[mypy] +pretty = true +show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[mypy-boto3.*] +ignore_missing_imports = true + +[mypy-botocore.*] +ignore_missing_imports = true + +[mypy-django.*] +ignore_missing_imports = true diff --git a/setup.py b/setup.py index 01e707b..69a1b21 100755 --- a/setup.py +++ b/setup.py @@ -26,8 +26,10 @@ ] }, packages=find_packages(exclude=["test"]), + package_data={ + "watchtower": ["py.typed"], + }, platforms=["MacOS X", "Posix"], - include_package_data=True, classifiers=[ "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", diff --git a/watchtower/__init__.py b/watchtower/__init__.py index baa03a4..b562664 100644 --- a/watchtower/__init__.py +++ b/watchtower/__init__.py @@ -1,6 +1,8 @@ -from collections.abc import Mapping -from datetime import date, datetime +from collections.abc import Mapping, MutableMapping +from datetime import date, datetime, timedelta +from logging import LogRecord from operator import itemgetter +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import sys, json, logging, time, threading, warnings, functools, platform import queue @@ -9,7 +11,7 @@ from botocore.exceptions import ClientError -def _json_serialize_default(o): +def _json_serialize_default(o: Any) -> Any: """ A standard 'default' json serializer function that will serialize datetime objects as ISO format. """ @@ -17,7 +19,7 @@ def _json_serialize_default(o): return o.isoformat() -def _boto_debug_filter(record): +def _boto_debug_filter(record: LogRecord) -> bool: # Filter debug log messages from botocore and its dependency, urllib3. # This is required to avoid message storms any time we send logs. if record.name.startswith("botocore") and record.levelname == "DEBUG": @@ -27,7 +29,7 @@ def _boto_debug_filter(record): return True -def _boto_filter(record): +def _boto_filter(record: LogRecord) -> bool: # Filter log messages from botocore and its dependency, urllib3. # This is required to avoid an infinite loop when shutting down. if record.name.startswith("botocore"): @@ -90,25 +92,29 @@ class CloudWatchLogFormatter(logging.Formatter): for more details about the 'default' parameter. By default, watchtower uses a serializer that formats datetime objects into strings using the `datetime.isoformat()` method, with no other customizations. """ - add_log_record_attrs = tuple() - - def __init__(self, *args, json_serialize_default: callable = None, add_log_record_attrs: tuple = None, **kwargs): + add_log_record_attrs: Tuple[str, ...] = () + + def __init__( + self, + *args, + json_serialize_default: Optional[Callable[[Any], Any]] = None, + add_log_record_attrs: Optional[Tuple[str, ...]] = None, + **kwargs): super().__init__(*args, **kwargs) - self.json_serialize_default = _json_serialize_default - if json_serialize_default is not None: - self.json_serialize_default = json_serialize_default + self.json_serialize_default = json_serialize_default or _json_serialize_default if add_log_record_attrs is not None: self.add_log_record_attrs = add_log_record_attrs - def format(self, message): + def format(self, message: LogRecord) -> str: + msg: Union[str, MutableMapping] = message.msg if self.add_log_record_attrs: - msg = message.msg if isinstance(message.msg, Mapping) else {"msg": message.msg} + if not isinstance(msg, Mapping): + msg = {"msg": msg} for field in self.add_log_record_attrs: if field != "msg": msg[field] = getattr(message, field) - message.msg = msg - if isinstance(message.msg, Mapping): - return json.dumps(message.msg, default=self.json_serialize_default) + if isinstance(msg, Mapping): + return json.dumps(msg, default=self.json_serialize_default) return super().format(message) @@ -172,29 +178,36 @@ class CloudWatchLogHandler(logging.Handler): # extra size of meta information with each messages EXTRA_MSG_PAYLOAD_SIZE = 26 + queues: Dict[str, queue.Queue] + sequence_tokens: Dict[str, Optional[str]] + threads: List[threading.Thread] + def __init__(self, log_group_name: str = __name__, log_stream_name: str = "{machine_name}/{program_name}/{logger_name}", use_queues: bool = True, - send_interval: int = 60, + send_interval: Union[int, timedelta] = 60, max_batch_size: int = 1024 * 1024, max_batch_count: int = 10000, - boto3_client: botocore.client.BaseClient = None, - boto3_profile_name: str = None, + boto3_client: Optional[botocore.client.BaseClient] = None, + boto3_profile_name: Optional[str] = None, create_log_group: bool = True, - json_serialize_default: callable = None, - log_group_retention_days: int = None, + json_serialize_default: Optional[Callable[[Any], Any]] = None, + log_group_retention_days: Optional[int] = None, create_log_stream: bool = True, max_message_size: int = 256 * 1024, - log_group=None, - stream_name=None, + log_group: Optional[str] = None, + stream_name: Optional[str] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.log_group_name = log_group_name self.log_stream_name = log_stream_name self.use_queues = use_queues - self.send_interval = send_interval + if isinstance(send_interval, timedelta): + self.send_interval = send_interval.total_seconds() + else: + self.send_interval = send_interval self.json_serialize_default = json_serialize_default or _json_serialize_default self.max_batch_size = max_batch_size self.max_batch_count = max_batch_count @@ -236,14 +249,14 @@ def __init__(self, logGroupName=self.log_group_name, retentionInDays=self.log_group_retention_days) - def _at_fork_reinit(self): + def _at_fork_reinit(self) -> None: # This was added in Python 3.9 and should only be called with a recent # version of Python. An older version will attempt to call createLock # instead. - super()._at_fork_reinit() + super()._at_fork_reinit() # type: ignore self._init_state() - def _init_state(self): + def _init_state(self) -> None: self.queues, self.sequence_tokens = {}, {} self.threads = [] self.creating_log_stream, self.shutting_down = False, False @@ -254,7 +267,7 @@ def _paginate(self, boto3_paginator, *args, **kwargs): for value in page.get(result_key.parsed.get("value"), []): yield value - def _ensure_log_group(self): + def _ensure_log_group(self) -> None: try: paginator = self.cwl_client.get_paginator("describe_log_groups") for log_group in self._paginate(paginator, logGroupNamePrefix=self.log_group_name): @@ -264,7 +277,7 @@ def _ensure_log_group(self): pass self._idempotent_call("create_log_group", logGroupName=self.log_group_name) - def _idempotent_call(self, method, *args, **kwargs): + def _idempotent_call(self, method: str, *args, **kwargs) -> None: method_callable = getattr(self.cwl_client, method) try: method_callable(*args, **kwargs) @@ -273,10 +286,10 @@ def _idempotent_call(self, method, *args, **kwargs): pass @functools.lru_cache(maxsize=0) - def _get_machine_name(self): + def _get_machine_name(self) -> str: return platform.node() - def _get_stream_name(self, message): + def _get_stream_name(self, message: LogRecord) -> str: return self.log_stream_name.format( machine_name=self._get_machine_name(), program_name=sys.argv[0], @@ -285,7 +298,7 @@ def _get_stream_name(self, message): strftime=datetime.utcnow() ) - def _submit_batch(self, batch, log_stream_name, max_retries=5): + def _submit_batch(self, batch: Sequence[dict], log_stream_name: str, max_retries: int = 5) -> None: if len(batch) < 1: return sorted_batch = sorted(batch, key=itemgetter('timestamp'), reverse=False) @@ -325,23 +338,23 @@ def _submit_batch(self, batch, log_stream_name, max_retries=5): finally: self.creating_log_stream = False else: - warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning) + warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning) except Exception as e: - warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning) + warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning) # response can be None only when all retries have been exhausted if response is None or "rejectedLogEventsInfo" in response: - warnings.warn("Failed to deliver logs: {}".format(response), WatchtowerWarning) + warnings.warn(f"Failed to deliver logs: {response}", WatchtowerWarning) elif "nextSequenceToken" in response: # According to https://github.com/kislyuk/watchtower/issues/134, nextSequenceToken may sometimes be absent # from the response self.sequence_tokens[log_stream_name] = response["nextSequenceToken"] - def createLock(self): + def createLock(self) -> None: super().createLock() self._init_state() - def emit(self, message): + def emit(self, message: LogRecord) -> None: if self.creating_log_stream: return # Avoid infinite recursion when asked to log a message as our own side effect @@ -375,28 +388,32 @@ def emit(self, message): except Exception: self.handleError(message) - def _dequeue_batch(self, my_queue, stream_name, send_interval, max_batch_size, max_batch_count, max_message_size): - msg = None + def _dequeue_batch(self, my_queue: queue.Queue, stream_name: str, send_interval: int, max_batch_size: int, + max_batch_count: int, max_message_size: int) -> None: + msg: Union[dict, int, None] = None max_message_body_size = max_message_size - CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE + assert max_message_body_size > 0 - def size(_msg): + def size(_msg: Union[dict, int]) -> int: return (len(_msg["message"]) if isinstance(_msg, dict) else 1) + CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE - def truncate(_msg2): + def truncate(_msg2: dict) -> dict: warnings.warn("Log message size exceeds CWL max payload size, truncated", WatchtowerWarning) _msg2["message"] = _msg2["message"][:max_message_body_size] return _msg2 # See https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.put_log_events while msg != self.END: - cur_batch = [] if msg is None or msg == self.FLUSH else [msg] + cur_batch: List[dict] = [] if msg is None or msg == self.FLUSH else [cast(dict, msg)] cur_batch_size = sum(map(size, cur_batch)) cur_batch_msg_count = len(cur_batch) cur_batch_deadline = time.time() + send_interval while True: try: - msg = my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time())) + msg = cast(Union[dict, int], + my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time()))) if size(msg) > max_message_body_size: + assert isinstance(msg, dict) # size always < max_message_body_size if not `dict` msg = truncate(msg) except queue.Empty: # If the queue is empty, we don't want to reprocess the previous message @@ -413,12 +430,13 @@ def truncate(_msg2): my_queue.task_done() break elif msg: + assert isinstance(msg, dict) # mypy can't handle all the or expressions filtering out sentinels cur_batch_size += size(msg) cur_batch_msg_count += 1 cur_batch.append(msg) my_queue.task_done() - def flush(self): + def flush(self) -> None: """ Send any queued messages to CloudWatch. This method does nothing if ``use_queues`` is set to False. """ @@ -431,7 +449,7 @@ def flush(self): for q in self.queues.values(): q.join() - def close(self): + def close(self) -> None: """ Send any queued messages to CloudWatch and prevent further processing of messages. This method does nothing if ``use_queues`` is set to False. @@ -450,6 +468,6 @@ def close(self): q.join() super().close() - def __repr__(self): + def __repr__(self) -> str: name = self.__class__.__name__ return f"{name}(log_group_name='{self.log_group_name}', log_stream_name='{self.log_stream_name}')" diff --git a/watchtower/py.typed b/watchtower/py.typed new file mode 100644 index 0000000..e69de29