Skip to content

Commit

Permalink
add more comprehensive typing and check with mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
terencehonles committed Apr 30, 2021
1 parent adab2ce commit db5c90d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 20 deletions.
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
SHELL=/bin/bash

test_deps:
pip install coverage flake8 wheel pyyaml boto3
pip install \
boto3 \
coverage \
flake8 \
mypy \
pyyaml \
wheel

lint: test_deps
flake8

test: test_deps lint
coverage run --source=watchtower ./test/test.py
mypy watchtower

docs:
sphinx-build docs docs/html
Expand Down
17 changes: 17 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
[flake8]
max-line-length=120
ignore: E301, E401

[mypy]
pretty = true
show_error_codes = true
show_error_context = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_unreachable = true

[mypy-boto3.*]
ignore_missing_imports = true

[mypy-botocore.*]
ignore_missing_imports = true

[mypy-django.*]
ignore_missing_imports = true
68 changes: 49 additions & 19 deletions watchtower/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections.abc import Mapping
from datetime import date, datetime, timedelta
from logging import LogRecord
from operator import itemgetter
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Union, TYPE_CHECKING
import json, logging, time, threading, warnings
import queue

Expand All @@ -10,23 +11,46 @@
from botocore.exceptions import ClientError


def _idempotent_create(client, method, *args, **kwargs):
if TYPE_CHECKING:
from typing_extensions import Protocol

class LogClient(Protocol):
exceptions: Any

def create_log_group(self, logGroupName: str) -> None:
...

def create_log_stream(self, logGroupName: str, logStreamName: str) -> None:
...

def put_log_events(self, **kwargs: Any) -> dict:
...

def put_retention_policy(self, logGroupName: str, retentionInDays: Optional[int]) -> None:
...

class Session(Protocol):
def client(self, service: str, endpoint_url: Optional[str] = None) -> LogClient:
...


def _idempotent_create(client: LogClient, method: str, *args: Any, **kwargs: Any) -> Any:
method_callable = getattr(client, method)
try:
method_callable(*args, **kwargs)
except (client.exceptions.OperationAbortedException, client.exceptions.ResourceAlreadyExistsException):
pass


def _json_serialize_default(o):
def _json_serialize_default(o: Any) -> Any:
"""
A standard 'default' json serializer function that will serialize datetime objects as ISO format.
"""
if isinstance(o, (date, datetime)):
return o.isoformat()


def _boto_debug_filter(record):
def _boto_debug_filter(record: LogRecord) -> bool:
# Filter debug log messages from botocore and its dependency, urllib3.
# This is required to avoid message storms any time we send logs.
if record.name.startswith("botocore") and record.levelname == "DEBUG":
Expand All @@ -36,7 +60,7 @@ def _boto_debug_filter(record):
return True


def _boto_filter(record):
def _boto_filter(record: LogRecord) -> bool:
# Filter log messages from botocore and its dependency, urllib3.
# This is required to avoid an infinite loop when shutting down.
if record.name.startswith("botocore"):
Expand Down Expand Up @@ -102,7 +126,7 @@ class CloudWatchLogHandler(logging.Handler):
EXTRA_MSG_PAYLOAD_SIZE = 26

@staticmethod
def _get_session(boto3_session, boto3_profile_name):
def _get_session(boto3_session: Optional[Session], boto3_profile_name: Optional[str]) -> Session:
if boto3_session:
return boto3_session

