diff --git a/src/pinnwand/defensive.py b/src/pinnwand/defensive.py index df412f3..4a5ebd5 100644 --- a/src/pinnwand/defensive.py +++ b/src/pinnwand/defensive.py @@ -15,9 +15,7 @@ ratelimit_area: Dict[str, token_bucket.Limiter] = {} -def should_be_ratelimited( - request: HTTPServerRequest, area: str = "global" -) -> bool: +def should_be_ratelimited(ip_address: str, area: str = "global") -> bool: """Test if a requesting IP is ratelimited for a certain area. Areas are different functionalities of the website, for example 'view' or 'input' to differentiate between creating new pastes (low volume) or high volume @@ -38,10 +36,10 @@ def should_be_ratelimited( ) if not ratelimit_area[area].consume( - request.remote_ip.encode("utf-8"), + ip_address.encode("utf-8"), configuration.ratelimit[area]["consume"], ): - log.warning("%s hit rate limit for %r", request.remote_ip, area) + log.warning("%s hit rate limit for %r", ip_address, area) return True return False @@ -53,7 +51,7 @@ def ratelimit(area: str): def wrapper(func): @wraps(func) def inner(request_handler: RequestHandler, *args, **kwargs): - if should_be_ratelimited(request_handler.request, area): + if should_be_ratelimited(request_handler.request.remote_ip, area): raise error.RatelimitError() return func(request_handler, *args, **kwargs)