Files
wireguard_webadmin/containers/auth-gateway/auth_gateway/services/policy_engine.py

116 lines
4.2 KiB
Python

import ipaddress
from dataclasses import dataclass, field
from auth_gateway.models.auth import (
IPAddressMethodModel,
IPRuleModel,
LocalPasswordMethodModel,
OIDCMethodModel,
PolicyModel,
TotpMethodModel,
)
from auth_gateway.models.runtime import RuntimeConfig
@dataclass
class EffectivePolicy:
name: str
mode: str
error_message: str | None = None
required_factors: list[str] = field(default_factory=list)
allowed_users: set[str] = field(default_factory=set)
allowed_groups: set[str] = field(default_factory=set)
ip_method_names: list[str] = field(default_factory=list)
totp_method_names: list[str] = field(default_factory=list)
password_method_names: list[str] = field(default_factory=list)
oidc_method_names: list[str] = field(default_factory=list)
factor_expirations: dict[str, int] = field(default_factory=dict)
def expand_policy_users(runtime_config: RuntimeConfig, policy: PolicyModel) -> set[str]:
usernames: set[str] = set()
for group_name in policy.groups:
group = runtime_config.groups.get(group_name)
if group:
usernames.update(group.users)
return usernames
def build_effective_policy(runtime_config: RuntimeConfig, policy_name: str) -> EffectivePolicy | None:
policy = runtime_config.policies.get(policy_name)
if not policy:
return None
effective = EffectivePolicy(
name=policy_name,
mode=policy.policy_type,
error_message=policy.error_message,
allowed_users=expand_policy_users(runtime_config, policy),
allowed_groups=set(policy.groups),
)
if policy.policy_type != "protected":
return effective
if not policy.methods:
return EffectivePolicy(
name=policy_name,
mode="error",
error_message="Policy configuration error: protected policy has no authentication methods.",
)
for method_name in policy.methods:
method = runtime_config.auth_methods[method_name]
if isinstance(method, IPAddressMethodModel):
effective.ip_method_names.append(method_name)
if "ip" not in effective.required_factors:
effective.required_factors.append("ip")
elif isinstance(method, LocalPasswordMethodModel):
effective.password_method_names.append(method_name)
effective.required_factors.append("password")
effective.factor_expirations["password"] = method.session_expiration_minutes
elif isinstance(method, TotpMethodModel):
effective.totp_method_names.append(method_name)
effective.required_factors.append("totp")
effective.factor_expirations["totp"] = method.session_expiration_minutes or 720
elif isinstance(method, OIDCMethodModel):
effective.oidc_method_names.append(method_name)
effective.required_factors.append("oidc")
effective.factor_expirations["oidc"] = method.session_expiration_minutes
return effective
def extract_client_ip(forwarded_for: str) -> str | None:
if not forwarded_for:
return None
candidate = forwarded_for.split(",")[0].strip()
try:
return str(ipaddress.ip_address(candidate))
except ValueError:
return None
def evaluate_ip_rules(client_ip: str | None, rules: list[IPRuleModel]) -> bool:
if not client_ip:
return False
ip_value = ipaddress.ip_address(client_ip)
for rule in rules:
if rule.prefix_length is None:
network = ipaddress.ip_network(f"{rule.address}/{'32' if ip_value.version == 4 else '128'}", strict=False)
else:
network = ipaddress.ip_network(f"{rule.address}/{rule.prefix_length}", strict=False)
if ip_value in network:
return rule.action == "allow"
return False
def evaluate_ip_access(runtime_config: RuntimeConfig, effective_policy: EffectivePolicy, client_ip: str | None) -> bool:
if not effective_policy.ip_method_names:
return True
for method_name in effective_policy.ip_method_names:
method = runtime_config.auth_methods[method_name]
if isinstance(method, IPAddressMethodModel) and evaluate_ip_rules(client_ip, method.rules):
return True
return False