Skip to content

Commit

Permalink
pass remote_ip instead of entire request object to should_be_ratelimited
Browse files Browse the repository at this point in the history
This is done mainly because it'll make unit testing the function alot easier

Also, passing the entire object is beneficial when multiple attributes are needed, which is not needed here.
  • Loading branch information
shtlrs authored and supakeen committed Mar 20, 2024
1 parent b732e87 commit 5e0376f
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/pinnwand/defensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
ratelimit_area: Dict[str, token_bucket.Limiter] = {}


def should_be_ratelimited(
request: HTTPServerRequest, area: str = "global"
) -> bool:
def should_be_ratelimited(ip_address: str, area: str = "global") -> bool:
"""Test if a requesting IP is ratelimited for a certain area. Areas are
different functionalities of the website, for example 'view' or 'input' to
differentiate between creating new pastes (low volume) or high volume
Expand All @@ -38,10 +36,10 @@ def should_be_ratelimited(
)

if not ratelimit_area[area].consume(
request.remote_ip.encode("utf-8"),
ip_address.encode("utf-8"),
configuration.ratelimit[area]["consume"],
):
log.warning("%s hit rate limit for %r", request.remote_ip, area)
log.warning("%s hit rate limit for %r", ip_address, area)
return True

return False
Expand All @@ -53,7 +51,7 @@ def ratelimit(area: str):
def wrapper(func):
@wraps(func)
def inner(request_handler: RequestHandler, *args, **kwargs):
if should_be_ratelimited(request_handler.request, area):
if should_be_ratelimited(request_handler.request.remote_ip, area):
raise error.RatelimitError()
return func(request_handler, *args, **kwargs)

Expand Down

0 comments on commit 5e0376f

Please sign in to comment.