Skip to content

Commit

Permalink
add more comprehensive typing and check with mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
terencehonles committed Nov 15, 2021
1 parent 98a58de commit 7e51bb0
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 48 deletions.
1 change: 0 additions & 1 deletion MANIFEST.in

This file was deleted.

1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ lint: test_deps

test: test_deps lint
coverage run --source=$$(python setup.py --name) ./test/test.py -v
mypy watchtower

docs:
sphinx-build docs docs/html
Expand Down
17 changes: 17 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
[flake8]
max-line-length=120
ignore: E301, E401

[mypy]
pretty = true
show_error_codes = true
show_error_context = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_unreachable = true

[mypy-boto3.*]
ignore_missing_imports = true

[mypy-botocore.*]
ignore_missing_imports = true

[mypy-django.*]
ignore_missing_imports = true
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
]
},
packages=find_packages(exclude=["test"]),
package_data={
"watchtower": ["py.typed"],
},
platforms=["MacOS X", "Posix"],
include_package_data=True,
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
Expand Down
110 changes: 64 additions & 46 deletions watchtower/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Mapping
from datetime import date, datetime
from collections.abc import Mapping, MutableMapping
from datetime import date, datetime, timedelta
from logging import LogRecord
from operator import itemgetter
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
import sys, json, logging, time, threading, warnings, functools, platform
import queue

Expand All @@ -9,15 +11,15 @@
from botocore.exceptions import ClientError


def _json_serialize_default(o):
def _json_serialize_default(o: Any) -> Any:
"""
A standard 'default' json serializer function that will serialize datetime objects as ISO format.
"""
if isinstance(o, (date, datetime)):
return o.isoformat()


def _boto_debug_filter(record):
def _boto_debug_filter(record: LogRecord) -> bool:
# Filter debug log messages from botocore and its dependency, urllib3.
# This is required to avoid message storms any time we send logs.
if record.name.startswith("botocore") and record.levelname == "DEBUG":
Expand All @@ -27,7 +29,7 @@ def _boto_debug_filter(record):
return True


