Skip to content

Commit

Permalink
port ratelimiting logic to its own class
Browse files Browse the repository at this point in the history
  • Loading branch information
shtlrs committed Jan 8, 2024
1 parent 3c866b4 commit 41f966d
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions src/pinnwand/handler/api_v1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from datetime import datetime, timedelta
from typing import Any
from typing import Any, Awaitable, List, Optional
from urllib.parse import urljoin

import tornado.web
Expand All @@ -10,25 +10,43 @@
log = logger.get_logger(__name__)


class RateLimiterMixin:
"""
A mixin that provides behavior for ratelimiting.
The subclass must define a list of areas where the ratelimit should be applied.
Due to MRO, the mixin needs to be the first to be inherited in the child classes.
Raises `RatelimitError` when a ratelimit is hit for any of the areas defined in the subclass.
"""

areas: List[str] = []

def prepare(self) -> Optional[Awaitable[None]]:
for area in self.areas:
if defensive.ratelimit(self.request, area=area):
raise error.RatelimitError()
return None


class Base(tornado.web.RequestHandler):
def write_error(self, status_code: int, **kwargs: Any) -> None:
_, exc, _ = kwargs["exc_info"]
self.write({"state": "error", "code": status_code, "message": str(exc)})


class Lexer(Base):
async def get(self) -> None:
if defensive.ratelimit(self.request, area="read"):
raise error.RatelimitError()
class Lexer(RateLimiterMixin, Base):
areas = ["read"]

async def get(self) -> None:
self.write(utility.list_languages())


class Expiry(Base):
async def get(self) -> None:
if defensive.ratelimit(self.request, area="read"):
raise error.RatelimitError()
class Expiry(RateLimiterMixin, Base):
areas = ["read"]

async def get(self) -> None:
self.write(
{
name: str(timedelta(seconds=delta))
Expand All @@ -37,17 +55,16 @@ async def get(self) -> None:
)


class Paste(Base):
class Paste(RateLimiterMixin, Base):
areas = ["create"]

def check_xsrf_cookie(self) -> None:
return

async def get(self) -> None:
raise tornado.web.HTTPError(405)

async def post(self) -> None:
if defensive.ratelimit(self.request, area="create"):
raise error.RatelimitError()

try:
data = tornado.escape.json_decode(self.request.body)
except json.decoder.JSONDecodeError:
Expand Down Expand Up @@ -123,11 +140,10 @@ async def post(self) -> None:
self.write({"link": url_paste, "removal": url_removal})


class PasteDetail(Base):
async def get(self, slug: str) -> None:
if defensive.ratelimit(self.request, area="read"):
raise error.RatelimitError()
class PasteDetail(RateLimiterMixin, Base):
areas = ["read"]

async def get(self, slug: str) -> None:
with database.session() as session:
paste = (
session.query(database.Paste)
Expand Down

0 comments on commit 41f966d

Please sign in to comment.