Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more comprehensive typing and check with mypy #144

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion MANIFEST.in

This file was deleted.

1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ lint: test_deps

test: test_deps lint
coverage run --source=$$(python setup.py --name) ./test/test.py -v
mypy watchtower

docs:
sphinx-build docs docs/html
Expand Down
17 changes: 17 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
[flake8]
max-line-length=120
ignore: E301, E401

[mypy]
pretty = true
show_error_codes = true
show_error_context = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_unreachable = true

[mypy-boto3.*]
ignore_missing_imports = true

[mypy-botocore.*]
ignore_missing_imports = true

[mypy-django.*]
ignore_missing_imports = true
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
]
},
packages=find_packages(exclude=["test"]),
package_data={
"watchtower": ["py.typed"],
},
platforms=["MacOS X", "Posix"],
include_package_data=True,
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
Expand Down
107 changes: 62 additions & 45 deletions watchtower/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Mapping
from datetime import date, datetime
from collections.abc import Mapping, MutableMapping
from datetime import date, datetime, timedelta
from logging import LogRecord
from operator import itemgetter
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
import sys, json, logging, time, threading, warnings, functools, platform
import queue

Expand All @@ -9,15 +11,15 @@
from botocore.exceptions import ClientError


def _json_serialize_default(o):
def _json_serialize_default(o: Any) -> Any:
"""
A standard 'default' json serializer function that will serialize datetime objects as ISO format.
"""
if isinstance(o, (date, datetime)):
return o.isoformat()


def _boto_debug_filter(record):
def _boto_debug_filter(record: LogRecord) -> bool:
# Filter debug log messages from botocore and its dependency, urllib3.
# This is required to avoid message storms any time we send logs.
if record.name.startswith("botocore") and record.levelname == "DEBUG":
Expand All @@ -27,7 +29,7 @@ def _boto_debug_filter(record):
return True


