mirror of
https://github.com/eduardogsilva/wireguard_webadmin.git
synced 2026-03-17 22:36:17 +00:00
add initial implementation of auth gateway with models, routes, and session management
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Auth gateway service layer."""
|
||||
@@ -0,0 +1,59 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from auth_gateway.models.auth import OIDCMethodModel
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCIdentity:
|
||||
subject: str | None
|
||||
email: str | None
|
||||
claims: dict
|
||||
|
||||
|
||||
class OIDCService:
|
||||
def __init__(self):
|
||||
self.oauth = OAuth()
|
||||
self._clients = {}
|
||||
|
||||
def _client(self, method_name: str, method: OIDCMethodModel):
|
||||
if method_name in self._clients:
|
||||
return self._clients[method_name]
|
||||
metadata_url = f"{method.provider.rstrip('/')}/.well-known/openid-configuration"
|
||||
client = self.oauth.register(
|
||||
name=f"oidc_{method_name}",
|
||||
client_id=method.client_id,
|
||||
client_secret=method.client_secret,
|
||||
server_metadata_url=metadata_url,
|
||||
client_kwargs={"scope": "openid email profile"},
|
||||
)
|
||||
self._clients[method_name] = client
|
||||
return client
|
||||
|
||||
async def build_authorization_redirect(self, request, method_name: str, method: OIDCMethodModel, redirect_uri: str, state: str, nonce: str):
|
||||
client = self._client(method_name, method)
|
||||
return await client.authorize_redirect(request, redirect_uri, state=state, nonce=nonce)
|
||||
|
||||
async def finish_callback(self, request, method_name: str, method: OIDCMethodModel, nonce: str) -> OIDCIdentity:
|
||||
client = self._client(method_name, method)
|
||||
token = await client.authorize_access_token(request)
|
||||
claims = {}
|
||||
if "userinfo" in token and isinstance(token["userinfo"], dict):
|
||||
claims = token["userinfo"]
|
||||
elif "id_token" in token:
|
||||
claims = await client.parse_id_token(request, token, nonce=nonce)
|
||||
email = claims.get("email")
|
||||
subject = claims.get("sub")
|
||||
return OIDCIdentity(subject=subject, email=email, claims=dict(claims))
|
||||
|
||||
|
||||
def is_oidc_identity_allowed(method: OIDCMethodModel, email: str | None) -> bool:
|
||||
if not email:
|
||||
return not method.allowed_domains and not method.allowed_emails
|
||||
normalized_email = email.lower()
|
||||
normalized_domain = normalized_email.split("@", 1)[1] if "@" in normalized_email else ""
|
||||
allowed_emails = {item.lower() for item in method.allowed_emails}
|
||||
allowed_domains = {item.lower() for item in method.allowed_domains}
|
||||
if not allowed_emails and not allowed_domains:
|
||||
return True
|
||||
return normalized_email in allowed_emails or normalized_domain in allowed_domains
|
||||
@@ -0,0 +1,18 @@
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
from auth_gateway.models.auth import UserModel
|
||||
|
||||
|
||||
password_hasher = PasswordHasher()
|
||||
|
||||
|
||||
def verify_user_password(username: str, password: str, users: dict[str, UserModel]) -> UserModel | None:
|
||||
user = users.get(username)
|
||||
if not user or not user.password_hash:
|
||||
return None
|
||||
try:
|
||||
password_hasher.verify(user.password_hash, password)
|
||||
except VerifyMismatchError:
|
||||
return None
|
||||
return user
|
||||
106
containers/auth-gateway/auth_gateway/services/policy_engine.py
Normal file
106
containers/auth-gateway/auth_gateway/services/policy_engine.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import ipaddress
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from auth_gateway.models.auth import (
|
||||
IPAddressMethodModel,
|
||||
IPRuleModel,
|
||||
LocalPasswordMethodModel,
|
||||
OIDCMethodModel,
|
||||
PolicyModel,
|
||||
TotpMethodModel,
|
||||
)
|
||||
from auth_gateway.models.runtime import RuntimeConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class EffectivePolicy:
|
||||
name: str
|
||||
mode: str
|
||||
required_factors: list[str] = field(default_factory=list)
|
||||
allowed_users: set[str] = field(default_factory=set)
|
||||
allowed_groups: set[str] = field(default_factory=set)
|
||||
ip_method_names: list[str] = field(default_factory=list)
|
||||
totp_method_names: list[str] = field(default_factory=list)
|
||||
password_method_names: list[str] = field(default_factory=list)
|
||||
oidc_method_names: list[str] = field(default_factory=list)
|
||||
factor_expirations: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def expand_policy_users(runtime_config: RuntimeConfig, policy: PolicyModel) -> set[str]:
|
||||
usernames: set[str] = set()
|
||||
for group_name in policy.groups:
|
||||
group = runtime_config.groups.get(group_name)
|
||||
if group:
|
||||
usernames.update(group.users)
|
||||
return usernames
|
||||
|
||||
|
||||
def build_effective_policy(runtime_config: RuntimeConfig, policy_name: str) -> EffectivePolicy | None:
|
||||
policy = runtime_config.policies.get(policy_name)
|
||||
if not policy:
|
||||
return None
|
||||
|
||||
effective = EffectivePolicy(
|
||||
name=policy_name,
|
||||
mode=policy.policy_type,
|
||||
allowed_users=expand_policy_users(runtime_config, policy),
|
||||
allowed_groups=set(policy.groups),
|
||||
)
|
||||
|
||||
if policy.policy_type != "protected":
|
||||
return effective
|
||||
|
||||
for method_name in policy.methods:
|
||||
method = runtime_config.auth_methods[method_name]
|
||||
if isinstance(method, IPAddressMethodModel):
|
||||
effective.ip_method_names.append(method_name)
|
||||
if "ip" not in effective.required_factors:
|
||||
effective.required_factors.append("ip")
|
||||
elif isinstance(method, LocalPasswordMethodModel):
|
||||
effective.password_method_names.append(method_name)
|
||||
effective.required_factors.append("password")
|
||||
effective.factor_expirations["password"] = method.session_expiration_minutes
|
||||
elif isinstance(method, TotpMethodModel):
|
||||
effective.totp_method_names.append(method_name)
|
||||
effective.required_factors.append("totp")
|
||||
effective.factor_expirations["totp"] = method.session_expiration_minutes or 720
|
||||
elif isinstance(method, OIDCMethodModel):
|
||||
effective.oidc_method_names.append(method_name)
|
||||
effective.required_factors.append("oidc")
|
||||
effective.factor_expirations["oidc"] = method.session_expiration_minutes
|
||||
|
||||
return effective
|
||||
|
||||
|
||||
def extract_client_ip(forwarded_for: str) -> str | None:
|
||||
if not forwarded_for:
|
||||
return None
|
||||
candidate = forwarded_for.split(",")[0].strip()
|
||||
try:
|
||||
return str(ipaddress.ip_address(candidate))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_ip_rules(client_ip: str | None, rules: list[IPRuleModel]) -> bool:
|
||||
if not client_ip:
|
||||
return False
|
||||
ip_value = ipaddress.ip_address(client_ip)
|
||||
for rule in rules:
|
||||
if rule.prefix_length is None:
|
||||
network = ipaddress.ip_network(f"{rule.address}/{'32' if ip_value.version == 4 else '128'}", strict=False)
|
||||
else:
|
||||
network = ipaddress.ip_network(f"{rule.address}/{rule.prefix_length}", strict=False)
|
||||
if ip_value in network:
|
||||
return rule.action == "allow"
|
||||
return False
|
||||
|
||||
|
||||
def evaluate_ip_access(runtime_config: RuntimeConfig, effective_policy: EffectivePolicy, client_ip: str | None) -> bool:
|
||||
if not effective_policy.ip_method_names:
|
||||
return True
|
||||
for method_name in effective_policy.ip_method_names:
|
||||
method = runtime_config.auth_methods[method_name]
|
||||
if isinstance(method, IPAddressMethodModel) and evaluate_ip_rules(client_ip, method.rules):
|
||||
return True
|
||||
return False
|
||||
63
containers/auth-gateway/auth_gateway/services/resolver.py
Normal file
63
containers/auth-gateway/auth_gateway/services/resolver.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from auth_gateway.models.applications import ApplicationModel
|
||||
from auth_gateway.models.routes import AppRoutesModel, RoutePolicyBindingModel
|
||||
from auth_gateway.models.runtime import RuntimeConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
host: str
|
||||
path: str
|
||||
application: ApplicationModel
|
||||
route: RoutePolicyBindingModel | None
|
||||
policy_name: str
|
||||
|
||||
|
||||
def normalize_host(raw_host: str) -> str:
|
||||
return raw_host.split(":", 1)[0].strip().lower()
|
||||
|
||||
|
||||
def normalize_path(raw_uri: str) -> str:
|
||||
parsed = urlsplit(raw_uri or "/")
|
||||
path = parsed.path or "/"
|
||||
return path if path.startswith("/") else f"/{path}"
|
||||
|
||||
|
||||
def resolve_application(runtime_config: RuntimeConfig, host: str) -> ApplicationModel | None:
|
||||
normalized_host = normalize_host(host)
|
||||
for application in runtime_config.applications.values():
|
||||
if normalized_host in {candidate.lower() for candidate in application.hosts}:
|
||||
return application
|
||||
return None
|
||||
|
||||
|
||||
def resolve_route(runtime_config: RuntimeConfig, application_id: str, path: str) -> tuple[RoutePolicyBindingModel | None, str | None]:
|
||||
app_routes: AppRoutesModel | None = runtime_config.routes_by_app.get(application_id)
|
||||
if not app_routes:
|
||||
return None, None
|
||||
|
||||
normalized_path = normalize_path(path)
|
||||
sorted_routes = sorted(app_routes.routes, key=lambda route: len(route.path_prefix), reverse=True)
|
||||
for route in sorted_routes:
|
||||
route_prefix = normalize_path(route.path_prefix)
|
||||
if normalized_path.startswith(route_prefix):
|
||||
return route, route.policy
|
||||
return None, app_routes.default_policy
|
||||
|
||||
|
||||
def resolve_request_context(runtime_config: RuntimeConfig, host: str, path: str) -> RequestContext | None:
|
||||
application = resolve_application(runtime_config, host)
|
||||
if not application:
|
||||
return None
|
||||
route, policy_name = resolve_route(runtime_config, application.id, path)
|
||||
if not policy_name:
|
||||
return None
|
||||
return RequestContext(
|
||||
host=normalize_host(host),
|
||||
path=normalize_path(path),
|
||||
application=application,
|
||||
route=route,
|
||||
policy_name=policy_name,
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from secrets import token_urlsafe
|
||||
|
||||
from auth_gateway.models.session import OIDCStateRecord, SessionRecord
|
||||
from auth_gateway.storage.sqlite import SQLiteStorage
|
||||
|
||||
|
||||
class SessionService:
|
||||
def __init__(self, storage: SQLiteStorage, default_session_minutes: int, oidc_state_ttl_minutes: int):
|
||||
self.storage = storage
|
||||
self.default_session_minutes = default_session_minutes
|
||||
self.oidc_state_ttl_minutes = oidc_state_ttl_minutes
|
||||
|
||||
def get_session(self, session_id: str | None) -> SessionRecord | None:
|
||||
if not session_id:
|
||||
return None
|
||||
session = self.storage.get_session(session_id)
|
||||
if not session:
|
||||
return None
|
||||
if session.expires_at <= datetime.now(UTC):
|
||||
self.storage.delete_session(session_id)
|
||||
return None
|
||||
return session
|
||||
|
||||
def issue_session(
|
||||
self,
|
||||
existing_session: SessionRecord | None = None,
|
||||
*,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
subject: str | None = None,
|
||||
groups: list[str] | None = None,
|
||||
add_factors: list[str] | None = None,
|
||||
metadata: dict | None = None,
|
||||
expires_in_minutes: int | None = None,
|
||||
) -> SessionRecord:
|
||||
now = datetime.now(UTC)
|
||||
session = existing_session or SessionRecord(
|
||||
session_id=token_urlsafe(32),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
expires_at=now + timedelta(minutes=expires_in_minutes or self.default_session_minutes),
|
||||
)
|
||||
if username is not None:
|
||||
session.username = username
|
||||
if email is not None:
|
||||
session.email = email
|
||||
if subject is not None:
|
||||
session.subject = subject
|
||||
if groups is not None:
|
||||
session.groups = groups
|
||||
if metadata:
|
||||
session.metadata.update(metadata)
|
||||
if add_factors:
|
||||
merged_factors = set(session.auth_factors)
|
||||
merged_factors.update(add_factors)
|
||||
session.auth_factors = sorted(merged_factors)
|
||||
requested_expiry = now + timedelta(minutes=expires_in_minutes or self.default_session_minutes)
|
||||
session.expires_at = min(session.expires_at, requested_expiry) if existing_session else requested_expiry
|
||||
session.updated_at = now
|
||||
self.storage.save_session(session)
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: str | None) -> None:
|
||||
if session_id:
|
||||
self.storage.delete_session(session_id)
|
||||
|
||||
def create_oidc_state(self, method_name: str, host: str, next_url: str) -> OIDCStateRecord:
|
||||
now = datetime.now(UTC)
|
||||
state = OIDCStateRecord(
|
||||
state=token_urlsafe(24),
|
||||
nonce=token_urlsafe(24),
|
||||
method_name=method_name,
|
||||
host=host,
|
||||
next_url=next_url,
|
||||
created_at=now,
|
||||
expires_at=now + timedelta(minutes=self.oidc_state_ttl_minutes),
|
||||
)
|
||||
self.storage.save_oidc_state(state)
|
||||
return state
|
||||
|
||||
def consume_oidc_state(self, state_value: str) -> OIDCStateRecord | None:
|
||||
oidc_state = self.storage.get_oidc_state(state_value)
|
||||
if not oidc_state:
|
||||
return None
|
||||
self.storage.delete_oidc_state(state_value)
|
||||
if oidc_state.expires_at <= datetime.now(UTC):
|
||||
return None
|
||||
return oidc_state
|
||||
@@ -0,0 +1,9 @@
|
||||
import pyotp
|
||||
|
||||
|
||||
def verify_totp(secret: str, token: str) -> bool:
|
||||
normalized_secret = secret.strip()
|
||||
normalized_token = token.strip().replace(" ", "")
|
||||
if not normalized_secret or not normalized_token:
|
||||
return False
|
||||
return pyotp.TOTP(normalized_secret).verify(normalized_token, valid_window=1)
|
||||
Reference in New Issue
Block a user