mirror of
https://github.com/eduardogsilva/wireguard_webadmin.git
synced 2026-03-17 14:26:18 +00:00
implement rate limiting for authentication routes and add custom error handling page
This commit is contained in:
14
containers/auth-gateway/auth_gateway/limiter.py
Normal file
14
containers/auth-gateway/auth_gateway/limiter.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
from slowapi import Limiter
|
||||||
|
|
||||||
|
AUTH_RATE_LIMIT = "5/minute"
|
||||||
|
|
||||||
|
|
||||||
|
def get_real_client_ip(request: Request) -> str:
|
||||||
|
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||||
|
if forwarded_for:
|
||||||
|
return forwarded_for.split(",")[0].strip()
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=get_real_client_ip)
|
||||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from auth_gateway.config_loader import RuntimeConfigStore
|
from auth_gateway.config_loader import RuntimeConfigStore
|
||||||
|
from auth_gateway.limiter import get_real_client_ip, limiter
|
||||||
from auth_gateway.services.oidc_service import OIDCService
|
from auth_gateway.services.oidc_service import OIDCService
|
||||||
from auth_gateway.services.session_service import SessionService
|
from auth_gateway.services.session_service import SessionService
|
||||||
from auth_gateway.settings import settings
|
from auth_gateway.settings import settings
|
||||||
@@ -11,8 +12,10 @@ from auth_gateway.storage.sqlite import SQLiteStorage
|
|||||||
from auth_gateway.web.auth_routes import router as auth_router
|
from auth_gateway.web.auth_routes import router as auth_router
|
||||||
from auth_gateway.web.login_routes import router as login_router
|
from auth_gateway.web.login_routes import router as login_router
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
_access_logger = logging.getLogger("uvicorn.error")
|
_access_logger = logging.getLogger("uvicorn.error")
|
||||||
@@ -30,6 +33,32 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Auth Gateway", lifespan=lifespan)
|
app = FastAPI(title="Auth Gateway", lifespan=lifespan)
|
||||||
|
app.state.limiter = limiter
|
||||||
|
|
||||||
|
|
||||||
|
async def _rate_limit_handler(request: Request, exc: RateLimitExceeded) -> HTMLResponse:
|
||||||
|
username = None
|
||||||
|
try:
|
||||||
|
form = await request.form()
|
||||||
|
username = form.get("username")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
client = get_real_client_ip(request)
|
||||||
|
if username:
|
||||||
|
_access_logger.warning("AUTH rate limit exceeded for '%s' on %s from %s", username, request.url.path, client)
|
||||||
|
else:
|
||||||
|
_access_logger.warning("AUTH rate limit exceeded on %s from %s", request.url.path, client)
|
||||||
|
|
||||||
|
templates = request.app.state.templates
|
||||||
|
external_path = request.app.state.settings.external_path
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"ratelimit.html",
|
||||||
|
{"request": request, "external_path": external_path, "back_url": str(request.url.path)},
|
||||||
|
status_code=429,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_handler)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
@@ -39,7 +68,7 @@ async def access_log(request: Request, call_next):
|
|||||||
if request.url.path == "/auth/check" and response.status_code == 200:
|
if request.url.path == "/auth/check" and response.status_code == 200:
|
||||||
return response
|
return response
|
||||||
ms = (time.monotonic() - start) * 1000
|
ms = (time.monotonic() - start) * 1000
|
||||||
client = request.client.host if request.client else "-"
|
client = get_real_client_ip(request)
|
||||||
_access_logger.info('%s - "%s %s" %d (%.0fms)', client, request.method, request.url.path, response.status_code, ms)
|
_access_logger.info('%s - "%s %s" %d (%.0fms)', client, request.method, request.url.path, response.status_code, ms)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}Too many attempts — Gatekeeper{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
<h1 class="card-title">Too many attempts</h1>
|
||||||
|
<p class="card-subtitle">You have made too many requests in a short period. Please wait a moment before trying again.</p>
|
||||||
|
<a class="btn btn-secondary" href="{{ back_url }}">Try again</a>
|
||||||
|
{% endblock %}
|
||||||
@@ -19,6 +19,7 @@ from auth_gateway.web.dependencies import (
|
|||||||
resolve_context_from_request,
|
resolve_context_from_request,
|
||||||
session_is_allowed,
|
session_is_allowed,
|
||||||
)
|
)
|
||||||
|
from auth_gateway.limiter import AUTH_RATE_LIMIT, limiter
|
||||||
from fastapi import APIRouter, Form, Request
|
from fastapi import APIRouter, Form, Request
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
|
|
||||||
@@ -118,6 +119,7 @@ async def login_password_page(request: Request, next: str = "/"):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/login/password")
|
@router.post("/login/password")
|
||||||
|
@limiter.limit(AUTH_RATE_LIMIT)
|
||||||
async def login_password_submit(request: Request, next: str = Form("/"), username: str = Form(...), password: str = Form(...)):
|
async def login_password_submit(request: Request, next: str = Form("/"), username: str = Form(...), password: str = Form(...)):
|
||||||
runtime_config = get_runtime_config(request)
|
runtime_config = get_runtime_config(request)
|
||||||
context = resolve_context_from_request(request, runtime_config, next)
|
context = resolve_context_from_request(request, runtime_config, next)
|
||||||
@@ -168,6 +170,7 @@ async def login_totp_page(request: Request, next: str = "/"):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/login/totp")
|
@router.post("/login/totp")
|
||||||
|
@limiter.limit(AUTH_RATE_LIMIT)
|
||||||
async def login_totp_submit(request: Request, next: str = Form("/"), token: str = Form(...)):
|
async def login_totp_submit(request: Request, next: str = Form("/"), token: str = Form(...)):
|
||||||
runtime_config = get_runtime_config(request)
|
runtime_config = get_runtime_config(request)
|
||||||
context = resolve_context_from_request(request, runtime_config, next)
|
context = resolve_context_from_request(request, runtime_config, next)
|
||||||
@@ -204,6 +207,7 @@ async def login_totp_submit(request: Request, next: str = Form("/"), token: str
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/login/oidc/start")
|
@router.get("/login/oidc/start")
|
||||||
|
@limiter.limit(AUTH_RATE_LIMIT)
|
||||||
async def login_oidc_start(request: Request, next: str = "/"):
|
async def login_oidc_start(request: Request, next: str = "/"):
|
||||||
runtime_config = get_runtime_config(request)
|
runtime_config = get_runtime_config(request)
|
||||||
context = resolve_context_from_request(request, runtime_config, next)
|
context = resolve_context_from_request(request, runtime_config, next)
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
set -eu
|
set -eu
|
||||||
|
|
||||||
exec uvicorn auth_gateway.main:app --host 0.0.0.0 --port 9091 --no-access-log
|
exec uvicorn auth_gateway.main:app --host 0.0.0.0 --port 9091 --no-access-log --proxy-headers --forwarded-allow-ips="*"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ certifi==2026.2.25
|
|||||||
cffi==2.0.0
|
cffi==2.0.0
|
||||||
click==8.3.1
|
click==8.3.1
|
||||||
cryptography==46.0.5
|
cryptography==46.0.5
|
||||||
|
Deprecated==1.3.1
|
||||||
fastapi==0.135.1
|
fastapi==0.135.1
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
httpcore==1.0.9
|
httpcore==1.0.9
|
||||||
@@ -15,7 +16,9 @@ httptools==0.7.1
|
|||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
idna==3.11
|
idna==3.11
|
||||||
Jinja2==3.1.6
|
Jinja2==3.1.6
|
||||||
|
limits==5.8.0
|
||||||
MarkupSafe==3.0.3
|
MarkupSafe==3.0.3
|
||||||
|
packaging==26.0
|
||||||
pycparser==3.0
|
pycparser==3.0
|
||||||
pydantic==2.12.5
|
pydantic==2.12.5
|
||||||
pydantic-settings==2.13.1
|
pydantic-settings==2.13.1
|
||||||
@@ -24,6 +27,7 @@ pyotp==2.9.0
|
|||||||
python-dotenv==1.2.2
|
python-dotenv==1.2.2
|
||||||
python-multipart==0.0.22
|
python-multipart==0.0.22
|
||||||
PyYAML==6.0.3
|
PyYAML==6.0.3
|
||||||
|
slowapi==0.1.9
|
||||||
starlette==0.52.1
|
starlette==0.52.1
|
||||||
typing-inspection==0.4.2
|
typing-inspection==0.4.2
|
||||||
typing_extensions==4.15.0
|
typing_extensions==4.15.0
|
||||||
@@ -31,3 +35,4 @@ uvicorn==0.42.0
|
|||||||
uvloop==0.22.1
|
uvloop==0.22.1
|
||||||
watchfiles==1.1.1
|
watchfiles==1.1.1
|
||||||
websockets==16.0
|
websockets==16.0
|
||||||
|
wrapt==2.1.2
|
||||||
|
|||||||
Reference in New Issue
Block a user