Expand Down Expand Up @@ -130,8 +154,9 @@ def __init__(self, log_group: str = __name__, stream_name: Optional[str] = None,
self.max_batch_size = max_batch_size
self.max_batch_count = max_batch_count
self.max_message_size = max_message_size
self.queues, self.sequence_tokens = {}, {}
self.threads = []
self.queues: Dict[str, queue.Queue] = {}
self.sequence_tokens: Dict[str, Optional[str]] = {}
self.threads: List[threading.Thread] = []
self.creating_log_stream, self.shutting_down = False, False
self.create_log_stream = create_log_stream
self.log_group_retention_days = log_group_retention_days
Expand All @@ -150,7 +175,7 @@ def __init__(self, log_group: str = __name__, stream_name: Optional[str] = None,

self.addFilter(_boto_debug_filter)

def _submit_batch(self, batch, stream_name, max_retries=5):
def _submit_batch(self, batch: Sequence[dict], stream_name: str, max_retries: int = 5) -> None:
if len(batch) < 1:
return
sorted_batch = sorted(batch, key=itemgetter('timestamp'), reverse=False)
Expand Down Expand Up @@ -203,7 +228,7 @@ def _submit_batch(self, batch, stream_name, max_retries=5):
# from the response
self.sequence_tokens[stream_name] = response["nextSequenceToken"]

def emit(self, message):
def emit(self, message: LogRecord) -> None:
if self.creating_log_stream:
return # Avoid infinite recursion when asked to log a message as our own side effect
stream_name = self.stream_name
Expand All @@ -214,7 +239,7 @@ def emit(self, message):
if stream_name not in self.sequence_tokens:
self.sequence_tokens[stream_name] = None

if isinstance(message.msg, Mapping):
if isinstance(cast(Union[dict, str], message.msg), Mapping):
message.msg = json.dumps(message.msg, default=self.json_serialize_default)

cwl_message = dict(timestamp=int(message.created * 1000), message=self.format(message))
Expand All @@ -235,28 +260,32 @@ def emit(self, message):
else:
self._submit_batch([cwl_message], stream_name)

def batch_sender(self, my_queue, stream_name, send_interval, max_batch_size, max_batch_count, max_message_size):
msg = None
def batch_sender(self, my_queue: queue.Queue, stream_name: str, send_interval: int, max_batch_size: int,
max_batch_count: int, max_message_size: int) -> None:
msg: Union[dict, int, None] = None
max_message_body_size = max_message_size - CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE
assert max_message_body_size > 0

def size(_msg):
def size(_msg: Union[dict, int]) -> int:
return (len(_msg["message"]) if isinstance(_msg, dict) else 1) + CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE

def truncate(_msg2):
def truncate(_msg2: dict) -> dict:
warnings.warn("Log message size exceeds CWL max payload size, truncated", WatchtowerWarning)
_msg2["message"] = _msg2["message"][:max_message_body_size]
return _msg2

# See https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.put_log_events
while msg != self.END:
cur_batch = [] if msg is None or msg == self.FLUSH else [msg]
cur_batch: List[dict] = [] if msg is None or msg == self.FLUSH else [cast(dict, msg)]
cur_batch_size = sum(map(size, cur_batch))
cur_batch_msg_count = len(cur_batch)
cur_batch_deadline = time.time() + send_interval
while True:
try:
msg = my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time()))
msg = cast(Union[dict, int],
my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time())))
if size(msg) > max_message_body_size:
assert isinstance(msg, dict) # size always < max_message_body_size if not `dict`
msg = truncate(msg)
except queue.Empty:
# If the queue is empty, we don't want to reprocess the previous message
Expand All @@ -273,12 +302,13 @@ def truncate(_msg2):
my_queue.task_done()
break
elif msg:
assert isinstance(msg, dict) # mypy can't handle all the or expressions filtering out sentinels
cur_batch_size += size(msg)
cur_batch_msg_count += 1
cur_batch.append(msg)
my_queue.task_done()

def flush(self):
def flush(self) -> None:
"""
Send any queued messages to CloudWatch. This method does nothing if ``use_queues`` is set to False.
"""
Expand All @@ -291,7 +321,7 @@ def flush(self):
for q in self.queues.values():
q.join()

def close(self):
def close(self) -> None:
"""
Send any queued messages to CloudWatch and prevent further processing of messages.
This method does nothing if ``use_queues`` is set to False.
Expand Down

0 comments on commit db5c90d

Please sign in to comment.