def _boto_filter(record):
def _boto_filter(record: LogRecord) -> bool:
# Filter log messages from botocore and its dependency, urllib3.
# This is required to avoid an infinite loop when shutting down.
if record.name.startswith("botocore"):
Expand Down Expand Up @@ -90,25 +92,29 @@ class CloudWatchLogFormatter(logging.Formatter):
for more details about the 'default' parameter. By default, watchtower uses a serializer that formats datetime
objects into strings using the `datetime.isoformat()` method, with no other customizations.
"""
add_log_record_attrs = tuple()

def __init__(self, *args, json_serialize_default: callable = None, add_log_record_attrs: tuple = None, **kwargs):
add_log_record_attrs: Tuple[str, ...] = ()

def __init__(
self,
*args,
json_serialize_default: Optional[Callable[[Any], Any]] = None,
add_log_record_attrs: Optional[Tuple[str, ...]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.json_serialize_default = _json_serialize_default
if json_serialize_default is not None:
self.json_serialize_default = json_serialize_default
self.json_serialize_default = json_serialize_default or _json_serialize_default
if add_log_record_attrs is not None:
self.add_log_record_attrs = add_log_record_attrs

def format(self, message):
def format(self, message: LogRecord) -> str:
msg: Union[str, MutableMapping] = message.msg
if self.add_log_record_attrs:
msg = message.msg if isinstance(message.msg, Mapping) else {"msg": message.msg}
if not isinstance(msg, Mapping):
msg = {"msg": msg}
for field in self.add_log_record_attrs:
if field != "msg":
msg[field] = getattr(message, field)
message.msg = msg
if isinstance(message.msg, Mapping):
return json.dumps(message.msg, default=self.json_serialize_default)
if isinstance(msg, Mapping):
return json.dumps(msg, default=self.json_serialize_default)
return super().format(message)


Expand Down Expand Up @@ -172,29 +178,36 @@ class CloudWatchLogHandler(logging.Handler):
# extra size of meta information with each messages
EXTRA_MSG_PAYLOAD_SIZE = 26

queues: Dict[str, queue.Queue]
sequence_tokens: Dict[str, Optional[str]]
threads: List[threading.Thread]

def __init__(self,
log_group_name: str = __name__,
log_stream_name: str = "{machine_name}/{program_name}/{logger_name}",
use_queues: bool = True,
send_interval: int = 60,
send_interval: Union[int, timedelta] = 60,
max_batch_size: int = 1024 * 1024,
max_batch_count: int = 10000,
boto3_client: botocore.client.BaseClient = None,
boto3_profile_name: str = None,
boto3_client: Optional[botocore.client.BaseClient] = None,
boto3_profile_name: Optional[str] = None,
create_log_group: bool = True,
json_serialize_default: callable = None,
log_group_retention_days: int = None,
json_serialize_default: Optional[Callable[[Any], Any]] = None,
log_group_retention_days: Optional[int] = None,
create_log_stream: bool = True,
max_message_size: int = 256 * 1024,
log_group=None,
stream_name=None,
log_group: Optional[str] = None,
stream_name: Optional[str] = None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.log_group_name = log_group_name
self.log_stream_name = log_stream_name
self.use_queues = use_queues
self.send_interval = send_interval
if isinstance(send_interval, timedelta):
self.send_interval = send_interval.total_seconds()
else:
self.send_interval = send_interval
self.json_serialize_default = json_serialize_default or _json_serialize_default
self.max_batch_size = max_batch_size
self.max_batch_count = max_batch_count
Expand Down Expand Up @@ -236,14 +249,14 @@ def __init__(self,
logGroupName=self.log_group_name,
retentionInDays=self.log_group_retention_days)

def _at_fork_reinit(self):
def _at_fork_reinit(self) -> None:
# This was added in Python 3.9 and should only be called with a recent
# version of Python. An older version will attempt to call createLock
# instead.
super()._at_fork_reinit()
super()._at_fork_reinit() # type: ignore
self._init_state()

def _init_state(self):
def _init_state(self) -> None:
self.queues, self.sequence_tokens = {}, {}
self.threads = []
self.creating_log_stream, self.shutting_down = False, False
Expand All @@ -254,7 +267,7 @@ def _paginate(self, boto3_paginator, *args, **kwargs):
for value in page.get(result_key.parsed.get("value"), []):
yield value

def _ensure_log_group(self):
def _ensure_log_group(self) -> None:
try:
paginator = self.cwl_client.get_paginator("describe_log_groups")
for log_group in self._paginate(paginator, logGroupNamePrefix=self.log_group_name):
Expand All @@ -264,7 +277,7 @@ def _ensure_log_group(self):
pass
self._idempotent_call("create_log_group", logGroupName=self.log_group_name)

def _idempotent_call(self, method, *args, **kwargs):
def _idempotent_call(self, method: str, *args, **kwargs) -> None:
method_callable = getattr(self.cwl_client, method)
try:
method_callable(*args, **kwargs)
Expand All @@ -273,10 +286,10 @@ def _idempotent_call(self, method, *args, **kwargs):
pass

@functools.lru_cache(maxsize=0)
def _get_machine_name(self):
def _get_machine_name(self) -> str:
return platform.node()

def _get_stream_name(self, message):
def _get_stream_name(self, message: LogRecord) -> str:
return self.log_stream_name.format(
machine_name=self._get_machine_name(),
program_name=sys.argv[0],
Expand All @@ -285,7 +298,7 @@ def _get_stream_name(self, message):
strftime=datetime.utcnow()
)

def _submit_batch(self, batch, log_stream_name, max_retries=5):
def _submit_batch(self, batch: Sequence[dict], log_stream_name: str, max_retries: int = 5) -> None:
if len(batch) < 1:
return
sorted_batch = sorted(batch, key=itemgetter('timestamp'), reverse=False)
Expand Down Expand Up @@ -325,23 +338,23 @@ def _submit_batch(self, batch, log_stream_name, max_retries=5):
finally:
self.creating_log_stream = False
else:
warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning)
warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning)
except Exception as e:
warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning)
warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning)

# response can be None only when all retries have been exhausted
if response is None or "rejectedLogEventsInfo" in response:
warnings.warn("Failed to deliver logs: {}".format(response), WatchtowerWarning)
warnings.warn(f"Failed to deliver logs: {response}", WatchtowerWarning)
elif "nextSequenceToken" in response:
# According to https://github.com/kislyuk/watchtower/issues/134, nextSequenceToken may sometimes be absent
# from the response
self.sequence_tokens[log_stream_name] = response["nextSequenceToken"]

def createLock(self):
def createLock(self) -> None:
super().createLock()
self._init_state()

def emit(self, message):
def emit(self, message: LogRecord) -> None:
if self.creating_log_stream:
return # Avoid infinite recursion when asked to log a message as our own side effect

Expand Down Expand Up @@ -375,28 +388,32 @@ def emit(self, message):
except Exception:
self.handleError(message)

def _dequeue_batch(self, my_queue, stream_name, send_interval, max_batch_size, max_batch_count, max_message_size):
msg = None
def _dequeue_batch(self, my_queue: queue.Queue, stream_name: str, send_interval: int, max_batch_size: int,
max_batch_count: int, max_message_size: int) -> None:
msg: Union[dict, int, None] = None
max_message_body_size = max_message_size - CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE
assert max_message_body_size > 0

def size(_msg):
def size(_msg: Union[dict, int]) -> int:
return (len(_msg["message"]) if isinstance(_msg, dict) else 1) + CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE

def truncate(_msg2):
def truncate(_msg2: dict) -> dict:
warnings.warn("Log message size exceeds CWL max payload size, truncated", WatchtowerWarning)
_msg2["message"] = _msg2["message"][:max_message_body_size]
return _msg2

# See https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.put_log_events
while msg != self.END:
cur_batch = [] if msg is None or msg == self.FLUSH else [msg]
cur_batch: List[dict] = [] if msg is None or msg == self.FLUSH else [cast(dict, msg)]
cur_batch_size = sum(map(size, cur_batch))
cur_batch_msg_count = len(cur_batch)
cur_batch_deadline = time.time() + send_interval
while True:
try:
msg = my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time()))
msg = cast(Union[dict, int],
my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time())))
if size(msg) > max_message_body_size:
assert isinstance(msg, dict) # size always < max_message_body_size if not `dict`
msg = truncate(msg)
except queue.Empty:
# If the queue is empty, we don't want to reprocess the previous message
Expand All @@ -413,12 +430,13 @@ def truncate(_msg2):
my_queue.task_done()
break
elif msg:
assert isinstance(msg, dict) # mypy can't handle all the or expressions filtering out sentinels
cur_batch_size += size(msg)
cur_batch_msg_count += 1
cur_batch.append(msg)
my_queue.task_done()

def flush(self):
def flush(self) -> None:
"""
Send any queued messages to CloudWatch. This method does nothing if ``use_queues`` is set to False.
"""
Expand All @@ -431,7 +449,7 @@ def flush(self):
for q in self.queues.values():
q.join()

def close(self):
def close(self) -> None:
"""
Send any queued messages to CloudWatch and prevent further processing of messages.
This method does nothing if ``use_queues`` is set to False.
Expand All @@ -450,6 +468,6 @@ def close(self):
q.join()
super().close()

def __repr__(self):
def __repr__(self) -> str:
name = self.__class__.__name__
return f"{name}(log_group_name='{self.log_group_name}', log_stream_name='{self.log_stream_name}')"
Empty file added watchtower/py.typed
Empty file.

0 comments on commit 7e51bb0

Please sign in to comment.