diff --git a/README.md b/README.md index 9bc044f..c7eda8d 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ The configuration portal supports using a database (SQLite, MySQL, MsSQL, or Pos * Docker ready * Can be used with existing WireGuard setups * Support for multiple WireGuard interfaces -* Supports multiple WireGuard backends (wgctrl or MikroTik) +* Supports multiple WireGuard backends (wgctrl, MikroTik, or pfSense) * Peer Expiry Feature * Handles route and DNS settings like wg-quick does * Exposes Prometheus metrics for monitoring and alerting diff --git a/config.yml.sample b/config.yml.sample index 039d034..638cf9b 100644 --- a/config.yml.sample +++ b/config.yml.sample @@ -93,4 +93,16 @@ auth: admin_value_regex: ^true$ admin_group_regex: ^admin-group-name$ registration_enabled: true - log_user_info: true \ No newline at end of file + log_user_info: true + +backend: + default: local + pfsense: + - id: pfsense1 + display_name: "Main pfSense Firewall" + api_url: "https://pfsense.example.com" # Base URL without /api/v2 (endpoints already include it) + api_key: "your-api-key" # Generate in pfSense under 'System' -> 'REST API' -> 'Keys' + api_verify_tls: true + api_timeout: 30s + concurrency: 5 + debug: false \ No newline at end of file diff --git a/docs/documentation/usage/backends.md b/docs/documentation/usage/backends.md index e891d95..aeac9d6 100644 --- a/docs/documentation/usage/backends.md +++ b/docs/documentation/usage/backends.md @@ -8,6 +8,7 @@ A global default backend determines where newly created interfaces go (unless yo **Supported backends:** - **Local** (default): Manages interfaces on the host running WireGuard Portal (Linux WireGuard via wgctrl). Use this when the portal should directly configure wg devices on the same server. - **MikroTik** RouterOS (_beta_): Manages interfaces and peers on MikroTik devices via the RouterOS REST API. Use this to control WG interfaces on RouterOS v7+. +- **pfSense** (_alpha_): Manages interfaces and peers on pfSense firewalls via the pfSense REST API. How backend selection works: - The default backend is configured at `backend.default` (_local_ or the id of a defined MikroTik backend). @@ -54,4 +55,37 @@ backend: ### Known limitations: - The MikroTik backend is still in beta. Some features may not work as expected. -- Not all WireGuard Portal features are supported yet (e.g., no support for interface hooks) \ No newline at end of file +- Not all WireGuard Portal features are supported yet (e.g., no support for interface hooks) + +## Configuring pfSense backends + +> :warning: The pfSense backend is currently **alpha**. Only basic interface and peer CRUD are supported. Traffic statistics (rx/tx, last handshake) are not exposed by the pfSense REST API and will show as empty. + +The pfSense backend talks to the pfSense REST API (pfSense Plus / CE with the REST API package installed). Point the backend at the appliance hostname without appending `/api/v2` — the portal appends `/api/v2` automatically. + +### Prerequisites on pfSense: +- pfSense with the REST API package enabled (`System -> API`) and WireGuard configured. +- An API key with permissions for WireGuard endpoints. If you use a read-only key, set `core.restore_state: false` in `config.yml` to avoid write attempts at startup. +- HTTPS recommended; set `api_verify_tls: false` only for lab/self-signed setups. + +Example WireGuard Portal configuration: + +```yaml +backend: + # default backend decides where new interfaces are created + default: pfsense1 + + pfsense: + - id: pfsense1 # unique id, not "local" + display_name: Main pfSense # optional nice name + api_url: https://pfsense.example.com # no trailing /api/v2 + api_key: your-api-key + api_verify_tls: true + api_timeout: 30s + concurrency: 5 + debug: false +``` + +### Known limitations: +- Alpha quality: behavior and API coverage may change. +- Statistics (rx/tx bytes, last handshake) are not available from the pfSense REST API today. diff --git a/internal/adapters/wgcontroller/pfsense.go b/internal/adapters/wgcontroller/pfsense.go new file mode 100644 index 0000000..89a0a0f --- /dev/null +++ b/internal/adapters/wgcontroller/pfsense.go @@ -0,0 +1,979 @@ +package wgcontroller + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "strings" + "sync" + "time" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/lowlevel" +) + +// PfsenseController implements the InterfaceController interface for pfSense firewalls. +// It uses the pfSense REST API (https://pfrest.org/) to manage WireGuard interfaces and peers. +// API endpoint paths and field names should be verified against the Swagger documentation: +// https://pfrest.org/api-docs/ + +type PfsenseController struct { + coreCfg *config.Config + cfg *config.BackendPfsense + + client *lowlevel.PfsenseApiClient + + // Add mutexes to prevent race conditions + interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex + peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex + coreMutex sync.Mutex // for updating the core configuration such as routing table or DNS settings +} + +func NewPfsenseController(coreCfg *config.Config, cfg *config.BackendPfsense) (*PfsenseController, error) { + client, err := lowlevel.NewPfsenseApiClient(coreCfg, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create pfSense API client: %w", err) + } + + return &PfsenseController{ + coreCfg: coreCfg, + cfg: cfg, + + client: client, + + interfaceMutexes: sync.Map{}, + peerMutexes: sync.Map{}, + coreMutex: sync.Mutex{}, + }, nil +} + +func (c *PfsenseController) GetId() domain.InterfaceBackend { + return domain.InterfaceBackend(c.cfg.Id) +} + +// getInterfaceMutex returns a mutex for the given interface to prevent concurrent modifications +func (c *PfsenseController) getInterfaceMutex(id domain.InterfaceIdentifier) *sync.Mutex { + mutex, _ := c.interfaceMutexes.LoadOrStore(id, &sync.Mutex{}) + return mutex.(*sync.Mutex) +} + +// getPeerMutex returns a mutex for the given peer to prevent concurrent modifications +func (c *PfsenseController) getPeerMutex(id domain.PeerIdentifier) *sync.Mutex { + mutex, _ := c.peerMutexes.LoadOrStore(id, &sync.Mutex{}) + return mutex.(*sync.Mutex) +} + +// region wireguard-related + +func (c *PfsenseController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) { + // Query WireGuard tunnels from pfSense API + // Using pfSense REST API v2 endpoints: GET /api/v2/vpn/wireguard/tunnels + // Field names should be verified against Swagger docs: https://pfrest.org/api-docs/ + wgReply := c.client.Query(ctx, "/api/v2/vpn/wireguard/tunnels", &lowlevel.PfsenseRequestOptions{}) + if wgReply.Status != lowlevel.PfsenseApiStatusOk { + return nil, fmt.Errorf("failed to query interfaces: %v", wgReply.Error) + } + + // Parallelize loading of interface details to speed up overall latency. + // Use a bounded semaphore to avoid overloading the pfSense device. + maxConcurrent := c.cfg.GetConcurrency() + sem := make(chan struct{}, maxConcurrent) + + interfaces := make([]domain.PhysicalInterface, 0, len(wgReply.Data)) + var mu sync.Mutex + var wgWait sync.WaitGroup + var firstErr error + ctx2, cancel := context.WithCancel(ctx) + defer cancel() + + for _, wgObj := range wgReply.Data { + wgWait.Add(1) + sem <- struct{}{} // block if more than maxConcurrent requests are processing + go func(wg lowlevel.GenericJsonObject) { + defer wgWait.Done() + defer func() { <-sem }() // read from the semaphore and make space for the next entry + if firstErr != nil { + return + } + pi, err := c.loadInterfaceData(ctx2, wg) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + cancel() + } + mu.Unlock() + return + } + mu.Lock() + interfaces = append(interfaces, *pi) + mu.Unlock() + }(wgObj) + } + + wgWait.Wait() + if firstErr != nil { + return nil, firstErr + } + + return interfaces, nil +} + +func (c *PfsenseController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) ( + *domain.PhysicalInterface, + error, +) { + // First, get the tunnel ID by querying by name + wgReply := c.client.Query(ctx, "/api/v2/vpn/wireguard/tunnels", &lowlevel.PfsenseRequestOptions{ + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status != lowlevel.PfsenseApiStatusOk { + return nil, fmt.Errorf("failed to query interface %s: %v", id, wgReply.Error) + } + + if len(wgReply.Data) == 0 { + return nil, fmt.Errorf("interface %s not found", id) + } + + tunnelId := wgReply.Data[0].GetString("id") + + // Query the specific tunnel endpoint to get full details including addresses + // Endpoint: GET /api/v2/vpn/wireguard/tunnel?id={id} + if tunnelId != "" { + tunnelReply := c.client.Get(ctx, "/api/v2/vpn/wireguard/tunnel", &lowlevel.PfsenseRequestOptions{ + Filters: map[string]string{ + "id": tunnelId, + }, + }) + if tunnelReply.Status == lowlevel.PfsenseApiStatusOk && tunnelReply.Data != nil { + // Use the detailed tunnel response which includes addresses + return c.loadInterfaceData(ctx, tunnelReply.Data) + } + // Fall back to list response if detail query fails + if c.cfg.Debug { + slog.Debug("failed to query detailed tunnel info, using list response", "interface", id, "tunnel_id", tunnelId) + } + } + + return c.loadInterfaceData(ctx, wgReply.Data[0]) +} + +func (c *PfsenseController) loadInterfaceData( + ctx context.Context, + wireGuardObj lowlevel.GenericJsonObject, +) (*domain.PhysicalInterface, error) { + deviceName := wireGuardObj.GetString("name") + deviceId := wireGuardObj.GetString("id") + + // Extract addresses from the tunnel data + // The tunnel response may include an "addresses" array when queried via /tunnel?id={id} + addresses := c.extractAddresses(wireGuardObj, nil) + + // If addresses weren't found in the tunnel object and we have a tunnel ID, + // query the specific tunnel endpoint to get full details including addresses + // Endpoint: GET /api/v2/vpn/wireguard/tunnel?id={id} + if len(addresses) == 0 && deviceId != "" { + tunnelReply := c.client.Get(ctx, "/api/v2/vpn/wireguard/tunnel", &lowlevel.PfsenseRequestOptions{ + Filters: map[string]string{ + "id": deviceId, + }, + }) + if tunnelReply.Status == lowlevel.PfsenseApiStatusOk && tunnelReply.Data != nil { + // Extract addresses from the detailed tunnel response + parsedAddrs := c.extractAddresses(tunnelReply.Data, nil) + if len(parsedAddrs) > 0 { + addresses = parsedAddrs + if c.cfg.Debug { + slog.Debug("loaded addresses from detailed tunnel query", "interface", deviceName, "count", len(addresses)) + } + } + } + } + + interfaceModel, err := c.convertWireGuardInterface(wireGuardObj, nil, addresses) + if err != nil { + return nil, fmt.Errorf("interface convert failed for %s: %w", deviceName, err) + } + return &interfaceModel, nil +} + +func (c *PfsenseController) extractAddresses( + wgObj lowlevel.GenericJsonObject, + ifaceObj lowlevel.GenericJsonObject, +) []domain.Cidr { + addresses := make([]domain.Cidr, 0) + + // Try to get addresses from ifaceObj first + if ifaceObj != nil { + addrStr := ifaceObj.GetString("addresses") + if addrStr != "" { + // Addresses might be comma-separated or in an array + addrs, _ := domain.CidrsFromString(addrStr) + addresses = append(addresses, addrs...) + } + } + + // Try to get addresses from wgObj - check if it's an array first + if len(addresses) == 0 { + if addressesValue, ok := wgObj["addresses"]; ok && addressesValue != nil { + if addressesArray, ok := addressesValue.([]any); ok { + // Parse addresses array (from /tunnel?id={id} response) + // Each object has "address" and "mask" fields + for _, addrItem := range addressesArray { + if addrObj, ok := addrItem.(map[string]any); ok { + address := "" + mask := 0 + + // Extract address + if addrVal, ok := addrObj["address"]; ok { + if addrStr, ok := addrVal.(string); ok { + address = addrStr + } else { + address = fmt.Sprintf("%v", addrVal) + } + } + + // Extract mask + if maskVal, ok := addrObj["mask"]; ok { + if maskInt, ok := maskVal.(int); ok { + mask = maskInt + } else if maskFloat, ok := maskVal.(float64); ok { + mask = int(maskFloat) + } else if maskStr, ok := maskVal.(string); ok { + if maskInt, err := strconv.Atoi(maskStr); err == nil { + mask = maskInt + } + } + } + + // Convert to CIDR format + if address != "" && mask > 0 { + cidrStr := fmt.Sprintf("%s/%d", address, mask) + if cidr, err := domain.CidrFromString(cidrStr); err == nil { + addresses = append(addresses, cidr) + } + } else if address != "" { + // Try parsing as CIDR string directly + if cidr, err := domain.CidrFromString(address); err == nil { + addresses = append(addresses, cidr) + } + } + } + } + } else if addrStr, ok := addressesValue.(string); ok { + // Fallback: try parsing as comma-separated string + addrs, _ := domain.CidrsFromString(addrStr) + addresses = append(addresses, addrs...) + } + } else { + // Try as string field + addrStr := wgObj.GetString("addresses") + if addrStr != "" { + addrs, _ := domain.CidrsFromString(addrStr) + addresses = append(addresses, addrs...) + } + } + } + + return addresses +} + +// parseAddressArray parses an array of address objects from the pfSense API +// Each object has "address" and "mask" fields (similar to allowedips structure) +func (c *PfsenseController) parseAddressArray(addressArray []lowlevel.GenericJsonObject) []domain.Cidr { + addresses := make([]domain.Cidr, 0, len(addressArray)) + + for _, addrObj := range addressArray { + address := addrObj.GetString("address") + mask := addrObj.GetInt("mask") + + if address != "" && mask > 0 { + cidrStr := fmt.Sprintf("%s/%d", address, mask) + if cidr, err := domain.CidrFromString(cidrStr); err == nil { + addresses = append(addresses, cidr) + } + } else if address != "" { + // Try parsing as CIDR string directly + if cidr, err := domain.CidrFromString(address); err == nil { + addresses = append(addresses, cidr) + } + } + } + + return addresses +} + +func (c *PfsenseController) convertWireGuardInterface( + wg, iface lowlevel.GenericJsonObject, + addresses []domain.Cidr, +) ( + domain.PhysicalInterface, + error, +) { + // Map pfSense field names to our domain model + // Field names should be verified against the Swagger UI: https://pfrest.org/api-docs/ + // The implementation attempts to handle both camelCase and kebab-case variations + privateKey := wg.GetString("privatekey") + if privateKey == "" { + privateKey = wg.GetString("private-key") + } + publicKey := wg.GetString("publickey") + if publicKey == "" { + publicKey = wg.GetString("public-key") + } + + listenPort := wg.GetInt("listenport") + if listenPort == 0 { + listenPort = wg.GetInt("listen-port") + } + + mtu := wg.GetInt("mtu") + running := wg.GetBool("running") + disabled := wg.GetBool("disabled") + + // TODO: Interface statistics (rx/tx bytes) are not currently supported + // by the pfSense REST API. This functionality is reserved for future implementation. + var rxBytes, txBytes uint64 + + pi := domain.PhysicalInterface{ + Identifier: domain.InterfaceIdentifier(wg.GetString("name")), + KeyPair: domain.KeyPair{ + PrivateKey: privateKey, + PublicKey: publicKey, + }, + ListenPort: listenPort, + Addresses: addresses, + Mtu: mtu, + FirewallMark: 0, + DeviceUp: running && !disabled, + ImportSource: domain.ControllerTypePfsense, + DeviceType: domain.ControllerTypePfsense, + BytesUpload: txBytes, + BytesDownload: rxBytes, + } + + // Extract description - pfSense API uses "descr" field + description := wg.GetString("descr") + if description == "" { + description = wg.GetString("description") + } + if description == "" { + description = wg.GetString("comment") + } + + pi.SetExtras(domain.PfsenseInterfaceExtras{ + Id: wg.GetString("id"), + Comment: description, + Disabled: disabled, + }) + + return pi, nil +} + +func (c *PfsenseController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) ( + []domain.PhysicalPeer, + error, +) { + // Query all peers and filter by interface client-side + // Using pfSense REST API v2 endpoints (https://pfrest.org/) + // The API uses query parameters like ?id=0 for specific items, but we need to filter + // by interface (tun field), so we fetch all peers and filter client-side + wgReply := c.client.Query(ctx, "/api/v2/vpn/wireguard/peers", &lowlevel.PfsenseRequestOptions{}) + if wgReply.Status != lowlevel.PfsenseApiStatusOk { + return nil, fmt.Errorf("failed to query peers for %s: %v", deviceId, wgReply.Error) + } + + if len(wgReply.Data) == 0 { + return nil, nil + } + + // Filter peers client-side by checking the "tun" field in each peer + // pfSense peer responses use "tun" field to indicate which tunnel/interface the peer belongs to + peers := make([]domain.PhysicalPeer, 0, len(wgReply.Data)) + for _, peer := range wgReply.Data { + // Check if this peer belongs to the requested interface + // pfSense uses "tun" field with the interface name (e.g., "tun_wg0") + peerTun := peer.GetString("tun") + if peerTun == "" { + // Try alternative field names as fallback + peerTun = peer.GetString("interface") + if peerTun == "" { + peerTun = peer.GetString("tunnel") + } + } + + // Only include peers that match the requested interface name + if peerTun != string(deviceId) { + if c.cfg.Debug { + slog.Debug("skipping peer - interface mismatch", + "peer", peer.GetString("name"), + "peer_tun", peerTun, + "requested_interface", deviceId, + "peer_id", peer.GetString("id")) + } + continue + } + + // Use peer data directly from the list response + peerModel, err := c.convertWireGuardPeer(peer) + if err != nil { + return nil, fmt.Errorf("peer convert failed for %v: %w", peer.GetString("name"), err) + } + peers = append(peers, peerModel) + } + + if c.cfg.Debug { + slog.Debug("filtered peers for interface", + "interface", deviceId, + "total_peers_from_api", len(wgReply.Data), + "filtered_peers", len(peers)) + } + + return peers, nil +} + +func (c *PfsenseController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) ( + domain.PhysicalPeer, + error, +) { + publicKey := peer.GetString("publickey") + if publicKey == "" { + publicKey = peer.GetString("public-key") + } + + privateKey := peer.GetString("privatekey") + if privateKey == "" { + privateKey = peer.GetString("private-key") + } + + presharedKey := peer.GetString("presharedkey") + if presharedKey == "" { + presharedKey = peer.GetString("preshared-key") + } + + // pfSense returns allowedips as an array of objects with "address" and "mask" fields + // Example: [{"address": "10.1.2.3", "mask": 32, ...}, ...] + var allowedAddresses []domain.Cidr + if allowedIPsValue, ok := peer["allowedips"]; ok { + if allowedIPsArray, ok := allowedIPsValue.([]any); ok { + // Parse array of objects + for _, item := range allowedIPsArray { + if itemObj, ok := item.(map[string]any); ok { + address := "" + mask := 0 + + // Extract address + if addrVal, ok := itemObj["address"]; ok { + if addrStr, ok := addrVal.(string); ok { + address = addrStr + } else { + address = fmt.Sprintf("%v", addrVal) + } + } + + // Extract mask + if maskVal, ok := itemObj["mask"]; ok { + if maskInt, ok := maskVal.(int); ok { + mask = maskInt + } else if maskFloat, ok := maskVal.(float64); ok { + mask = int(maskFloat) + } else if maskStr, ok := maskVal.(string); ok { + if maskInt, err := strconv.Atoi(maskStr); err == nil { + mask = maskInt + } + } + } + + // Convert to CIDR format (e.g., "10.1.2.3/32") + if address != "" && mask > 0 { + cidrStr := fmt.Sprintf("%s/%d", address, mask) + if cidr, err := domain.CidrFromString(cidrStr); err == nil { + allowedAddresses = append(allowedAddresses, cidr) + } + } + } + } + } else if allowedIPsStr, ok := allowedIPsValue.(string); ok { + // Fallback: try parsing as comma-separated string + allowedAddresses, _ = domain.CidrsFromString(allowedIPsStr) + } + } + + // Fallback to string parsing if array parsing didn't work + if len(allowedAddresses) == 0 { + allowedIPsStr := peer.GetString("allowedips") + if allowedIPsStr == "" { + allowedIPsStr = peer.GetString("allowed-ips") + } + if allowedIPsStr != "" { + allowedAddresses, _ = domain.CidrsFromString(allowedIPsStr) + } + } + + endpoint := peer.GetString("endpoint") + port := peer.GetString("port") + + // Combine endpoint and port if both are available + if endpoint != "" && port != "" { + // Check if endpoint already contains a port + if !strings.Contains(endpoint, ":") { + endpoint = fmt.Sprintf("%s:%s", endpoint, port) + } + } else if endpoint == "" && port != "" { + // If only port is available, we can't construct a full endpoint + // This might be used with the interface's listenport + } + + keepAliveSeconds := 0 + keepAliveStr := peer.GetString("persistentkeepalive") + if keepAliveStr == "" { + keepAliveStr = peer.GetString("persistent-keepalive") + } + if keepAliveStr != "" { + duration, err := time.ParseDuration(keepAliveStr) + if err == nil { + keepAliveSeconds = int(duration.Seconds()) + } else { + // Try parsing as integer (seconds) + if secs, err := strconv.Atoi(keepAliveStr); err == nil { + keepAliveSeconds = secs + } + } + } + + // TODO: Peer statistics (last handshake, rx/tx bytes) are not currently supported + // by the pfSense REST API. This functionality is reserved for future implementation + // when the API adds support for these fields. + // See: https://github.com/jaredhendrickson13/pfsense-api/issues (issue opened by user) + // + // When supported, extract fields like: + // - lastHandshake: peer.GetString("lasthandshake") or peer.GetString("last-handshake") + // - rxBytes: peer.GetInt("rxbytes") or peer.GetInt("rx-bytes") + // - txBytes: peer.GetInt("txbytes") or peer.GetInt("tx-bytes") + lastHandshakeTime := time.Time{} + rxBytes := uint64(0) + txBytes := uint64(0) + + peerModel := domain.PhysicalPeer{ + Identifier: domain.PeerIdentifier(publicKey), + Endpoint: endpoint, + AllowedIPs: allowedAddresses, + KeyPair: domain.KeyPair{ + PublicKey: publicKey, + PrivateKey: privateKey, + }, + PresharedKey: domain.PreSharedKey(presharedKey), + PersistentKeepalive: keepAliveSeconds, + LastHandshake: lastHandshakeTime, + ProtocolVersion: 0, // pfSense may not expose protocol version + BytesUpload: txBytes, + BytesDownload: rxBytes, + ImportSource: domain.ControllerTypePfsense, + } + + // Extract description/name - pfSense API uses "descr" field + description := peer.GetString("descr") + if description == "" { + description = peer.GetString("description") + } + if description == "" { + description = peer.GetString("comment") + } + + // Extract name - pfSense API may use "name" or "descr" + name := peer.GetString("name") + if name == "" { + name = peer.GetString("descr") + } + if name == "" { + name = description // fallback to description if name is not available + } + + peerModel.SetExtras(domain.PfsensePeerExtras{ + Id: peer.GetString("id"), + Name: name, + Comment: description, + Disabled: peer.GetBool("disabled"), + ClientEndpoint: "", // pfSense may handle this differently + ClientAddress: "", // pfSense may handle this differently + ClientDns: "", // pfSense may handle this differently + ClientKeepalive: 0, // pfSense may handle this differently + }) + + return peerModel, nil +} + +func (c *PfsenseController) SaveInterface( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), +) error { + // Lock the interface to prevent concurrent modifications + mutex := c.getInterfaceMutex(id) + mutex.Lock() + defer mutex.Unlock() + + physicalInterface, err := c.getOrCreateInterface(ctx, id) + if err != nil { + return err + } + + deviceId := "" + if physicalInterface.GetExtras() != nil { + if extras, ok := physicalInterface.GetExtras().(domain.PfsenseInterfaceExtras); ok { + deviceId = extras.Id + } + } + + if updateFunc != nil { + physicalInterface, err = updateFunc(physicalInterface) + if err != nil { + return err + } + if deviceId != "" { + // Ensure the ID is preserved + if extras, ok := physicalInterface.GetExtras().(domain.PfsenseInterfaceExtras); ok { + extras.Id = deviceId + physicalInterface.SetExtras(extras) + } + } + } + + if err := c.updateInterface(ctx, physicalInterface); err != nil { + return err + } + + return nil +} + +func (c *PfsenseController) getOrCreateInterface( + ctx context.Context, + id domain.InterfaceIdentifier, +) (*domain.PhysicalInterface, error) { + wgReply := c.client.Query(ctx, "/api/v2/vpn/wireguard/tunnels", &lowlevel.PfsenseRequestOptions{ + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status == lowlevel.PfsenseApiStatusOk && len(wgReply.Data) > 0 { + return c.loadInterfaceData(ctx, wgReply.Data[0]) + } + + // create a new tunnel if it does not exist + // Actual endpoint: POST /api/v2/vpn/wireguard/tunnel (singular) + createReply := c.client.Create(ctx, "/api/v2/vpn/wireguard/tunnel", lowlevel.GenericJsonObject{ + "name": string(id), + }) + if createReply.Status == lowlevel.PfsenseApiStatusOk { + return c.loadInterfaceData(ctx, createReply.Data) + } + + return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error) +} + +func (c *PfsenseController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error { + extras := pi.GetExtras().(domain.PfsenseInterfaceExtras) + interfaceId := extras.Id + + payload := lowlevel.GenericJsonObject{ + "name": string(pi.Identifier), + "description": extras.Comment, + "mtu": strconv.Itoa(pi.Mtu), + "listenport": strconv.Itoa(pi.ListenPort), + "privatekey": pi.KeyPair.PrivateKey, + "disabled": strconv.FormatBool(!pi.DeviceUp), + } + + // Add addresses if present + if len(pi.Addresses) > 0 { + addresses := make([]string, 0, len(pi.Addresses)) + for _, addr := range pi.Addresses { + addresses = append(addresses, addr.String()) + } + payload["addresses"] = strings.Join(addresses, ",") + } + + // Actual endpoint: PATCH /api/v2/vpn/wireguard/tunnel?id={id} + wgReply := c.client.Update(ctx, "/api/v2/vpn/wireguard/tunnel?id="+interfaceId, payload) + if wgReply.Status != lowlevel.PfsenseApiStatusOk { + return fmt.Errorf("failed to update interface %s: %v", pi.Identifier, wgReply.Error) + } + + return nil +} + +func (c *PfsenseController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { + // Lock the interface to prevent concurrent modifications + mutex := c.getInterfaceMutex(id) + mutex.Lock() + defer mutex.Unlock() + + // Find the tunnel ID + wgReply := c.client.Query(ctx, "/api/v2/vpn/wireguard/tunnels", &lowlevel.PfsenseRequestOptions{ + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status != lowlevel.PfsenseApiStatusOk { + return fmt.Errorf("unable to find WireGuard tunnel %s: %v", id, wgReply.Error) + } + if len(wgReply.Data) == 0 { + return nil // tunnel does not exist, nothing to delete + } + + interfaceId := wgReply.Data[0].GetString("id") + // Actual endpoint: DELETE /api/v2/vpn/wireguard/tunnel?id={id} + deleteReply := c.client.Delete(ctx, "/api/v2/vpn/wireguard/tunnel?id="+interfaceId) + if deleteReply.Status != lowlevel.PfsenseApiStatusOk { + return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error) + } + + return nil +} + +func (c *PfsenseController) SavePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), +) error { + // Lock the peer to prevent concurrent modifications + mutex := c.getPeerMutex(id) + mutex.Lock() + defer mutex.Unlock() + + physicalPeer, err := c.getOrCreatePeer(ctx, deviceId, id) + if err != nil { + return err + } + + peerId := "" + if physicalPeer.GetExtras() != nil { + if extras, ok := physicalPeer.GetExtras().(domain.PfsensePeerExtras); ok { + peerId = extras.Id + } + } + + physicalPeer, err = updateFunc(physicalPeer) + if err != nil { + return err + } + if peerId != "" { + // Ensure the ID is preserved + if extras, ok := physicalPeer.GetExtras().(domain.PfsensePeerExtras); ok { + extras.Id = peerId + physicalPeer.SetExtras(extras) + } + } + + if err := c.updatePeer(ctx, deviceId, physicalPeer); err != nil { + return err + } + + return nil +} + +func (c *PfsenseController) getOrCreatePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, +) (*domain.PhysicalPeer, error) { + // Query for peer by publickey and interface (tun field) + // The API uses query parameters like ?publickey=...&tun=... + wgReply := c.client.Query(ctx, "/api/v2/vpn/wireguard/peers", &lowlevel.PfsenseRequestOptions{ + Filters: map[string]string{ + "publickey": string(id), + "tun": string(deviceId), // Use "tun" field name as that's what the API uses + }, + }) + if wgReply.Status == lowlevel.PfsenseApiStatusOk && len(wgReply.Data) > 0 { + slog.Debug("found existing pfSense peer", "peer", id, "interface", deviceId) + existingPeer, err := c.convertWireGuardPeer(wgReply.Data[0]) + if err != nil { + return nil, err + } + return &existingPeer, nil + } + + // create a new peer if it does not exist + // Actual endpoint: POST /api/v2/vpn/wireguard/peer (singular) + slog.Debug("creating new pfSense peer", "peer", id, "interface", deviceId) + createReply := c.client.Create(ctx, "/api/v2/vpn/wireguard/peer", lowlevel.GenericJsonObject{ + "name": fmt.Sprintf("wg-%s", id[0:8]), + "interface": string(deviceId), + "publickey": string(id), + "allowedips": "0.0.0.0/0", // Use 0.0.0.0/0 as default, will be updated by updatePeer + }) + if createReply.Status == lowlevel.PfsenseApiStatusOk { + newPeer, err := c.convertWireGuardPeer(createReply.Data) + if err != nil { + return nil, err + } + slog.Debug("successfully created pfSense peer", "peer", id, "interface", deviceId) + return &newPeer, nil + } + + return nil, fmt.Errorf("failed to create peer %s for interface %s: %v", id, deviceId, createReply.Error) +} + +func (c *PfsenseController) updatePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + pp *domain.PhysicalPeer, +) error { + extras := pp.GetExtras().(domain.PfsensePeerExtras) + peerId := extras.Id + + allowedIPsStr := domain.CidrsToString(pp.AllowedIPs) + + slog.Debug("updating pfSense peer", + "peer", pp.Identifier, + "interface", deviceId, + "allowed-ips", allowedIPsStr, + "allowed-ips-count", len(pp.AllowedIPs), + "disabled", extras.Disabled) + + payload := lowlevel.GenericJsonObject{ + "name": extras.Name, + "description": extras.Comment, + "presharedkey": string(pp.PresharedKey), + "publickey": pp.KeyPair.PublicKey, + "privatekey": pp.KeyPair.PrivateKey, + "persistentkeepalive": strconv.Itoa(pp.PersistentKeepalive), + "disabled": strconv.FormatBool(extras.Disabled), + "allowedips": allowedIPsStr, + } + + if pp.Endpoint != "" { + payload["endpoint"] = pp.Endpoint + } + + // Actual endpoint: PATCH /api/v2/vpn/wireguard/peer?id={id} + wgReply := c.client.Update(ctx, "/api/v2/vpn/wireguard/peer?id="+peerId, payload) + if wgReply.Status != lowlevel.PfsenseApiStatusOk { + return fmt.Errorf("failed to update peer %s on interface %s: %v", pp.Identifier, deviceId, wgReply.Error) + } + + if extras.Disabled { + slog.Debug("successfully disabled pfSense peer", "peer", pp.Identifier, "interface", deviceId) + } else { + slog.Debug("successfully updated pfSense peer", "peer", pp.Identifier, "interface", deviceId) + } + + return nil +} + +func (c *PfsenseController) DeletePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, +) error { + // Lock the peer to prevent concurrent modifications + mutex := c.getPeerMutex(id) + mutex.Lock() + defer mutex.Unlock() + + // Query for peer by publickey and interface (tun field) + // The API uses query parameters like ?publickey=...&tun=... + wgReply := c.client.Query(ctx, "/api/v2/vpn/wireguard/peers", &lowlevel.PfsenseRequestOptions{ + Filters: map[string]string{ + "publickey": string(id), + "tun": string(deviceId), // Use "tun" field name as that's what the API uses + }, + }) + if wgReply.Status != lowlevel.PfsenseApiStatusOk { + return fmt.Errorf("unable to find WireGuard peer %s for interface %s: %v", id, deviceId, wgReply.Error) + } + if len(wgReply.Data) == 0 { + return nil // peer does not exist, nothing to delete + } + + peerId := wgReply.Data[0].GetString("id") + // Actual endpoint: DELETE /api/v2/vpn/wireguard/peer?id={id} + deleteReply := c.client.Delete(ctx, "/api/v2/vpn/wireguard/peer?id="+peerId) + if deleteReply.Status != lowlevel.PfsenseApiStatusOk { + return fmt.Errorf("failed to delete WireGuard peer %s for interface %s: %v", id, deviceId, deleteReply.Error) + } + + return nil +} + +// endregion wireguard-related + +// region wg-quick-related + +func (c *PfsenseController) ExecuteInterfaceHook( + _ context.Context, + _ domain.InterfaceIdentifier, + _ string, +) error { + // TODO implement me + slog.Error("interface hooks are not yet supported for pfSense backends, please open an issue on GitHub") + return nil +} + +func (c *PfsenseController) SetDNS( + ctx context.Context, + _ domain.InterfaceIdentifier, + dnsStr, _ string, +) error { + // Lock the interface to prevent concurrent modifications + c.coreMutex.Lock() + defer c.coreMutex.Unlock() + + // pfSense DNS configuration is typically managed at the system level + // This may need to be implemented based on pfSense API capabilities + slog.Warn("DNS setting is not yet fully supported for pfSense backends") + return nil +} + +func (c *PfsenseController) UnsetDNS( + ctx context.Context, + _ domain.InterfaceIdentifier, + dnsStr, _ string, +) error { + // Lock the interface to prevent concurrent modifications + c.coreMutex.Lock() + defer c.coreMutex.Unlock() + + // pfSense DNS configuration is typically managed at the system level + slog.Warn("DNS unsetting is not yet fully supported for pfSense backends") + return nil +} + +// endregion wg-quick-related + +// region routing-related + +func (c *PfsenseController) SetRoutes(_ context.Context, info domain.RoutingTableInfo) error { + // pfSense routing is typically managed through the firewall rules and routing tables + // This may need to be implemented based on pfSense API capabilities + slog.Warn("route setting is not yet fully supported for pfSense backends") + return nil +} + +func (c *PfsenseController) RemoveRoutes(_ context.Context, info domain.RoutingTableInfo) error { + // pfSense routing is typically managed through the firewall rules and routing tables + slog.Warn("route removal is not yet fully supported for pfSense backends") + return nil +} + +// endregion routing-related + +// region statistics-related + +func (c *PfsenseController) PingAddresses( + ctx context.Context, + addr string, +) (*domain.PingerResult, error) { + // Use pfSense API to ping if available, otherwise return error + // This may need to be implemented based on pfSense API capabilities + return nil, fmt.Errorf("ping functionality is not yet implemented for pfSense backends") +} + +// endregion statistics-related + diff --git a/internal/app/wireguard/controller_manager.go b/internal/app/wireguard/controller_manager.go index 0f6bd23..c8be1ba 100644 --- a/internal/app/wireguard/controller_manager.go +++ b/internal/app/wireguard/controller_manager.go @@ -44,6 +44,10 @@ func (c *ControllerManager) init() error { return err } + if err := c.registerPfsenseControllers(); err != nil { + return err + } + c.logRegisteredControllers() return nil @@ -86,6 +90,26 @@ func (c *ControllerManager) registerMikrotikControllers() error { return nil } +func (c *ControllerManager) registerPfsenseControllers() error { + for _, backendConfig := range c.cfg.Backend.Pfsense { + if backendConfig.Id == config.LocalBackendName { + slog.Warn("skipping registration of pfSense controller with reserved ID", "id", config.LocalBackendName) + continue + } + + controller, err := wgcontroller.NewPfsenseController(c.cfg, &backendConfig) + if err != nil { + return fmt.Errorf("failed to create pfSense controller for backend %s: %w", backendConfig.Id, err) + } + + c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{ + Config: backendConfig.BackendBase, + Implementation: controller, + } + } + return nil +} + func (c *ControllerManager) logRegisteredControllers() { for backend, controller := range c.controllers { slog.Debug("backend controller registered", diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 3dbe53f..6f1f32b 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -7,6 +7,7 @@ import ( "log/slog" "os" "slices" + "strings" "time" "github.com/h44z/wg-portal/internal/app" @@ -867,6 +868,17 @@ func (m Manager) importInterface( iface.Backend = backend.GetId() iface.PeerDefAllowedIPsStr = iface.AddressStr() + // For pfSense backends, extract endpoint and DNS from peers + if backend.GetId() == domain.ControllerTypePfsense { + endpoint, dns := extractPfsenseDefaultsFromPeers(peers, iface.ListenPort) + if endpoint != "" { + iface.PeerDefEndpoint = endpoint + } + if dns != "" { + iface.PeerDefDnsStr = dns + } + } + // try to predict the interface type based on the number of peers switch len(peers) { case 0: @@ -904,6 +916,61 @@ func (m Manager) importInterface( return nil } +// extractPfsenseDefaultsFromPeers extracts common endpoint and DNS information from peers +// For server interfaces, peers typically have endpoints pointing to the server, so we use the most common one +func extractPfsenseDefaultsFromPeers(peers []domain.PhysicalPeer, listenPort int) (endpoint, dns string) { + if len(peers) == 0 { + return "", "" + } + + // Count endpoint occurrences to find the most common one + endpointCounts := make(map[string]int) + dnsValues := make(map[string]int) + + for _, peer := range peers { + // Extract endpoint from peer + if peer.Endpoint != "" { + endpointCounts[peer.Endpoint]++ + } + + // Extract DNS from peer extras if available + if extras := peer.GetExtras(); extras != nil { + if pfsenseExtras, ok := extras.(domain.PfsensePeerExtras); ok { + if pfsenseExtras.ClientDns != "" { + dnsValues[pfsenseExtras.ClientDns]++ + } + } + } + } + + // Find the most common endpoint + maxCount := 0 + for ep, count := range endpointCounts { + if count > maxCount { + maxCount = count + endpoint = ep + } + } + + // If endpoint doesn't have a port and we have a listenPort, add it + if endpoint != "" && listenPort > 0 { + if !strings.Contains(endpoint, ":") { + endpoint = fmt.Sprintf("%s:%d", endpoint, listenPort) + } + } + + // Find the most common DNS + maxDnsCount := 0 + for dnsVal, count := range dnsValues { + if count > maxDnsCount { + maxDnsCount = count + dns = dnsVal + } + } + + return endpoint, dns +} + func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain.PhysicalPeer) error { now := time.Now() peer := domain.ConvertPhysicalPeer(p) diff --git a/internal/config/backend.go b/internal/config/backend.go index cee7c8b..c02058f 100644 --- a/internal/config/backend.go +++ b/internal/config/backend.go @@ -18,6 +18,7 @@ type Backend struct { // External Backend-specific configuration Mikrotik []BackendMikrotik `yaml:"mikrotik"` + Pfsense []BackendPfsense `yaml:"pfsense"` } // Validate checks the backend configuration for errors. @@ -36,6 +37,15 @@ func (b *Backend) Validate() error { } uniqueMap[backend.Id] = struct{}{} } + for _, backend := range b.Pfsense { + if backend.Id == LocalBackendName { + return fmt.Errorf("backend ID %q is a reserved keyword", LocalBackendName) + } + if _, exists := uniqueMap[backend.Id]; exists { + return fmt.Errorf("backend ID %q is not unique", backend.Id) + } + uniqueMap[backend.Id] = struct{}{} + } if b.Default != LocalBackendName { if _, ok := uniqueMap[b.Default]; !ok { @@ -101,3 +111,42 @@ func (b *BackendMikrotik) GetApiTimeout() time.Duration { } return b.ApiTimeout } + +type BackendPfsense struct { + BackendBase `yaml:",inline"` // Embed the base fields + + ApiUrl string `yaml:"api_url"` // The base URL of the pfSense REST API (e.g., "https://pfsense.example.com/api/v2") + ApiKey string `yaml:"api_key"` // API key for authentication (generated in pfSense under 'System' -> 'REST API' -> 'Keys') + ApiVerifyTls bool `yaml:"api_verify_tls"` // Whether to verify the TLS certificate of the pfSense API + ApiTimeout time.Duration `yaml:"api_timeout"` // Timeout for API requests (default: 30 seconds) + + // Concurrency controls the maximum number of concurrent API requests that this backend will issue + // when enumerating interfaces and their details. If 0 or negative, a default of 5 is used. + Concurrency int `yaml:"concurrency"` + + Debug bool `yaml:"debug"` // Enable debug logging for the pfSense backend +} + +// GetConcurrency returns the configured concurrency for this backend or a sane default (5) +// when the configured value is zero or negative. +func (b *BackendPfsense) GetConcurrency() int { + if b == nil { + return 5 + } + if b.Concurrency <= 0 { + return 5 + } + return b.Concurrency +} + +// GetApiTimeout returns the configured API timeout or a sane default (30 seconds) +// when the configured value is zero or negative. +func (b *BackendPfsense) GetApiTimeout() time.Duration { + if b == nil { + return 30 * time.Second + } + if b.ApiTimeout <= 0 { + return 30 * time.Second + } + return b.ApiTimeout +} diff --git a/internal/domain/controller.go b/internal/domain/controller.go index eaefe32..3aec4f0 100644 --- a/internal/domain/controller.go +++ b/internal/domain/controller.go @@ -5,6 +5,7 @@ package domain const ( ControllerTypeMikrotik = "mikrotik" ControllerTypeLocal = "wgctrl" + ControllerTypePfsense = "pfsense" ) // Controller extras can be used to store additional information available for specific controllers only. @@ -30,3 +31,20 @@ type MikrotikPeerExtras struct { type LocalPeerExtras struct { Disabled bool } + +type PfsenseInterfaceExtras struct { + Id string // internal pfSense ID + Comment string + Disabled bool +} + +type PfsensePeerExtras struct { + Id string // internal pfSense ID + Name string + Comment string + Disabled bool + ClientEndpoint string + ClientAddress string + ClientDns string + ClientKeepalive int +} diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 01c720c..b71fe16 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -240,7 +240,8 @@ func (p *PhysicalInterface) GetExtras() any { func (p *PhysicalInterface) SetExtras(extras any) { switch extras.(type) { case MikrotikInterfaceExtras: // OK - default: // we only support MikrotikInterfaceExtras for now + case PfsenseInterfaceExtras: // OK + default: // we only support MikrotikInterfaceExtras and PfsenseInterfaceExtras for now panic(fmt.Sprintf("unsupported interface backend extras type %T", extras)) } @@ -303,6 +304,14 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface { } else { iface.Disabled = nil } + case ControllerTypePfsense: + extras := pi.GetExtras().(PfsenseInterfaceExtras) + iface.DisplayName = extras.Comment + if extras.Disabled { + iface.Disabled = &now + } else { + iface.Disabled = nil + } } return iface @@ -325,6 +334,12 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) { Disabled: i.IsDisabled(), } pi.SetExtras(extras) + case ControllerTypePfsense: + extras := PfsenseInterfaceExtras{ + Comment: i.DisplayName, + Disabled: i.IsDisabled(), + } + pi.SetExtras(extras) } } diff --git a/internal/domain/peer.go b/internal/domain/peer.go index 5fb0eb1..b265c20 100644 --- a/internal/domain/peer.go +++ b/internal/domain/peer.go @@ -240,7 +240,8 @@ func (p *PhysicalPeer) SetExtras(extras any) { switch extras.(type) { case MikrotikPeerExtras: // OK case LocalPeerExtras: // OK - default: // we only support MikrotikPeerExtras and LocalPeerExtras for now + case PfsensePeerExtras: // OK + default: // we only support MikrotikPeerExtras, LocalPeerExtras, and PfsensePeerExtras for now panic(fmt.Sprintf("unsupported peer backend extras type %T", extras)) } @@ -301,6 +302,26 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer { peer.Disabled = nil peer.DisabledReason = "" } + case ControllerTypePfsense: + extras := pp.GetExtras().(PfsensePeerExtras) + peer.Notes = extras.Comment + peer.DisplayName = extras.Name + if extras.ClientEndpoint != "" { // if the client endpoint is set, we assume that this is a client peer + peer.Endpoint = NewConfigOption(extras.ClientEndpoint, true) + peer.Interface.Type = InterfaceTypeClient + peer.Interface.Addresses, _ = CidrsFromString(extras.ClientAddress) + peer.Interface.DnsStr = NewConfigOption(extras.ClientDns, true) + peer.PersistentKeepalive = NewConfigOption(extras.ClientKeepalive, true) + } else { + peer.Interface.Type = InterfaceTypeServer + } + if extras.Disabled { + peer.Disabled = &now + peer.DisabledReason = "Disabled by pfSense controller" + } else { + peer.Disabled = nil + peer.DisabledReason = "" + } } return peer @@ -355,6 +376,18 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) { Disabled: p.IsDisabled(), } pp.SetExtras(extras) + case ControllerTypePfsense: + extras := PfsensePeerExtras{ + Id: "", + Name: p.DisplayName, + Comment: p.Notes, + Disabled: p.IsDisabled(), + ClientEndpoint: p.Endpoint.GetValue(), + ClientAddress: CidrsToString(p.Interface.Addresses), + ClientDns: p.Interface.DnsStr.GetValue(), + ClientKeepalive: p.PersistentKeepalive.GetValue(), + } + pp.SetExtras(extras) } } diff --git a/internal/lowlevel/pfsense.go b/internal/lowlevel/pfsense.go new file mode 100644 index 0000000..e58471a --- /dev/null +++ b/internal/lowlevel/pfsense.go @@ -0,0 +1,428 @@ +package lowlevel + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "time" + + "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/config" +) + +// PfsenseApiClient provides HTTP client functionality for interacting with the pfSense REST API. +// Documentation: https://pfrest.org/ +// Swagger UI: https://pfrest.org/api-docs/ + +// region models + +const ( + PfsenseApiStatusOk = "ok" // pfSense REST API uses "ok" in response + PfsenseApiStatusError = "error" +) + +const ( + PfsenseApiErrorCodeUnknown = iota + 700 + PfsenseApiErrorCodeRequestPreparationFailed + PfsenseApiErrorCodeRequestFailed + PfsenseApiErrorCodeResponseDecodeFailed +) + +type PfsenseApiResponse[T any] struct { + Status string + Code int + Data T `json:"data,omitempty"` + Error *PfsenseApiError `json:"error,omitempty"` +} + +type PfsenseApiError struct { + Code int `json:"error,omitempty"` + Message string `json:"message,omitempty"` + Details string `json:"detail,omitempty"` +} + +func (e *PfsenseApiError) String() string { + if e == nil { + return "no error" + } + return fmt.Sprintf("API error %d: %s - %s", e.Code, e.Message, e.Details) +} + +type PfsenseRequestOptions struct { + Filters map[string]string `json:"filters,omitempty"` + PropList []string `json:"proplist,omitempty"` +} + +func (o *PfsenseRequestOptions) GetPath(base string) string { + if o == nil { + return base + } + + path, err := url.Parse(base) + if err != nil { + return base + } + + query := path.Query() + // pfSense REST API uses standard query parameters for filtering + for k, v := range o.Filters { + query.Set(k, v) + } + // Note: PropList may not be supported by pfSense REST API in the same way as Mikrotik + // pfSense typically returns all fields by default, but we keep this for potential future use + // Verify the correct parameter name in Swagger docs if field selection is needed + if len(o.PropList) > 0 { + // pfSense might use different parameter name - verify in Swagger docs + // For now, we'll skip it as pfSense may return all fields by default + // query.Set("fields", strings.Join(o.PropList, ",")) + } + path.RawQuery = query.Encode() + return path.String() +} + +// endregion models + +// region API-client + +type PfsenseApiClient struct { + coreCfg *config.Config + cfg *config.BackendPfsense + + client *http.Client + log *slog.Logger +} + +func NewPfsenseApiClient(coreCfg *config.Config, cfg *config.BackendPfsense) (*PfsenseApiClient, error) { + c := &PfsenseApiClient{ + coreCfg: coreCfg, + cfg: cfg, + } + + err := c.setup() + if err != nil { + return nil, err + } + + c.debugLog("pfSense api client created", "api_url", cfg.ApiUrl) + + return c, nil +} + +func (p *PfsenseApiClient) setup() error { + p.client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: !p.cfg.ApiVerifyTls, + }, + }, + Timeout: p.cfg.GetApiTimeout(), + } + + if p.cfg.Debug { + p.log = slog.New(internal.GetLoggingHandler("debug", + p.coreCfg.Advanced.LogPretty, + p.coreCfg.Advanced.LogJson). + WithAttrs([]slog.Attr{ + { + Key: "pfsense-bid", Value: slog.StringValue(p.cfg.Id), + }, + })) + } + + return nil +} + +func (p *PfsenseApiClient) debugLog(msg string, args ...any) { + if p.log != nil { + p.log.Debug("[PFS-API] "+msg, args...) + } +} + +func (p *PfsenseApiClient) getFullPath(command string) string { + path, err := url.JoinPath(p.cfg.ApiUrl, command) + if err != nil { + return "" + } + return path +} + +func (p *PfsenseApiClient) prepareGetRequest(ctx context.Context, fullUrl string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullUrl, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + if p.cfg.ApiKey != "" { + // pfSense REST API API Key authentication (https://pfrest.org/AUTHENTICATION_AND_AUTHORIZATION/) + // Uses X-API-Key header for API key authentication + req.Header.Set("X-API-Key", p.cfg.ApiKey) + } + + return req, nil +} + +func (p *PfsenseApiClient) prepareDeleteRequest(ctx context.Context, fullUrl string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, fullUrl, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + if p.cfg.ApiKey != "" { + // pfSense REST API API Key authentication (https://pfrest.org/AUTHENTICATION_AND_AUTHORIZATION/) + // Uses X-API-Key header for API key authentication + req.Header.Set("X-API-Key", p.cfg.ApiKey) + } + + return req, nil +} + +func (p *PfsenseApiClient) preparePayloadRequest( + ctx context.Context, + method string, + fullUrl string, + payload GenericJsonObject, +) (*http.Request, error) { + // marshal the payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, method, fullUrl, bytes.NewReader(payloadBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + if p.cfg.ApiKey != "" { + // pfSense REST API API Key authentication (https://pfrest.org/AUTHENTICATION_AND_AUTHORIZATION/) + // Uses X-API-Key header for API key authentication + req.Header.Set("X-API-Key", p.cfg.ApiKey) + } + + return req, nil +} + +func errToPfsenseApiResponse[T any](code int, message string, err error) PfsenseApiResponse[T] { + return PfsenseApiResponse[T]{ + Status: PfsenseApiStatusError, + Code: code, + Error: &PfsenseApiError{ + Code: code, + Message: message, + Details: err.Error(), + }, + } +} + +func parsePfsenseHttpResponse[T any](resp *http.Response, err error) PfsenseApiResponse[T] { + if err != nil { + return errToPfsenseApiResponse[T](PfsenseApiErrorCodeRequestFailed, "failed to execute request", err) + } + + // pfSense REST API wraps responses in {code, status, data} or {code, status, error} structure + var wrapper struct { + Code int `json:"code"` + Status string `json:"status"` + Data T `json:"data,omitempty"` + Error *struct { + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Detail string `json:"detail,omitempty"` + } `json:"error,omitempty"` + } + + // Read the entire body first + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return errToPfsenseApiResponse[T](PfsenseApiErrorCodeResponseDecodeFailed, "failed to read response body", err) + } + + // Close the body after reading + defer func() { + if err := resp.Body.Close(); err != nil { + slog.Error("failed to close response body", "error", err) + } + }() + + if len(bodyBytes) == 0 { + // Empty response for DELETE operations + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return PfsenseApiResponse[T]{Status: PfsenseApiStatusOk, Code: resp.StatusCode} + } + return errToPfsenseApiResponse[T](resp.StatusCode, "empty error response", fmt.Errorf("HTTP %d", resp.StatusCode)) + } + + if err := json.Unmarshal(bodyBytes, &wrapper); err != nil { + // Log the actual response for debugging when JSON parsing fails + contentType := resp.Header.Get("Content-Type") + bodyPreview := string(bodyBytes) + if len(bodyPreview) > 500 { + bodyPreview = bodyPreview[:500] + "..." + } + slog.Error("failed to decode pfSense API response", + "status_code", resp.StatusCode, + "content_type", contentType, + "url", resp.Request.URL.String(), + "method", resp.Request.Method, + "body_preview", bodyPreview, + "error", err) + return errToPfsenseApiResponse[T](PfsenseApiErrorCodeResponseDecodeFailed, + fmt.Sprintf("failed to decode response (status %d, content-type: %s): %v", resp.StatusCode, contentType, err), err) + } + + // Check if response indicates success + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + // Map pfSense status to our status + status := PfsenseApiStatusOk + if wrapper.Status != "ok" && wrapper.Status != "success" { + status = PfsenseApiStatusError + } + + // Handle EmptyResponse type + if _, ok := any(wrapper.Data).(EmptyResponse); ok { + return PfsenseApiResponse[T]{Status: status, Code: wrapper.Code} + } + + return PfsenseApiResponse[T]{Status: status, Code: wrapper.Code, Data: wrapper.Data} + } + + // Handle error response + if wrapper.Error != nil { + return PfsenseApiResponse[T]{ + Status: PfsenseApiStatusError, + Code: wrapper.Code, + Error: &PfsenseApiError{ + Code: wrapper.Error.Code, + Message: wrapper.Error.Message, + Details: wrapper.Error.Detail, + }, + } + } + + // Fallback error response + return errToPfsenseApiResponse[T](wrapper.Code, "unknown error", fmt.Errorf("HTTP %d: %s", wrapper.Code, wrapper.Status)) +} + +func (p *PfsenseApiClient) Query( + ctx context.Context, + command string, + opts *PfsenseRequestOptions, +) PfsenseApiResponse[[]GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, p.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := opts.GetPath(p.getFullPath(command)) + + req, err := p.prepareGetRequest(apiCtx, fullUrl) + if err != nil { + return errToPfsenseApiResponse[[]GenericJsonObject](PfsenseApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + p.debugLog("executing API query", "url", fullUrl) + response := parsePfsenseHttpResponse[[]GenericJsonObject](p.client.Do(req)) + p.debugLog("retrieved API query result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (p *PfsenseApiClient) Get( + ctx context.Context, + command string, + opts *PfsenseRequestOptions, +) PfsenseApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, p.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := opts.GetPath(p.getFullPath(command)) + + req, err := p.prepareGetRequest(apiCtx, fullUrl) + if err != nil { + return errToPfsenseApiResponse[GenericJsonObject](PfsenseApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + p.debugLog("executing API get", "url", fullUrl) + response := parsePfsenseHttpResponse[GenericJsonObject](p.client.Do(req)) + p.debugLog("retrieved API get result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (p *PfsenseApiClient) Create( + ctx context.Context, + command string, + payload GenericJsonObject, +) PfsenseApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, p.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := p.getFullPath(command) + + req, err := p.preparePayloadRequest(apiCtx, http.MethodPost, fullUrl, payload) + if err != nil { + return errToPfsenseApiResponse[GenericJsonObject](PfsenseApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + p.debugLog("executing API post", "url", fullUrl) + response := parsePfsenseHttpResponse[GenericJsonObject](p.client.Do(req)) + p.debugLog("retrieved API post result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (p *PfsenseApiClient) Update( + ctx context.Context, + command string, + payload GenericJsonObject, +) PfsenseApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, p.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := p.getFullPath(command) + + req, err := p.preparePayloadRequest(apiCtx, http.MethodPatch, fullUrl, payload) + if err != nil { + return errToPfsenseApiResponse[GenericJsonObject](PfsenseApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + p.debugLog("executing API patch", "url", fullUrl) + response := parsePfsenseHttpResponse[GenericJsonObject](p.client.Do(req)) + p.debugLog("retrieved API patch result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (p *PfsenseApiClient) Delete( + ctx context.Context, + command string, +) PfsenseApiResponse[EmptyResponse] { + apiCtx, cancel := context.WithTimeout(ctx, p.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := p.getFullPath(command) + + req, err := p.prepareDeleteRequest(apiCtx, fullUrl) + if err != nil { + return errToPfsenseApiResponse[EmptyResponse](PfsenseApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + p.debugLog("executing API delete", "url", fullUrl) + response := parsePfsenseHttpResponse[EmptyResponse](p.client.Do(req)) + p.debugLog("retrieved API delete result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +// endregion API-client +