diff --git a/api_v2/tests.py b/api_v2/tests.py index a39b155..03f0071 100644 --- a/api_v2/tests.py +++ b/api_v2/tests.py @@ -1 +1,119 @@ -# Create your tests here. +import json +from unittest.mock import patch + +from django.test import TestCase + +from api_v2.models import ApiKey +from dns.models import DNSSettings, StaticHost + + +class ApiV2ManageDnsRecordTests(TestCase): + def setUp(self): + self.api_key = ApiKey.objects.create(name="dns-test-key", enabled=True) + self.url = "/api/v2/manage_dns_record/" + + @staticmethod + def _fake_export_dns_configuration(): + dns_settings, dns_settings_created = DNSSettings.objects.get_or_create(name="dns_settings") + dns_settings.pending_changes = False + dns_settings.save(update_fields=["pending_changes", "updated"]) + + @patch("api_v2.views_api.export_dns_configuration") + def test_post_creates_record(self, mock_export): + response = self.client.post( + self.url, + data=json.dumps({ + "hostname": "App.Example.com", + "ip_address": "10.20.30.40", + "skip_reload": True, + }), + content_type="application/json", + HTTP_TOKEN=str(self.api_key.token), + ) + + self.assertEqual(response.status_code, 201) + body = response.json() + self.assertEqual(body["hostname"], "app.example.com") + self.assertEqual(body["ip_address"], "10.20.30.40") + mock_export.assert_not_called() + self.assertTrue(StaticHost.objects.filter(hostname="app.example.com").exists()) + + @patch("api_v2.views_api.export_dns_configuration") + def test_post_fails_if_record_exists(self, mock_export): + StaticHost.objects.create(hostname="app.example.com", ip_address="10.20.30.40") + + response = self.client.post( + self.url, + data=json.dumps({ + "hostname": "app.example.com", + "ip_address": "10.20.30.99", + "skip_reload": False, + }), + content_type="application/json", + HTTP_TOKEN=str(self.api_key.token), + ) + + self.assertEqual(response.status_code, 400) + self.assertIn("already exists", response.json()["error_message"]) + self.assertEqual(str(StaticHost.objects.get(hostname="app.example.com").ip_address), "10.20.30.40") + mock_export.assert_not_called() + + @patch("api_v2.views_api.export_dns_configuration") + def test_put_upserts_existing_record(self, mock_export): + mock_export.side_effect = self._fake_export_dns_configuration + StaticHost.objects.create(hostname="app.example.com", ip_address="10.20.30.40") + DNSSettings.objects.create(name="dns_settings", pending_changes=True) + + response = self.client.put( + self.url, + data=json.dumps({ + "hostname": "app.example.com", + "ip_address": "10.20.30.41", + "skip_reload": False, + }), + content_type="application/json", + HTTP_TOKEN=str(self.api_key.token), + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(str(StaticHost.objects.get(hostname="app.example.com").ip_address), "10.20.30.41") + self.assertEqual(StaticHost.objects.filter(hostname="app.example.com").count(), 1) + self.assertFalse(DNSSettings.objects.get(name="dns_settings").pending_changes) + mock_export.assert_called_once() + + @patch("api_v2.views_api.export_dns_configuration") + def test_put_upserts_missing_record_as_create(self, mock_export): + mock_export.side_effect = self._fake_export_dns_configuration + + response = self.client.put( + self.url, + data=json.dumps({ + "hostname": "new.example.com", + "ip_address": "10.20.30.50", + "skip_reload": False, + }), + content_type="application/json", + HTTP_TOKEN=str(self.api_key.token), + ) + + self.assertEqual(response.status_code, 201) + self.assertTrue(StaticHost.objects.filter(hostname="new.example.com").exists()) + mock_export.assert_called_once() + + @patch("api_v2.views_api.export_dns_configuration") + def test_delete_deletes_record_and_does_not_require_ip(self, mock_export): + StaticHost.objects.create(hostname="del.example.com", ip_address="10.1.2.3") + + response = self.client.delete( + self.url, + data=json.dumps({ + "hostname": "del.example.com", + "skip_reload": True, + }), + content_type="application/json", + HTTP_TOKEN=str(self.api_key.token), + ) + + self.assertEqual(response.status_code, 200) + self.assertFalse(StaticHost.objects.filter(hostname="del.example.com").exists()) + mock_export.assert_not_called() diff --git a/api_v2/urls_api.py b/api_v2/urls_api.py index f711571..6c8ed5a 100644 --- a/api_v2/urls_api.py +++ b/api_v2/urls_api.py @@ -1,10 +1,17 @@ from django.urls import path -from .views_api import api_v2_manage_peer, api_v2_peer_list, api_v2_peer_detail, api_v2_wireguard_status +from .views_api import ( + api_v2_manage_dns_record, + api_v2_manage_peer, + api_v2_peer_detail, + api_v2_peer_list, + api_v2_wireguard_status, +) urlpatterns = [ path('manage_peer/', api_v2_manage_peer, name='api_v2_manage_peer'), + path('manage_dns_record/', api_v2_manage_dns_record, name='api_v2_manage_dns_record'), path('peer_list/', api_v2_peer_list, name='api_v2_peer_list'), path('peer_detail/', api_v2_peer_detail, name='api_v2_peer_detail'), path('wireguard_status/', api_v2_wireguard_status, name='api_v2_wireguard_status'), -] \ No newline at end of file +] diff --git a/api_v2/views_api.py b/api_v2/views_api.py index 0fa979a..26352c1 100644 --- a/api_v2/views_api.py +++ b/api_v2/views_api.py @@ -1,5 +1,6 @@ import ipaddress import json +import re from functools import wraps from typing import List, Optional, Tuple @@ -8,6 +9,8 @@ from django.http import JsonResponse from django.views.decorators.csrf import csrf_exempt from api.views import func_get_wireguard_status +from dns.models import DNSSettings, StaticHost +from dns.views import export_dns_configuration from routing_templates.models import RoutingTemplate from wireguard.models import Peer, PeerAllowedIP, WireGuardInstance from wireguard_peer.functions import func_create_new_peer @@ -164,6 +167,32 @@ def _get_wireguard_instance(instance_name: str) -> Optional[WireGuardInstance]: return + +def _validate_dns_hostname(hostname: str) -> Tuple[Optional[str], Optional[str]]: + if not isinstance(hostname, str): + return None, "Invalid hostname." + + normalized = hostname.strip().lower() + if not normalized: + return None, "Invalid hostname." + if "://" in normalized or "/" in normalized or ":" in normalized: + return None, "Invalid hostname." + + domain = normalized[2:] if normalized.startswith("*.") else normalized + labels = domain.split(".") + if len(labels) < 2: + return None, "Invalid hostname." + + for label in labels: + if not label: + return None, "Invalid hostname." + if not re.match(r"^[a-z0-9-]+$", label): + return None, "Invalid hostname." + if label.startswith("-") or label.endswith("-"): + return None, "Invalid hostname." + + return normalized, None + @csrf_exempt @api_doc( summary="Create / Update / Delete a WireGuard peer (and optionally reload the interface)", @@ -446,6 +475,165 @@ def api_v2_manage_peer(request): ) +@csrf_exempt +@api_doc( + summary="Create / Upsert / Delete a static DNS record identified by hostname", + auth="Header token: ", + methods=["POST", "PUT", "DELETE"], + params=[ + {"name": "hostname", "in": "json", "type": "string", "required": True, + "description": "DNS hostname to manage (supports wildcard like *.example.com)."}, + {"name": "ip_address", "in": "json", "type": "string", "required": False, + "description": "IPv4 address for the hostname record (required for POST/PUT, ignored for DELETE)."}, + {"name": "skip_reload", "in": "json", "type": "boolean", "required": False, "example": True, + "description": "If true, does not apply DNS changes immediately and only sets dns_settings.pending_changes=True."}, + ], + returns=[ + {"status": 200, "body": {"status": "success", "message": "DNS record updated successfully.", "hostname": "example.com", "ip_address": "10.0.0.50", "apply": {"success": True, "message": "..."}}}, + {"status": 200, "body": {"status": "success", "message": "DNS record deleted successfully.", "hostname": "example.com", "apply": {"success": True, "message": "..."}}}, + {"status": 201, "body": {"status": "success", "message": "DNS record created successfully.", "hostname": "example.com", "ip_address": "10.0.0.50", "apply": {"success": True, "message": "..."}}}, + {"status": 400, "body": {"status": "error", "error_message": "Invalid hostname."}}, + {"status": 403, "body": {"status": "error", "error_message": "Invalid API key."}}, + {"status": 404, "body": {"status": "error", "error_message": "DNS record not found for the provided hostname."}}, + {"status": 500, "body": {"status": "error", "error_message": "DNS changes were saved but apply failed: ..."}}, + {"status": 405, "body": {"status": "error", "error_message": "Method not allowed."}}, + ], + examples={ + "create_skip_reload": { + "method": "POST", + "json": { + "hostname": "app.example.com", + "ip_address": "10.20.30.40", + "skip_reload": True + } + }, + "put_upsert_apply": { + "method": "PUT", + "json": { + "hostname": "app.example.com", + "ip_address": "10.20.30.41", + "skip_reload": False + } + }, + "delete_skip_reload": { + "method": "DELETE", + "json": { + "hostname": "app.example.com", + "skip_reload": True + } + } + } +) +def api_v2_manage_dns_record(request): + if request.method not in ("POST", "PUT", "DELETE"): + return JsonResponse({"status": "error", "error_message": "Method not allowed."}, status=405) + + try: + payload = json.loads(request.body.decode("utf-8")) if request.body else {} + except Exception: + return JsonResponse({"status": "error", "error_message": "Invalid JSON body."}, status=400) + + api_key, api_error = validate_api_key(request) + if not api_key: + return JsonResponse({"status": "error", "error_message": api_error}, status=403) + + normalized_hostname, hostname_error = _validate_dns_hostname(payload.get("hostname")) + if hostname_error: + return JsonResponse({"status": "error", "error_message": hostname_error}, status=400) + + skip_reload = bool(payload.get("skip_reload", False)) + normalized_ip = None + if request.method in ("POST", "PUT"): + raw_ip = payload.get("ip_address") + if not isinstance(raw_ip, str) or not raw_ip.strip(): + return JsonResponse({"status": "error", "error_message": "Invalid ip_address."}, status=400) + + try: + ip = ipaddress.ip_address(raw_ip.strip()) + except Exception: + return JsonResponse({"status": "error", "error_message": "Invalid ip_address."}, status=400) + + if ip.version != 4: + return JsonResponse({"status": "error", "error_message": "Only IPv4 ip_address is supported."}, status=400) + normalized_ip = str(ip) + + export_error = None + + def _export_after_commit(): + nonlocal export_error + try: + export_dns_configuration() + except Exception as exc: + export_error = str(exc) + + with transaction.atomic(): + dns_settings, dns_settings_created = DNSSettings.objects.select_for_update().get_or_create(name="dns_settings") + + if request.method == "POST": + if StaticHost.objects.filter(hostname=normalized_hostname).exists(): + return JsonResponse( + {"status": "error", "error_message": "DNS record already exists for the provided hostname."}, + status=400, + ) + record = StaticHost.objects.create(hostname=normalized_hostname, ip_address=normalized_ip) + action_message = "DNS record created successfully." + status_code = 201 + elif request.method == "PUT": + record = StaticHost.objects.filter(hostname=normalized_hostname).first() + if record: + record.ip_address = normalized_ip + record.save(update_fields=["ip_address", "updated"]) + action_message = "DNS record updated successfully." + status_code = 200 + else: + record = StaticHost.objects.create(hostname=normalized_hostname, ip_address=normalized_ip) + action_message = "DNS record created successfully." + status_code = 201 + else: + record = StaticHost.objects.filter(hostname=normalized_hostname).first() + if not record: + return JsonResponse( + {"status": "error", "error_message": "DNS record not found for the provided hostname."}, + status=404, + ) + record.delete() + action_message = "DNS record deleted successfully." + status_code = 200 + + if skip_reload: + dns_settings.pending_changes = True + dns_settings.save(update_fields=["pending_changes", "updated"]) + apply_success = True + apply_message = "Changes saved. Apply skipped (pending_changes set to True)." + else: + # Mark pending inside transaction; actual export runs only after commit. + dns_settings.pending_changes = True + dns_settings.save(update_fields=["pending_changes", "updated"]) + transaction.on_commit(_export_after_commit) + apply_success = True + apply_message = "DNS configuration applied successfully." + + if not skip_reload and export_error: + return JsonResponse( + { + "status": "error", + "error_message": f"DNS changes were saved but apply failed: {export_error}", + }, + status=500, + ) + + response_data = { + "status": "success", + "message": action_message, + "hostname": normalized_hostname, + "apply": {"success": apply_success, "message": apply_message}, + } + if request.method in ("POST", "PUT"): + response_data["ip_address"] = normalized_ip + + return JsonResponse(response_data, status=status_code) + + @csrf_exempt @api_doc( summary="List peers for a specific instance (required)", diff --git a/dns/views.py b/dns/views.py index 3932381..ad97645 100644 --- a/dns/views.py +++ b/dns/views.py @@ -18,13 +18,13 @@ from .models import StaticHost def export_dns_configuration(): - dns_settings, _ = DNSSettings.objects.get_or_create(name='dns_settings') - dns_settings.pending_changes = False - dns_settings.save() dnsmasq_config = generate_dnsmasq_config() with open(settings.DNS_CONFIG_FILE, 'w') as f: f.write(dnsmasq_config) compress_dnsmasq_config() + dns_settings, dns_settings_created = DNSSettings.objects.get_or_create(name='dns_settings') + dns_settings.pending_changes = False + dns_settings.save(update_fields=['pending_changes', 'updated']) return @@ -280,4 +280,4 @@ def view_toggle_dns_list(request): dns_list.save() export_dns_configuration() messages.success(request, _('DNS Filter List disabled successfully')) - return redirect('/dns/') \ No newline at end of file + return redirect('/dns/')