Skip to content

Commit

Permalink
rename the ratelimiter functions
Browse files Browse the repository at this point in the history
This is to improve the decorator's api better.
  • Loading branch information
shtlrs committed Jan 20, 2024
1 parent 4f73294 commit 989960c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 25 deletions.
9 changes: 5 additions & 4 deletions src/pinnwand/defensive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ipaddress
import re
from collections import ChainMap
from typing import Dict, Union
from functools import wraps

Expand All @@ -21,7 +20,9 @@
] = {}


def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool:
def should_be_ratelimited(
request: HTTPServerRequest, area: str = "global"
) -> bool:
"""Test if a requesting IP is ratelimited for a certain area. Areas are
different functionalities of the website, for example 'view' or 'input' to
differentiate between creating new pastes (low volume) or high volume
Expand Down Expand Up @@ -52,13 +53,13 @@ def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool:
return False


def ratelimit_endpoint(area: str):
def ratelimit(area: str):
"""A ratelimiting decorator for tornado's request handlers."""

def wrapper(func):
@wraps(func)
def inner(request_handler: RequestHandler, *args, **kwargs):
if ratelimit(request_handler.request, area):
if should_be_ratelimited(request_handler.request, area):
raise error.RatelimitError()
return func(request_handler, *args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/pinnwand/handler/api_curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None:
else:
super().write_error(status_code, **kwargs)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
def post(self) -> None:
lexer = self.get_body_argument("lexer", "text")
raw = self.get_body_argument("raw", "", strip=False)
Expand Down
10 changes: 5 additions & 5 deletions src/pinnwand/handler/api_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def post(self) -> None:
class Show(Base):
"""Show a paste on the deprecated API."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None: # type: ignore
with database.session() as session:
paste = (
Expand Down Expand Up @@ -117,7 +117,7 @@ def check_xsrf_cookie(self) -> None:
async def get(self) -> None:
raise tornado.web.HTTPError(405)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
async def post(self) -> None:
lexer = self.get_body_argument("lexer")
raw = self.get_body_argument("code", strip=False)
Expand Down Expand Up @@ -170,7 +170,7 @@ def check_xsrf_cookie(self) -> None:
"""No XSRF cookies on the API."""
return

@defensive.ratelimit_endpoint(area="delete")
@defensive.ratelimit(area="delete")
async def post(self) -> None:
with database.session() as session:
paste = (
Expand Down Expand Up @@ -202,15 +202,15 @@ async def post(self) -> None:
class Lexer(Base):
"""List lexers through the deprecated API."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
self.write(utility.list_languages())


class Expiry(Base):
"""List expiries through the deprecated API."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
self.write(
{
Expand Down
8 changes: 4 additions & 4 deletions src/pinnwand/handler/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def write_error(self, status_code: int, **kwargs: Any) -> None:


class Lexer(Base):
@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
self.write(utility.list_languages())


class Expiry(Base):
@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
self.write(
{
Expand All @@ -40,7 +40,7 @@ def check_xsrf_cookie(self) -> None:
async def get(self) -> None:
raise tornado.web.HTTPError(405)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
async def post(self) -> None:
try:
data = tornado.escape.json_decode(self.request.body)
Expand Down Expand Up @@ -118,7 +118,7 @@ async def post(self) -> None:


class PasteDetail(Base):
@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None:
with database.session() as session:
paste = (
Expand Down
22 changes: 11 additions & 11 deletions src/pinnwand/handler/website.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Create(Base):
"""The index page shows the new paste page with a list of all available
lexers from Pygments."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, lexers: str = "") -> None:
"""Render the new paste form, optionally have a lexer preselected from
the URL."""
Expand Down Expand Up @@ -110,7 +110,7 @@ async def get(self, lexers: str = "") -> None:
paste=None,
)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
async def post(self) -> None:
"""This is a historical endpoint to create pastes, pastes are marked as
old-web and will get a warning on top of them to remove any access to
Expand Down Expand Up @@ -171,7 +171,7 @@ class CreateAction(Base):
"""The create action is the 'new' way to create pastes and supports multi
file pastes."""

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
def post(self) -> None: # type: ignore
"""POST handler for the 'web' side of things."""

Expand Down Expand Up @@ -256,7 +256,7 @@ class Repaste(Base):
"""Repaste is a specific case of the paste page. It only works for pre-
existing pastes and will prefill the textarea and lexer."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None: # type: ignore
"""Render the new paste form, optionally have a lexer preselected from
the URL."""
Expand Down Expand Up @@ -287,7 +287,7 @@ async def get(self, slug: str) -> None: # type: ignore
class Show(Base):
"""Show a paste."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None: # type: ignore
"""Fetch paste from database by slug and render the paste."""

Expand Down Expand Up @@ -354,7 +354,7 @@ async def get(self, slug: str) -> None: # type: ignore
class FileRaw(Base):
"""Show a file as plaintext."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, file_id: str) -> None: # type: ignore
"""Get a file from the database and show it in the plain."""

Expand Down Expand Up @@ -385,7 +385,7 @@ async def get(self, file_id: str) -> None: # type: ignore
class FileHex(Base):
"""Show a file as hexadecimal."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, file_id: str) -> None: # type: ignore
"""Get a file from the database and show it in hex."""

Expand Down Expand Up @@ -416,7 +416,7 @@ async def get(self, file_id: str) -> None: # type: ignore
class PasteDownload(Base):
"""Download an entire paste."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, paste_id: str) -> None: # type: ignore
"""Get all files from the database and download them as a zipfile."""

Expand Down Expand Up @@ -463,7 +463,7 @@ async def get(self, paste_id: str) -> None: # type: ignore
class FileDownload(Base):
"""Download a file."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, file_id: str) -> None: # type: ignore
"""Get a file from the database and download it in the plain."""

Expand Down Expand Up @@ -505,7 +505,7 @@ async def get(self, file_id: str) -> None: # type: ignore
class Remove(Base):
"""Remove a paste."""

@defensive.ratelimit_endpoint(area="delete")
@defensive.ratelimit(area="delete")
async def get(self, removal: str) -> None: # type: ignore
"""Look up if the user visiting this page has the removal id for a
certain paste. If they do they're authorized to remove the paste."""
Expand Down Expand Up @@ -543,7 +543,7 @@ class RestructuredTextPage(Base):
def initialize(self, file: str) -> None:
self.file = file

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
try:
with open(path.page / self.file) as f:
Expand Down

0 comments on commit 989960c

Please sign in to comment.