From cf4e86ad9382002716644377d4f93646c191c4bd Mon Sep 17 00:00:00 2001 From: Eduardo Silva Date: Tue, 27 Jan 2026 14:44:20 -0300 Subject: [PATCH] Add route policy restriction logic to Peer model --- wireguard/models.py | 75 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/wireguard/models.py b/wireguard/models.py index fd4a6e0..e46ff29 100644 --- a/wireguard/models.py +++ b/wireguard/models.py @@ -90,7 +90,6 @@ class WireGuardInstance(models.Model): priority__gte=1 ) .values_list('allowed_ip', 'netmask') - .distinct() ) return normalize_cidr_pairs(rows) @@ -104,7 +103,6 @@ class WireGuardInstance(models.Model): priority=0 ) .values_list('allowed_ip', 'netmask') - .distinct() ) return normalize_cidr_pairs(rows) @@ -133,11 +131,19 @@ class Peer(models.Model): @property def announced_networks(self): + prefetched = getattr(self, "_prefetched_objects_cache", {}) + if "peerallowedip_set" in prefetched: + rows = [ + (aip.allowed_ip, aip.netmask) + for aip in prefetched["peerallowedip_set"] + if aip.config_file == "server" and aip.priority >= 1 + ] + return normalize_cidr_pairs(rows) + rows = ( self.peerallowedip_set .filter(config_file='server', priority__gte=1) .values_list('allowed_ip', 'netmask') - .distinct() ) return normalize_cidr_pairs(rows) @@ -145,11 +151,31 @@ class Peer(models.Model): def client_routes(self): routes = [] + prefetched = getattr(self, "_prefetched_objects_cache", {}) + if "peerallowedip_set" in prefetched: + allowedips = prefetched["peerallowedip_set"] + + rows_client = [(aip.allowed_ip, aip.netmask) for aip in allowedips if aip.config_file == "client"] + routes.extend(normalize_cidr_pairs(rows_client)) + + if self.routing_template: + routes.extend(self.routing_template.template_routes) + + normalized = normalize_cidr_list(routes) + + rows_announced = [(aip.allowed_ip, aip.netmask) for aip in allowedips if aip.config_file == "server"] + exclude = set(normalize_cidr_pairs(rows_announced)) + + final_routes = [cidr for cidr in normalized if cidr not in exclude] + if not final_routes or "0.0.0.0/0" in final_routes: + return ["0.0.0.0/0"] + return final_routes + + # Fallback (no prefetch): your original DB-based implementation rows_client = ( self.peerallowedip_set .filter(config_file='client') .values_list('allowed_ip', 'netmask') - .distinct() ) routes.extend(normalize_cidr_pairs(rows_client)) @@ -162,7 +188,6 @@ class Peer(models.Model): self.peerallowedip_set .filter(config_file='server') .values_list('allowed_ip', 'netmask') - .distinct() ) exclude = set(normalize_cidr_pairs(rows_announced)) @@ -174,14 +199,52 @@ class Peer(models.Model): @property def main_addresses(self): + prefetched = getattr(self, "_prefetched_objects_cache", {}) + if "peerallowedip_set" in prefetched: + rows = [ + (aip.allowed_ip, aip.netmask) + for aip in prefetched["peerallowedip_set"] + if aip.config_file == "server" and aip.priority == 0 + ] + return normalize_cidr_pairs(rows) + rows = ( self.peerallowedip_set .filter(config_file='server', priority=0) .values_list('allowed_ip', 'netmask') - .distinct() ) return normalize_cidr_pairs(rows) + @property + def is_route_policy_restricted(self) -> bool: + # 1) Enforcement must be enabled somewhere (template OR instance). + template_enforced = bool(self.routing_template and self.routing_template.enforce_route_policy) + instance_enforced = bool(self.wireguard_instance and self.wireguard_instance.enforce_route_policy) + if not (template_enforced or instance_enforced): + return False + + # 2) If there is a routing template assigned, and its type is "default", the peer is not restricted. + if self.routing_template: + if self.routing_template.route_type == "default": + return False + else: + return True + + # 3) If there is any client-side allowed IP entry, the peer has explicit (non-default) registered routes. + # - If peerallowedip_set was prefetched, we scan in-memory. + # - Otherwise, we issue a single EXISTS() query. + prefetched = getattr(self, "_prefetched_objects_cache", {}) + if "peerallowedip_set" in prefetched: + has_client_routes = any(aip.config_file == "client" for aip in prefetched["peerallowedip_set"]) + else: + has_client_routes = self.peerallowedip_set.filter(config_file="client").exists() + + if has_client_routes: + return True + + # 5) Otherwise, the peer effectively falls back to default route (0.0.0.0/0), so it is not restricted. + return False + class PeerStatus(models.Model): peer = models.OneToOneField(Peer, on_delete=models.CASCADE)