From 5e0376f6d3922b832ee1f371d9af88a2d6a517b1 Mon Sep 17 00:00:00 2001 From: shtlrs Date: Wed, 20 Mar 2024 13:42:40 +0100 Subject: [PATCH] pass remote_ip instead of entire request object to should_be_ratelimited This is done mainly because it'll make unit testing the function alot easier Also, passing the entire object is beneficial when multiple attributes are needed, which is not needed here. --- src/pinnwand/defensive.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/pinnwand/defensive.py b/src/pinnwand/defensive.py index df412f3..4a5ebd5 100644 --- a/src/pinnwand/defensive.py +++ b/src/pinnwand/defensive.py @@ -15,9 +15,7 @@ ratelimit_area: Dict[str, token_bucket.Limiter] = {} -def should_be_ratelimited( - request: HTTPServerRequest, area: str = "global" -) -> bool: +def should_be_ratelimited(ip_address: str, area: str = "global") -> bool: """Test if a requesting IP is ratelimited for a certain area. Areas are different functionalities of the website, for example 'view' or 'input' to differentiate between creating new pastes (low volume) or high volume @@ -38,10 +36,10 @@ def should_be_ratelimited( ) if not ratelimit_area[area].consume( - request.remote_ip.encode("utf-8"), + ip_address.encode("utf-8"), configuration.ratelimit[area]["consume"], ): - log.warning("%s hit rate limit for %r", request.remote_ip, area) + log.warning("%s hit rate limit for %r", ip_address, area) return True return False @@ -53,7 +51,7 @@ def ratelimit(area: str): def wrapper(func): @wraps(func) def inner(request_handler: RequestHandler, *args, **kwargs): - if should_be_ratelimited(request_handler.request, area): + if should_be_ratelimited(request_handler.request.remote_ip, area): raise error.RatelimitError() return func(request_handler, *args, **kwargs)