diff --git a/src/pinnwand/defensive.py b/src/pinnwand/defensive.py index 717b254..cdc4a11 100644 --- a/src/pinnwand/defensive.py +++ b/src/pinnwand/defensive.py @@ -1,6 +1,5 @@ import ipaddress import re -from collections import ChainMap from typing import Dict, Union from functools import wraps @@ -21,7 +20,9 @@ ] = {} -def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool: +def should_be_ratelimited( + request: HTTPServerRequest, area: str = "global" +) -> bool: """Test if a requesting IP is ratelimited for a certain area. Areas are different functionalities of the website, for example 'view' or 'input' to differentiate between creating new pastes (low volume) or high volume @@ -52,13 +53,13 @@ def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool: return False -def ratelimit_endpoint(area: str): +def ratelimit(area: str): """A ratelimiting decorator for tornado's request handlers.""" def wrapper(func): @wraps(func) def inner(request_handler: RequestHandler, *args, **kwargs): - if ratelimit(request_handler.request, area): + if should_be_ratelimited(request_handler.request, area): raise error.RatelimitError() return func(request_handler, *args, **kwargs) diff --git a/src/pinnwand/handler/api_curl.py b/src/pinnwand/handler/api_curl.py index 882322d..9a0abea 100644 --- a/src/pinnwand/handler/api_curl.py +++ b/src/pinnwand/handler/api_curl.py @@ -27,7 +27,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None: else: super().write_error(status_code, **kwargs) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") def post(self) -> None: lexer = self.get_body_argument("lexer", "text") raw = self.get_body_argument("raw", "", strip=False) diff --git a/src/pinnwand/handler/api_deprecated.py b/src/pinnwand/handler/api_deprecated.py index 139ec86..4b32b43 100644 --- a/src/pinnwand/handler/api_deprecated.py +++ b/src/pinnwand/handler/api_deprecated.py @@ -73,7 +73,7 @@ async def post(self) -> None: class Show(Base): """Show a paste on the deprecated API.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: # type: ignore with database.session() as session: paste = ( @@ -117,7 +117,7 @@ def check_xsrf_cookie(self) -> None: async def get(self) -> None: raise tornado.web.HTTPError(405) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") async def post(self) -> None: lexer = self.get_body_argument("lexer") raw = self.get_body_argument("code", strip=False) @@ -170,7 +170,7 @@ def check_xsrf_cookie(self) -> None: """No XSRF cookies on the API.""" return - @defensive.ratelimit_endpoint(area="delete") + @defensive.ratelimit(area="delete") async def post(self) -> None: with database.session() as session: paste = ( @@ -202,7 +202,7 @@ async def post(self) -> None: class Lexer(Base): """List lexers through the deprecated API.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: self.write(utility.list_languages()) @@ -210,7 +210,7 @@ async def get(self) -> None: class Expiry(Base): """List expiries through the deprecated API.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: self.write( { diff --git a/src/pinnwand/handler/api_v1.py b/src/pinnwand/handler/api_v1.py index 2a0ae30..033f068 100644 --- a/src/pinnwand/handler/api_v1.py +++ b/src/pinnwand/handler/api_v1.py @@ -17,13 +17,13 @@ def write_error(self, status_code: int, **kwargs: Any) -> None: class Lexer(Base): - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: self.write(utility.list_languages()) class Expiry(Base): - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: self.write( { @@ -40,7 +40,7 @@ def check_xsrf_cookie(self) -> None: async def get(self) -> None: raise tornado.web.HTTPError(405) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") async def post(self) -> None: try: data = tornado.escape.json_decode(self.request.body) @@ -118,7 +118,7 @@ async def post(self) -> None: class PasteDetail(Base): - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: with database.session() as session: paste = ( diff --git a/src/pinnwand/handler/website.py b/src/pinnwand/handler/website.py index 08a71f2..c294fc1 100644 --- a/src/pinnwand/handler/website.py +++ b/src/pinnwand/handler/website.py @@ -82,7 +82,7 @@ class Create(Base): """The index page shows the new paste page with a list of all available lexers from Pygments.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, lexers: str = "") -> None: """Render the new paste form, optionally have a lexer preselected from the URL.""" @@ -110,7 +110,7 @@ async def get(self, lexers: str = "") -> None: paste=None, ) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") async def post(self) -> None: """This is a historical endpoint to create pastes, pastes are marked as old-web and will get a warning on top of them to remove any access to @@ -171,7 +171,7 @@ class CreateAction(Base): """The create action is the 'new' way to create pastes and supports multi file pastes.""" - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") def post(self) -> None: # type: ignore """POST handler for the 'web' side of things.""" @@ -256,7 +256,7 @@ class Repaste(Base): """Repaste is a specific case of the paste page. It only works for pre- existing pastes and will prefill the textarea and lexer.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: # type: ignore """Render the new paste form, optionally have a lexer preselected from the URL.""" @@ -287,7 +287,7 @@ async def get(self, slug: str) -> None: # type: ignore class Show(Base): """Show a paste.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: # type: ignore """Fetch paste from database by slug and render the paste.""" @@ -354,7 +354,7 @@ async def get(self, slug: str) -> None: # type: ignore class FileRaw(Base): """Show a file as plaintext.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, file_id: str) -> None: # type: ignore """Get a file from the database and show it in the plain.""" @@ -385,7 +385,7 @@ async def get(self, file_id: str) -> None: # type: ignore class FileHex(Base): """Show a file as hexadecimal.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, file_id: str) -> None: # type: ignore """Get a file from the database and show it in hex.""" @@ -416,7 +416,7 @@ async def get(self, file_id: str) -> None: # type: ignore class PasteDownload(Base): """Download an entire paste.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, paste_id: str) -> None: # type: ignore """Get all files from the database and download them as a zipfile.""" @@ -463,7 +463,7 @@ async def get(self, paste_id: str) -> None: # type: ignore class FileDownload(Base): """Download a file.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, file_id: str) -> None: # type: ignore """Get a file from the database and download it in the plain.""" @@ -505,7 +505,7 @@ async def get(self, file_id: str) -> None: # type: ignore class Remove(Base): """Remove a paste.""" - @defensive.ratelimit_endpoint(area="delete") + @defensive.ratelimit(area="delete") async def get(self, removal: str) -> None: # type: ignore """Look up if the user visiting this page has the removal id for a certain paste. If they do they're authorized to remove the paste.""" @@ -543,7 +543,7 @@ class RestructuredTextPage(Base): def initialize(self, file: str) -> None: self.file = file - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: try: with open(path.page / self.file) as f: