diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go index 2e15e78..aff49c9 100644 --- a/internal/adapters/wgcontroller/mikrotik.go +++ b/internal/adapters/wgcontroller/mikrotik.go @@ -3,6 +3,8 @@ package wgcontroller import ( "context" "fmt" + "slices" + "strconv" "time" "github.com/h44z/wg-portal/internal/config" @@ -97,9 +99,26 @@ func (c MikrotikController) loadInterfaceData( return nil, fmt.Errorf("failed to query interface %s: %v", deviceId, ifaceReply.Error) } + ipv4, ipv6, err := c.loadIpAddresses(ctx, deviceName) + if err != nil { + return nil, fmt.Errorf("failed to query IP addresses for interface %s: %v", deviceId, err) + } + addresses := c.convertIpAddresses(ipv4, ipv6) + + interfaceModel, err := c.convertWireGuardInterface(wireGuardObj, ifaceReply.Data, addresses) + if err != nil { + return nil, fmt.Errorf("interface convert failed for %s: %w", deviceName, err) + } + return &interfaceModel, nil +} + +func (c MikrotikController) loadIpAddresses( + ctx context.Context, + deviceName string, +) (ipv4 []lowlevel.GenericJsonObject, ipv6 []lowlevel.GenericJsonObject, err error) { addrV4Reply := c.client.Query(ctx, "/ip/address", &lowlevel.MikrotikRequestOptions{ PropList: []string{ - "address", "network", + ".id", "address", "network", }, Filters: map[string]string{ "interface": deviceName, @@ -108,12 +127,13 @@ func (c MikrotikController) loadInterfaceData( }, }) if addrV4Reply.Status != lowlevel.MikrotikApiStatusOk { - return nil, fmt.Errorf("failed to query IPv4 addresses for interface %s: %v", deviceId, addrV4Reply.Error) + return nil, nil, fmt.Errorf("failed to query IPv4 addresses for interface %s: %v", deviceName, + addrV4Reply.Error) } addrV6Reply := c.client.Query(ctx, "/ipv6/address", &lowlevel.MikrotikRequestOptions{ PropList: []string{ - "address", "network", + ".id", "address", "network", }, Filters: map[string]string{ "interface": deviceName, @@ -122,26 +142,16 @@ func (c MikrotikController) loadInterfaceData( }, }) if addrV6Reply.Status != lowlevel.MikrotikApiStatusOk { - return nil, fmt.Errorf("failed to query IPv6 addresses for interface %s: %v", deviceId, addrV6Reply.Error) + return nil, nil, fmt.Errorf("failed to query IPv6 addresses for interface %s: %v", deviceName, + addrV6Reply.Error) } - interfaceModel, err := c.convertWireGuardInterface(wireGuardObj, ifaceReply.Data, addrV4Reply.Data, - addrV6Reply.Data) - if err != nil { - return nil, fmt.Errorf("interface convert failed for %s: %w", deviceName, err) - } - return &interfaceModel, nil + return addrV4Reply.Data, addrV6Reply.Data, nil } -func (c MikrotikController) convertWireGuardInterface( - wg, iface lowlevel.GenericJsonObject, +func (c MikrotikController) convertIpAddresses( ipv4, ipv6 []lowlevel.GenericJsonObject, -) ( - domain.PhysicalInterface, - error, -) { - // read data from wgctrl interface - +) []domain.Cidr { addresses := make([]domain.Cidr, 0, len(ipv4)+len(ipv6)) for _, addr := range append(ipv4, ipv6...) { addrStr := addr.GetString("address") @@ -156,6 +166,16 @@ func (c MikrotikController) convertWireGuardInterface( addresses = append(addresses, cidr) } + return addresses +} + +func (c MikrotikController) convertWireGuardInterface( + wg, iface lowlevel.GenericJsonObject, + addresses []domain.Cidr, +) ( + domain.PhysicalInterface, + error, +) { pi := domain.PhysicalInterface{ Identifier: domain.InterfaceIdentifier(wg.GetString("name")), KeyPair: domain.KeyPair{ @@ -174,6 +194,7 @@ func (c MikrotikController) convertWireGuardInterface( } pi.SetExtras(domain.MikrotikInterfaceExtras{ + Id: wg.GetString(".id"), Comment: wg.GetString("comment"), Disabled: wg.GetBool("disabled"), }) @@ -270,16 +291,202 @@ func (c MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject } func (c MikrotikController) SaveInterface( - _ context.Context, + ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), ) error { - // TODO implement me + physicalInterface, err := c.getOrCreateInterface(ctx, id) + if err != nil { + return err + } + + deviceId := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras).Id + if updateFunc != nil { + physicalInterface, err = updateFunc(physicalInterface) + if err != nil { + return err + } + newExtras := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras) + newExtras.Id = deviceId // ensure the ID is not changed + physicalInterface.SetExtras(newExtras) + } + + if err := c.updateInterface(ctx, physicalInterface); err != nil { + return err + } + return nil } -func (c MikrotikController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { - // TODO implement me +func (c MikrotikController) getOrCreateInterface( + ctx context.Context, + id domain.InterfaceIdentifier, +) (*domain.PhysicalInterface, error) { + wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", + }, + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 { + return c.loadInterfaceData(ctx, wgReply.Data[0]) + } + + // create a new interface if it does not exist + createReply := c.client.Create(ctx, "/interface/wireguard", lowlevel.GenericJsonObject{ + "name": string(id), + }) + if wgReply.Status == lowlevel.MikrotikApiStatusOk { + return c.loadInterfaceData(ctx, createReply.Data) + } + + return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error) +} + +func (c MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error { + extras := pi.GetExtras().(domain.MikrotikInterfaceExtras) + interfaceId := extras.Id + wgReply := c.client.Update(ctx, "/interface/wireguard/"+interfaceId, lowlevel.GenericJsonObject{ + "name": pi.Identifier, + "comment": extras.Comment, + "mtu": strconv.Itoa(pi.Mtu), + "listen-port": strconv.Itoa(pi.ListenPort), + "private-key": pi.KeyPair.PrivateKey, + "disabled": strconv.FormatBool(!pi.DeviceUp), + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to update interface %s: %v", pi.Identifier, wgReply.Error) + } + + // update the interface's addresses + currentV4, currentV6, err := c.loadIpAddresses(ctx, string(pi.Identifier)) + if err != nil { + return fmt.Errorf("failed to load current addresses for interface %s: %v", pi.Identifier, err) + } + currentAddresses := c.convertIpAddresses(currentV4, currentV6) + + // get all addresses that are currently not in the interface, only in pi + newAddresses := make([]domain.Cidr, 0, len(pi.Addresses)) + for _, addr := range pi.Addresses { + if slices.Contains(currentAddresses, addr) { + continue + } + newAddresses = append(newAddresses, addr) + } + // get obsolete addresses that are in the interface, but not in pi + obsoleteAddresses := make([]domain.Cidr, 0, len(currentAddresses)) + for _, addr := range currentAddresses { + if slices.Contains(pi.Addresses, addr) { + continue + } + obsoleteAddresses = append(obsoleteAddresses, addr) + } + + // update the IP addresses for the interface + if err := c.updateIpAddresses(ctx, string(pi.Identifier), currentV4, currentV6, + newAddresses, obsoleteAddresses); err != nil { + return fmt.Errorf("failed to update IP addresses for interface %s: %v", pi.Identifier, err) + } + + return nil +} + +func (c MikrotikController) updateIpAddresses( + ctx context.Context, + deviceName string, + currentV4, currentV6 []lowlevel.GenericJsonObject, + new, obsolete []domain.Cidr, +) error { + // first, delete all obsolete addresses + for _, addr := range obsolete { + // find ID of the address to delete + if addr.IsV4() { + for _, a := range currentV4 { + if a.GetString("address") == addr.String() { + // delete the address + reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete obsolete IPv4 address %s: %v", addr, reply.Error) + } + break + } + } + } else { + for _, a := range currentV6 { + if a.GetString("address") == addr.String() { + // delete the address + reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete obsolete IPv6 address %s: %v", addr, reply.Error) + } + break + } + } + } + } + + // then, add all new addresses + for _, addr := range new { + var createPath string + if addr.IsV4() { + createPath = "/ip/address" + } else { + createPath = "/ipv6/address" + } + + // create the address + reply := c.client.Create(ctx, createPath, lowlevel.GenericJsonObject{ + "address": addr.String(), + "interface": deviceName, + }) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to create new address %s: %v", addr, reply.Error) + } + } + + return nil +} + +func (c MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { + // delete the interface's addresses + currentV4, currentV6, err := c.loadIpAddresses(ctx, string(id)) + if err != nil { + return fmt.Errorf("failed to load current addresses for interface %s: %v", id, err) + } + for _, a := range currentV4 { + // delete the address + reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete IPv4 address %s: %v", a.GetString("address"), reply.Error) + } + } + for _, a := range currentV6 { + // delete the address + reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete IPv6 address %s: %v", a.GetString("address"), reply.Error) + } + } + + // delete the WireGuard interface + wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{ + PropList: []string{".id"}, + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk || len(wgReply.Data) == 0 { + return fmt.Errorf("unable to find WireGuard interface %s: %v", id, wgReply.Error) + } + + deviceId := wgReply.Data[0].GetString(".id") + deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+deviceId) + if deleteReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error) + } + return nil } diff --git a/internal/domain/controller.go b/internal/domain/controller.go index 26d8b07..791741a 100644 --- a/internal/domain/controller.go +++ b/internal/domain/controller.go @@ -10,6 +10,7 @@ const ( // Controller extras can be used to store additional information available for specific controllers only. type MikrotikInterfaceExtras struct { + Id string // internal mikrotik ID Comment string Disabled bool } diff --git a/internal/domain/interface.go b/internal/domain/interface.go index cb57c50..32fc1c0 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -227,6 +227,11 @@ func (p *PhysicalInterface) SetExtras(extras any) { } func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface { + networks := make([]Cidr, 0, len(pi.Addresses)) + for _, addr := range pi.Addresses { + networks = append(networks, addr.NetworkAddr()) + } + // create a new basic interface with the data from the physical interface iface := &Interface{ Identifier: pi.Identifier, @@ -247,11 +252,11 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface { Type: InterfaceTypeAny, DriverType: pi.DeviceType, Disabled: nil, - PeerDefNetworkStr: "", + PeerDefNetworkStr: CidrsToString(networks), PeerDefDnsStr: "", PeerDefDnsSearchStr: "", PeerDefEndpoint: "", - PeerDefAllowedIPsStr: "", + PeerDefAllowedIPsStr: CidrsToString(networks), PeerDefMtu: pi.Mtu, PeerDefPersistentKeepalive: 0, PeerDefFirewallMark: 0, @@ -291,6 +296,15 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) { pi.FirewallMark = i.FirewallMark pi.DeviceUp = !i.IsDisabled() pi.Addresses = i.Addresses + + switch pi.ImportSource { + case ControllerTypeMikrotik: + extras := MikrotikInterfaceExtras{ + Comment: i.DisplayName, + Disabled: i.IsDisabled(), + } + pi.SetExtras(extras) + } } type RoutingTableInfo struct { diff --git a/internal/lowlevel/mikrotik.go b/internal/lowlevel/mikrotik.go index f0bdaf8..bcad3e5 100644 --- a/internal/lowlevel/mikrotik.go +++ b/internal/lowlevel/mikrotik.go @@ -45,11 +45,15 @@ type MikrotikApiError struct { Details string `json:"details,omitempty"` } -func (e MikrotikApiError) String() string { +func (e *MikrotikApiError) String() string { + if e == nil { + return "no error" + } return fmt.Sprintf("API error %d: %s - %s", e.Code, e.Message, e.Details) } type GenericJsonObject map[string]any +type EmptyResponse struct{} func (JsonObject GenericJsonObject) GetString(key string) string { if value, ok := JsonObject[key]; ok { @@ -211,8 +215,22 @@ func (m *MikrotikApiClient) prepareGetRequest(ctx context.Context, fullUrl strin return req, nil } -func (m *MikrotikApiClient) preparePostRequest( +func (m *MikrotikApiClient) prepareDeleteRequest(ctx context.Context, fullUrl string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, fullUrl, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" { + req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword) + } + + return req, nil +} + +func (m *MikrotikApiClient) preparePayloadRequest( ctx context.Context, + method string, fullUrl string, payload GenericJsonObject, ) (*http.Request, error) { @@ -222,7 +240,7 @@ func (m *MikrotikApiClient) preparePostRequest( return nil, fmt.Errorf("failed to marshal payload: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullUrl, bytes.NewReader(payloadBytes)) + req, err := http.NewRequestWithContext(ctx, method, fullUrl, bytes.NewReader(payloadBytes)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -261,6 +279,12 @@ func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiRespons if resp.StatusCode >= 200 && resp.StatusCode < 300 { var data T + + // if the type of T is EmptyResponse, we can return an empty response with just the status + if _, ok := any(data).(EmptyResponse); ok { + return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode} + } + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { return errToApiResponse[T](MikrotikApiErrorCodeResponseDecodeFailed, "failed to decode response", err) } @@ -321,6 +345,74 @@ func (m *MikrotikApiClient) Get( return response } +func (m *MikrotikApiClient) Create( + ctx context.Context, + command string, + payload GenericJsonObject, +) MikrotikApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.ApiTimeout) + defer cancel() + + fullUrl := m.getFullPath(command) + + req, err := m.preparePayloadRequest(apiCtx, http.MethodPut, fullUrl, payload) + if err != nil { + return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API put", "url", fullUrl) + response := parseHttpResponse[GenericJsonObject](m.client.Do(req)) + m.debugLog("retrieved API put result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (m *MikrotikApiClient) Update( + ctx context.Context, + command string, + payload GenericJsonObject, +) MikrotikApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.ApiTimeout) + defer cancel() + + fullUrl := m.getFullPath(command) + + req, err := m.preparePayloadRequest(apiCtx, http.MethodPatch, fullUrl, payload) + if err != nil { + return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API patch", "url", fullUrl) + response := parseHttpResponse[GenericJsonObject](m.client.Do(req)) + m.debugLog("retrieved API patch result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (m *MikrotikApiClient) Delete( + ctx context.Context, + command string, +) MikrotikApiResponse[EmptyResponse] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.ApiTimeout) + defer cancel() + + fullUrl := m.getFullPath(command) + + req, err := m.prepareDeleteRequest(apiCtx, fullUrl) + if err != nil { + return errToApiResponse[EmptyResponse](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API delete", "url", fullUrl) + response := parseHttpResponse[EmptyResponse](m.client.Do(req)) + m.debugLog("retrieved API delete result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + func (m *MikrotikApiClient) ExecList( ctx context.Context, command string, @@ -331,16 +423,16 @@ func (m *MikrotikApiClient) ExecList( fullUrl := m.getFullPath(command) - req, err := m.preparePostRequest(apiCtx, fullUrl, payload) + req, err := m.preparePayloadRequest(apiCtx, http.MethodPost, fullUrl, payload) if err != nil { return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, "failed to create request", err) } start := time.Now() - m.debugLog("executing API get", "url", fullUrl) + m.debugLog("executing API post", "url", fullUrl) response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req)) - m.debugLog("retrieved API get result", "url", fullUrl, "duration", time.Since(start).String()) + m.debugLog("retrieved API post result", "url", fullUrl, "duration", time.Since(start).String()) return response }