diff --git a/src/pinnwand/handler/api_v1.py b/src/pinnwand/handler/api_v1.py index 03228b7..034d8ae 100644 --- a/src/pinnwand/handler/api_v1.py +++ b/src/pinnwand/handler/api_v1.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta -from typing import Any +from typing import Any, Awaitable, List, Optional from urllib.parse import urljoin import tornado.web @@ -10,25 +10,43 @@ log = logger.get_logger(__name__) +class RateLimiterMixin: + """ + A mixin that provides behavior for ratelimiting. + + The subclass must define a list of areas where the ratelimit should be applied. + + Due to MRO, the mixin needs to be the first to be inherited in the child classes. + + Raises `RatelimitError` when a ratelimit is hit for any of the areas defined in the subclass. + """ + + areas: List[str] = [] + + def prepare(self) -> Optional[Awaitable[None]]: + for area in self.areas: + if defensive.ratelimit(self.request, area=area): + raise error.RatelimitError() + return None + + class Base(tornado.web.RequestHandler): def write_error(self, status_code: int, **kwargs: Any) -> None: _, exc, _ = kwargs["exc_info"] self.write({"state": "error", "code": status_code, "message": str(exc)}) -class Lexer(Base): - async def get(self) -> None: - if defensive.ratelimit(self.request, area="read"): - raise error.RatelimitError() +class Lexer(RateLimiterMixin, Base): + areas = ["read"] + async def get(self) -> None: self.write(utility.list_languages()) -class Expiry(Base): - async def get(self) -> None: - if defensive.ratelimit(self.request, area="read"): - raise error.RatelimitError() +class Expiry(RateLimiterMixin, Base): + areas = ["read"] + async def get(self) -> None: self.write( { name: str(timedelta(seconds=delta)) @@ -37,7 +55,9 @@ async def get(self) -> None: ) -class Paste(Base): +class Paste(RateLimiterMixin, Base): + areas = ["create"] + def check_xsrf_cookie(self) -> None: return @@ -45,9 +65,6 @@ async def get(self) -> None: raise tornado.web.HTTPError(405) async def post(self) -> None: - if defensive.ratelimit(self.request, area="create"): - raise error.RatelimitError() - try: data = tornado.escape.json_decode(self.request.body) except json.decoder.JSONDecodeError: @@ -123,11 +140,10 @@ async def post(self) -> None: self.write({"link": url_paste, "removal": url_removal}) -class PasteDetail(Base): - async def get(self, slug: str) -> None: - if defensive.ratelimit(self.request, area="read"): - raise error.RatelimitError() +class PasteDetail(RateLimiterMixin, Base): + areas = ["read"] + async def get(self, slug: str) -> None: with database.session() as session: paste = ( session.query(database.Paste)