mirror of
https://github.com/eduardogsilva/wireguard_webadmin.git
synced 2026-03-17 22:36:17 +00:00
enhance security by enforcing strict SameSite cookies, validating OIDC callback host, and improving path matching logic
This commit is contained in:
@@ -25,7 +25,7 @@ class OIDCService:
|
|||||||
client_id=method.client_id,
|
client_id=method.client_id,
|
||||||
client_secret=method.client_secret,
|
client_secret=method.client_secret,
|
||||||
server_metadata_url=metadata_url,
|
server_metadata_url=metadata_url,
|
||||||
client_kwargs={"scope": "openid email profile"},
|
client_kwargs={"scope": "openid email profile", "code_challenge_method": "S256"},
|
||||||
)
|
)
|
||||||
self._clients[method_name] = client
|
self._clients[method_name] = client
|
||||||
return client
|
return client
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import unquote, urlsplit
|
||||||
|
|
||||||
from auth_gateway.models.applications import ApplicationModel
|
from auth_gateway.models.applications import ApplicationModel
|
||||||
from auth_gateway.models.routes import AppRoutesModel, RoutePolicyBindingModel
|
from auth_gateway.models.routes import AppRoutesModel, RoutePolicyBindingModel
|
||||||
@@ -21,10 +21,16 @@ def normalize_host(raw_host: str) -> str:
|
|||||||
|
|
||||||
def normalize_path(raw_uri: str) -> str:
|
def normalize_path(raw_uri: str) -> str:
|
||||||
parsed = urlsplit(raw_uri or "/")
|
parsed = urlsplit(raw_uri or "/")
|
||||||
path = parsed.path or "/"
|
path = unquote(parsed.path or "/")
|
||||||
return path if path.startswith("/") else f"/{path}"
|
return path if path.startswith("/") else f"/{path}"
|
||||||
|
|
||||||
|
|
||||||
|
def _path_matches(path: str, prefix: str) -> bool:
|
||||||
|
"""Check path boundary correctly — prevents /admin matching /administrator."""
|
||||||
|
prefix = prefix.rstrip("/")
|
||||||
|
return path == prefix or path.startswith(prefix + "/")
|
||||||
|
|
||||||
|
|
||||||
def resolve_application(runtime_config: RuntimeConfig, host: str) -> ApplicationModel | None:
|
def resolve_application(runtime_config: RuntimeConfig, host: str) -> ApplicationModel | None:
|
||||||
normalized_host = normalize_host(host)
|
normalized_host = normalize_host(host)
|
||||||
for application in runtime_config.applications.values():
|
for application in runtime_config.applications.values():
|
||||||
@@ -42,7 +48,7 @@ def resolve_route(runtime_config: RuntimeConfig, application_id: str, path: str)
|
|||||||
sorted_routes = sorted(app_routes.routes, key=lambda route: len(route.path_prefix), reverse=True)
|
sorted_routes = sorted(app_routes.routes, key=lambda route: len(route.path_prefix), reverse=True)
|
||||||
for route in sorted_routes:
|
for route in sorted_routes:
|
||||||
route_prefix = normalize_path(route.path_prefix)
|
route_prefix = normalize_path(route.path_prefix)
|
||||||
if normalized_path.startswith(route_prefix):
|
if _path_matches(normalized_path, route_prefix):
|
||||||
return route, route.policy
|
return route, route.policy
|
||||||
return None, app_routes.default_policy
|
return None, app_routes.default_policy
|
||||||
|
|
||||||
|
|||||||
@@ -52,9 +52,14 @@ class SessionService:
|
|||||||
if metadata:
|
if metadata:
|
||||||
session.metadata.update(metadata)
|
session.metadata.update(metadata)
|
||||||
if add_factors:
|
if add_factors:
|
||||||
|
was_unauthenticated = not session.auth_factors
|
||||||
merged_factors = set(session.auth_factors)
|
merged_factors = set(session.auth_factors)
|
||||||
merged_factors.update(add_factors)
|
merged_factors.update(add_factors)
|
||||||
session.auth_factors = sorted(merged_factors)
|
session.auth_factors = sorted(merged_factors)
|
||||||
|
# Prevent session fixation: regenerate session ID on first authentication
|
||||||
|
if was_unauthenticated and existing_session:
|
||||||
|
self.storage.delete_session(existing_session.session_id)
|
||||||
|
session.session_id = token_urlsafe(32)
|
||||||
requested_expiry = now + timedelta(minutes=expires_in_minutes or self.default_session_minutes)
|
requested_expiry = now + timedelta(minutes=expires_in_minutes or self.default_session_minutes)
|
||||||
session.expires_at = min(session.expires_at, requested_expiry) if existing_session else requested_expiry
|
session.expires_at = min(session.expires_at, requested_expiry) if existing_session else requested_expiry
|
||||||
session.updated_at = now
|
session.updated_at = now
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def _redirect_with_cookie(request: Request, destination: str, session) -> Redire
|
|||||||
value=session.session_id,
|
value=session.session_id,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
secure=request.app.state.settings.secure_cookies,
|
secure=request.app.state.settings.secure_cookies,
|
||||||
samesite="lax",
|
samesite="strict",
|
||||||
path="/",
|
path="/",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
@@ -237,6 +237,11 @@ async def login_oidc_callback(request: Request, state: str):
|
|||||||
if not oidc_state:
|
if not oidc_state:
|
||||||
return _render(request, "error.html", status_code=400, title="Invalid OIDC state", message="The OIDC login state is missing or expired.")
|
return _render(request, "error.html", status_code=400, title="Invalid OIDC state", message="The OIDC login state is missing or expired.")
|
||||||
|
|
||||||
|
callback_host = normalize_host(request.headers.get("host", ""))
|
||||||
|
if oidc_state.host != callback_host:
|
||||||
|
logger.warning("OIDC callback host mismatch: expected '%s', got '%s'", oidc_state.host, callback_host)
|
||||||
|
return _render(request, "error.html", status_code=400, title="OIDC callback host mismatch", message="The OIDC callback host does not match the original request host.")
|
||||||
|
|
||||||
context = resolve_context_from_request(request, runtime_config, oidc_state.next_url)
|
context = resolve_context_from_request(request, runtime_config, oidc_state.next_url)
|
||||||
effective_policy = get_effective_policy(runtime_config, context.policy_name)
|
effective_policy = get_effective_policy(runtime_config, context.policy_name)
|
||||||
method = runtime_config.auth_methods.get(oidc_state.method_name)
|
method = runtime_config.auth_methods.get(oidc_state.method_name)
|
||||||
|
|||||||
@@ -109,6 +109,10 @@ def build_caddyfile(apps, auth_policies, routes):
|
|||||||
base, upstream_path = split_upstream(upstream)
|
base, upstream_path = split_upstream(upstream)
|
||||||
|
|
||||||
lines.append(f"{', '.join(hosts)} {{")
|
lines.append(f"{', '.join(hosts)} {{")
|
||||||
|
lines.append(" # Security: overwrite client-supplied forwarding headers with verified values")
|
||||||
|
lines.append(" request_header X-Forwarded-For {remote_host}")
|
||||||
|
lines.append(" request_header -X-Forwarded-Host")
|
||||||
|
lines.append("")
|
||||||
emit_auth_portal()
|
emit_auth_portal()
|
||||||
|
|
||||||
for static_route in static_routes:
|
for static_route in static_routes:
|
||||||
|
|||||||
Reference in New Issue
Block a user