def _boto_filter(record):
def _boto_filter(record: LogRecord) -> bool:
# Filter log messages from botocore and its dependency, urllib3.
# This is required to avoid an infinite loop when shutting down.
if record.name.startswith("botocore"):
Expand Down Expand Up @@ -90,25 +92,28 @@ class CloudWatchLogFormatter(logging.Formatter):
for more details about the 'default' parameter. By default, watchtower uses a serializer that formats datetime
objects into strings using the `datetime.isoformat()` method, with no other customizations.
"""
add_log_record_attrs = tuple()
add_log_record_attrs: Tuple[str, ...] = ()

def __init__(self, *args, json_serialize_default: callable = None, add_log_record_attrs: tuple = None, **kwargs):
def __init__(self,
*args,
json_serialize_default: Optional[Callable[[Any], Any]] = None,
add_log_record_attrs: Optional[Tuple[str, ...]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.json_serialize_default = _json_serialize_default
if json_serialize_default is not None:
self.json_serialize_default = json_serialize_default
self.json_serialize_default = json_serialize_default or _json_serialize_default
if add_log_record_attrs is not None:
self.add_log_record_attrs = add_log_record_attrs

def format(self, message):
def format(self, message: LogRecord) -> str:
msg: Union[str, MutableMapping] = message.msg
if self.add_log_record_attrs:
msg = message.msg if isinstance(message.msg, Mapping) else {"msg": message.msg}
if not isinstance(msg, Mapping):
msg = {"msg": msg}
for field in self.add_log_record_attrs:
if field != "msg":
msg[field] = getattr(message, field)
message.msg = msg
if isinstance(message.msg, Mapping):
return json.dumps(message.msg, default=self.json_serialize_default)
if isinstance(msg, Mapping):
return json.dumps(msg, default=self.json_serialize_default)
return super().format(message)


Expand Down Expand Up @@ -172,29 +177,36 @@ class CloudWatchLogHandler(logging.Handler):
# extra size of meta information with each messages
EXTRA_MSG_PAYLOAD_SIZE = 26

queues: Dict[str, queue.Queue]
sequence_tokens: Dict[str, Optional[str]]
threads: List[threading.Thread]

def __init__(self,
log_group_name: str = __name__,
log_stream_name: str = "{machine_name}/{program_name}/{logger_name}",
use_queues: bool = True,
send_interval: int = 60,
send_interval: Union[int, timedelta] = 60,
max_batch_size: int = 1024 * 1024,
max_batch_count: int = 10000,
boto3_client: botocore.client.BaseClient = None,
boto3_profile_name: str = None,
boto3_client: Optional[botocore.client.BaseClient] = None,
boto3_profile_name: Optional[str] = None,
create_log_group: bool = True,
json_serialize_default: callable = None,
log_group_retention_days: int = None,
json_serialize_default: Optional[Callable[[Any], Any]] = None,
log_group_retention_days: Optional[int] = None,
create_log_stream: bool = True,
max_message_size: int = 256 * 1024,
log_group=None,
stream_name=None,
log_group: Optional[str] = None,
stream_name: Optional[str] = None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.log_group_name = log_group_name
self.log_stream_name = log_stream_name
self.use_queues = use_queues
self.send_interval = send_interval
if isinstance(send_interval, timedelta):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was updated to match the text, but the :type: was previously just int

self.send_interval = send_interval.total_seconds()
else:
self.send_interval = send_interval
self.json_serialize_default = json_serialize_default or _json_serialize_default
self.max_batch_size = max_batch_size
self.max_batch_count = max_batch_count
Expand Down Expand Up @@ -236,14 +248,14 @@ def __init__(self,
logGroupName=self.log_group_name,
retentionInDays=self.log_group_retention_days)

def _at_fork_reinit(self):
def _at_fork_reinit(self) -> None:
# This was added in Python 3.9 and should only be called with a recent
# version of Python. An older version will attempt to call createLock
# instead.
super()._at_fork_reinit()
super()._at_fork_reinit() # type: ignore
self._init_state()

def _init_state(self):
def _init_state(self) -> None:
self.queues, self.sequence_tokens = {}, {}
self.threads = []
self.creating_log_stream, self.shutting_down = False, False
Expand All @@ -254,7 +266,7 @@ def _paginate(self, boto3_paginator, *args, **kwargs):
for value in page.get(result_key.parsed.get("value"), []):
yield value

def _ensure_log_group(self):
def _ensure_log_group(self) -> None:
try:
paginator = self.cwl_client.get_paginator("describe_log_groups")
for log_group in self._paginate(paginator, logGroupNamePrefix=self.log_group_name):
Expand All @@ -264,7 +276,7 @@ def _ensure_log_group(self):
pass
self._idempotent_call("create_log_group", logGroupName=self.log_group_name)

def _idempotent_call(self, method, *args, **kwargs):
def _idempotent_call(self, method: str, *args, **kwargs) -> None:
method_callable = getattr(self.cwl_client, method)
try:
method_callable(*args, **kwargs)
Expand All @@ -273,10 +285,10 @@ def _idempotent_call(self, method, *args, **kwargs):
pass

@functools.lru_cache(maxsize=0)
def _get_machine_name(self):
def _get_machine_name(self) -> str:
return platform.node()

def _get_stream_name(self, message):
def _get_stream_name(self, message: LogRecord) -> str:
return self.log_stream_name.format(
machine_name=self._get_machine_name(),
program_name=sys.argv[0],
Expand All @@ -285,7 +297,7 @@ def _get_stream_name(self, message):
strftime=datetime.utcnow()
)

def _submit_batch(self, batch, log_stream_name, max_retries=5):
def _submit_batch(self, batch: Sequence[dict], log_stream_name: str, max_retries: int = 5) -> None:
if len(batch) < 1:
return
sorted_batch = sorted(batch, key=itemgetter('timestamp'), reverse=False)
Expand Down Expand Up @@ -325,23 +337,23 @@ def _submit_batch(self, batch, log_stream_name, max_retries=5):
finally:
self.creating_log_stream = False
else:
warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning)
warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning)
except Exception as e:
warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning)
warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning)

# response can be None only when all retries have been exhausted
if response is None or "rejectedLogEventsInfo" in response:
warnings.warn("Failed to deliver logs: {}".format(response), WatchtowerWarning)
warnings.warn(f"Failed to deliver logs: {response}", WatchtowerWarning)
elif "nextSequenceToken" in response:
# According to https://github.com/kislyuk/watchtower/issues/134, nextSequenceToken may sometimes be absent
# from the response
self.sequence_tokens[log_stream_name] = response["nextSequenceToken"]

def createLock(self):
def createLock(self) -> None:
super().createLock()
self._init_state()

def emit(self, message):
def emit(self, message: LogRecord) -> None:
if self.creating_log_stream:
return # Avoid infinite recursion when asked to log a message as our own side effect

Expand Down Expand Up @@ -375,28 +387,32 @@ def emit(self, message):
except Exception:
self.handleError(message)

def _dequeue_batch(self, my_queue, stream_name, send_interval, max_batch_size, max_batch_count, max_message_size):
msg = None
def _dequeue_batch(self, my_queue: queue.Queue, stream_name: str, send_interval: int, max_batch_size: int,
max_batch_count: int, max_message_size: int) -> None:
msg: Union[dict, int, None] = None
max_message_body_size = max_message_size - CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE
assert max_message_body_size > 0

def size(_msg):
def size(_msg: Union[dict, int]) -> int:
return (len(_msg["message"]) if isinstance(_msg, dict) else 1) + CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE

def truncate(_msg2):
def truncate(_msg2: dict) -> dict:
warnings.warn("Log message size exceeds CWL max payload size, truncated", WatchtowerWarning)
_msg2["message"] = _msg2["message"][:max_message_body_size]
return _msg2

# See https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.put_log_events
while msg != self.END:
cur_batch = [] if msg is None or msg == self.FLUSH else [msg]
cur_batch: List[dict] = [] if msg is None or msg == self.FLUSH else [cast(dict, msg)]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks almost all good to go, except let's please remove all of the casts and asserts. The typing information should decorate the function signatures, not be part of the execution. If mypy fails to check this method without calls to cast and asserts, then we should disable mypy checking for those lines, or possibly split up the method into multiple methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look at the PR to see what makes sense. In case you aren't familiar, cast is just a noop at runtime so while it's calling more code it shouldn't be that big of a deal. It does enrich mypy with more information and I'm not sure if mypy treats a type: ignore equally, but I can double check. For the asserts I believe I have generally used them to narrow optional variables, but I will review the PR since it's been awhile since I looked at this code.

cur_batch_size = sum(map(size, cur_batch))
cur_batch_msg_count = len(cur_batch)
cur_batch_deadline = time.time() + send_interval
while True:
try:
msg = my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time()))
msg = cast(Union[dict, int],
my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time())))
if size(msg) > max_message_body_size:
assert isinstance(msg, dict) # size always < max_message_body_size if not `dict`
msg = truncate(msg)
except queue.Empty:
# If the queue is empty, we don't want to reprocess the previous message
Expand All @@ -413,12 +429,13 @@ def truncate(_msg2):
my_queue.task_done()
break
elif msg:
assert isinstance(msg, dict) # mypy can't handle all the or expressions filtering out sentinels
cur_batch_size += size(msg)
cur_batch_msg_count += 1
cur_batch.append(msg)
my_queue.task_done()

def flush(self):
def flush(self) -> None:
"""
Send any queued messages to CloudWatch. This method does nothing if ``use_queues`` is set to False.
"""
Expand All @@ -431,7 +448,7 @@ def flush(self):
for q in self.queues.values():
q.join()

def close(self):
def close(self) -> None:
"""
Send any queued messages to CloudWatch and prevent further processing of messages.
This method does nothing if ``use_queues`` is set to False.
Expand All @@ -450,6 +467,6 @@ def close(self):
q.join()
super().close()

def __repr__(self):
def __repr__(self) -> str:
name = self.__class__.__name__
return f"{name}(log_group_name='{self.log_group_name}', log_stream_name='{self.log_stream_name}')"
Empty file added watchtower/py.typed
Empty file.