refac: some WGDashboard code

This commit is contained in:
Daan Selen
2026-02-07 03:03:35 +01:00
parent 686fc139ec
commit f742f9de49
75 changed files with 427 additions and 337 deletions

View File

@@ -4,19 +4,23 @@ AmneziaWG Configuration
import random, sqlalchemy, os, subprocess, re, uuid
from flask import current_app
from .PeerJobs import PeerJobs
from .AmneziaWGPeer import AmneziaWGPeer
from .AmneziaPeer import AmneziaPeer
from .PeerShareLinks import PeerShareLinks
from .Utilities import RegexMatch
from .WireguardConfiguration import WireguardConfiguration
from .DashboardWebHooks import DashboardWebHooks
class AmneziaWireguardConfiguration(WireguardConfiguration):
def __init__(self, DashboardConfig,
class AmneziaConfiguration(WireguardConfiguration):
def __init__(self,
DashboardConfig,
AllPeerJobs: PeerJobs,
AllPeerShareLinks: PeerShareLinks,
DashboardWebHooks: DashboardWebHooks,
name: str = None, data: dict = None, backup: dict = None, startup: bool = False):
name: str = None,
data: dict = None,
backup: dict = None,
startup: bool = False):
self.Jc = 0
self.Jmin = 0
self.Jmax = 0
@@ -79,58 +83,50 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
}
def createDatabase(self, dbName = None):
def generate_column_obj():
return [
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False, primary_key=True),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('notes', sqlalchemy.Text),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255))
]
if dbName is None:
dbName = self.Name
self.peersTable = sqlalchemy.Table(
dbName, self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False, primary_key=True),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('advanced_security', sqlalchemy.String(255)),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255)),
extend_existing=True
f'{dbName}', self.metadata, *generate_column_obj(), extend_existing=True
)
self.peersRestrictedTable = sqlalchemy.Table(
f'{dbName}_restrict_access', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False, primary_key=True),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('advanced_security', sqlalchemy.String(255)),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255)),
extend_existing=True
f'{dbName}_restrict_access', self.metadata, *generate_column_obj(), extend_existing=True
)
self.peersDeletedTable = sqlalchemy.Table(
f'{dbName}_deleted', self.metadata, *generate_column_obj(), extend_existing=True
)
if self.DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite':
time_col_type = sqlalchemy.DATETIME
else:
time_col_type = sqlalchemy.TIMESTAMP
self.peersTransferTable = sqlalchemy.Table(
f'{dbName}_transfer', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False),
@@ -140,38 +136,7 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('time', (sqlalchemy.DATETIME if self.DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite' else sqlalchemy.TIMESTAMP),
server_default=sqlalchemy.func.now()),
extend_existing=True
)
self.peersDeletedTable = sqlalchemy.Table(
f'{dbName}_deleted', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('advanced_security', sqlalchemy.String(255)),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255)),
extend_existing=True
)
self.infoTable = sqlalchemy.Table(
'ConfigurationsInfo', self.metadata,
sqlalchemy.Column('ID', sqlalchemy.String(255), primary_key=True),
sqlalchemy.Column('Info', sqlalchemy.Text),
sqlalchemy.Column('time', time_col_type, server_default=sqlalchemy.func.now()),
extend_existing=True
)
@@ -179,15 +144,20 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
f'{dbName}_history_endpoint', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False),
sqlalchemy.Column('endpoint', sqlalchemy.String(255), nullable=False),
sqlalchemy.Column('time',
(sqlalchemy.DATETIME if self.DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite' else sqlalchemy.TIMESTAMP)),
sqlalchemy.Column('time', time_col_type)
)
self.infoTable = sqlalchemy.Table(
'ConfigurationsInfo', self.metadata,
sqlalchemy.Column('ID', sqlalchemy.String(255), primary_key=True),
sqlalchemy.Column('Info', sqlalchemy.Text),
extend_existing=True
)
self.metadata.create_all(self.engine)
def getPeers(self):
self.Peers.clear()
self.Peers.clear()
if self.configurationFileChanged():
with open(self.configPath, 'r') as configFile:
p = []
@@ -225,11 +195,9 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
if tempPeer is None:
tempPeer = {
"id": i['PublicKey'],
"advanced_security": i.get('AdvancedSecurity', 'off'),
"private_key": "",
"DNS": self.DashboardConfig.GetConfig("Peers", "peer_global_DNS")[1],
"endpoint_allowed_ip": self.DashboardConfig.GetConfig("Peers", "peer_endpoint_allowed_ip")[
1],
"endpoint_allowed_ip": self.DashboardConfig.GetConfig("Peers", "peer_endpoint_allowed_ip")[1],
"name": i.get("name"),
"total_receive": 0,
"total_sent": 0,
@@ -243,6 +211,7 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
"cumu_data": 0,
"mtu": self.DashboardConfig.GetConfig("Peers", "peer_mtu")[1],
"keepalive": self.DashboardConfig.GetConfig("Peers", "peer_keep_alive")[1],
"notes": "",
"remote_endpoint": self.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1],
"preshared_key": i["PresharedKey"] if "PresharedKey" in i.keys() else ""
}
@@ -257,14 +226,14 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
self.peersTable.columns.id == i['PublicKey']
)
)
self.Peers.append(AmneziaWGPeer(tempPeer, self))
self.Peers.append(AmneziaPeer(tempPeer, self))
except Exception as e:
current_app.logger.error(f"{self.Name} getPeers() Error", e)
else:
with self.engine.connect() as conn:
existingPeers = conn.execute(self.peersTable.select()).mappings().fetchall()
for i in existingPeers:
self.Peers.append(AmneziaWGPeer(i, self))
self.Peers.append(AmneziaPeer(i, self))
def addPeers(self, peers: list) -> tuple[bool, list, str]:
result = {
@@ -292,9 +261,9 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
"cumu_data": 0,
"mtu": i['mtu'],
"keepalive": i['keepalive'],
"notes": i.get('notes', ''),
"remote_endpoint": self.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1],
"preshared_key": i["preshared_key"],
"advanced_security": i['advanced_security']
"preshared_key": i["preshared_key"]
}
conn.execute(
self.peersTable.insert().values(newPeer)
@@ -333,4 +302,4 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
with self.engine.connect() as conn:
restricted = conn.execute(self.peersRestrictedTable.select()).mappings().fetchall()
for i in restricted:
self.RestrictedPeers.append(AmneziaWGPeer(i, self))
self.RestrictedPeers.append(AmneziaPeer(i, self))

View File

@@ -4,65 +4,84 @@ import re
import subprocess
import uuid
from flask import current_app
from .Peer import Peer
from .Utilities import ValidateIPAddressesWithRange, ValidateDNSAddress, GenerateWireguardPublicKey
from .Utilities import CheckAddress, ValidateDNSAddress, GenerateWireguardPublicKey
class AmneziaWGPeer(Peer):
class AmneziaPeer(Peer):
def __init__(self, tableData, configuration):
self.advanced_security = tableData["advanced_security"]
super().__init__(tableData, configuration)
def updatePeer(self, name: str, private_key: str,
preshared_key: str,
dns_addresses: str, allowed_ip: str, endpoint_allowed_ip: str, mtu: int,
keepalive: int, advanced_security: str) -> tuple[bool, str] or tuple[bool, None]:
dns_addresses: str,
allowed_ip: str,
endpoint_allowed_ip: str,
mtu: int,
keepalive: int,
notes: str
) -> tuple[bool, str | None]:
if not self.configuration.getStatus():
self.configuration.toggleConfiguration()
existingAllowedIps = [item for row in list(
map(lambda x: [q.strip() for q in x.split(',')],
map(lambda y: y.allowed_ip,
list(filter(lambda k: k.id != self.id, self.configuration.getPeersList()))))) for item in row]
if allowed_ip in existingAllowedIps:
return False, "Allowed IP already taken by another peer"
if not ValidateIPAddressesWithRange(endpoint_allowed_ip):
# Before we do any compute, let us check if the given endpoint allowed ip is valid at all
if not CheckAddress(endpoint_allowed_ip):
return False, f"Endpoint Allowed IPs format is incorrect"
if len(dns_addresses) > 0 and not ValidateDNSAddress(dns_addresses):
return False, f"DNS format is incorrect"
if type(mtu) is str:
peers = []
for peer in self.configuration.getPeersList():
# Make sure to exclude your own data when updating since its not really relevant
if peer.id != self.id:
continue
peers.append(peer)
used_allowed_ips = []
for peer in peers:
ips = peer.allowed_ip.split(',')
ips = [ip.strip() for ip in ips]
used_allowed_ips.append(ips)
if allowed_ip in used_allowed_ips:
return False, "Allowed IP already taken by another peer"
if not ValidateDNSAddress(dns_addresses):
return False, f"DNS IP-Address or FQDN is incorrect"
if isinstance(mtu, str):
mtu = 0
if type(keepalive) is str:
if isinstance(keepalive, str):
keepalive = 0
if mtu < 0 or mtu > 1460:
if mtu not in range(0, 1461):
return False, "MTU format is not correct"
if keepalive < 0:
return False, "Persistent Keepalive format is not correct"
if advanced_security != "on" and advanced_security != "off":
return False, "Advanced Security can only be on or off"
if len(private_key) > 0:
pubKey = GenerateWireguardPublicKey(private_key)
if not pubKey[0] or pubKey[1] != self.id:
return False, "Private key does not match with the public key"
try:
rd = random.Random()
uid = str(uuid.UUID(int=rd.getrandbits(128), version=4))
pskExist = len(preshared_key) > 0
rand = random.Random()
uid = str(uuid.UUID(int=rand.getrandbits(128), version=4))
psk_exist = len(preshared_key) > 0
if pskExist:
if psk_exist:
with open(uid, "w+") as f:
f.write(preshared_key)
newAllowedIPs = allowed_ip.replace(" ", "")
updateAllowedIp = subprocess.check_output(
f"{self.configuration.Protocol} set {self.configuration.Name} peer {self.id} allowed-ips {newAllowedIPs} {f'preshared-key {uid}' if pskExist else 'preshared-key /dev/null'}",
f"{self.configuration.Protocol} set {self.configuration.Name} peer {self.id} allowed-ips {newAllowedIPs} {f'preshared-key {uid}' if psk_exist else 'preshared-key /dev/null'}",
shell=True, stderr=subprocess.STDOUT)
if pskExist: os.remove(uid)
if psk_exist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
return False, "Update peer failed when updating Allowed IPs"
@@ -80,8 +99,8 @@ class AmneziaWGPeer(Peer):
"endpoint_allowed_ip": endpoint_allowed_ip,
"mtu": mtu,
"keepalive": keepalive,
"preshared_key": preshared_key,
"advanced_security": advanced_security
"notes": notes,
"preshared_key": preshared_key
}).where(
self.configuration.peersTable.c.id == self.id
)

View File

@@ -8,7 +8,7 @@ import pyotp
import sqlalchemy as db
import requests
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
from .DashboardClientsPeerAssignment import DashboardClientsPeerAssignment
from .DashboardClientsTOTP import DashboardClientsTOTP
from .DashboardOIDC import DashboardOIDC

View File

@@ -1,7 +1,7 @@
import datetime
import uuid
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
from .DashboardLogger import DashboardLogger
import sqlalchemy as db
from .WireguardConfiguration import WireguardConfiguration

View File

@@ -3,7 +3,7 @@ import hashlib
import uuid
import sqlalchemy as db
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
class DashboardClientsTOTP:

View File

@@ -7,14 +7,10 @@ import sqlalchemy as db
from datetime import datetime
from typing import Any
from flask import current_app
from .ConnectionString import ConnectionString
from .Utilities import (
GetRemoteEndpoint, ValidateDNSAddress
)
from .DatabaseConnection import ConnectionString
from .Utilities import (GetRemoteEndpoint, ValidateDNSAddress)
from .DashboardAPIKey import DashboardAPIKey
class DashboardConfig:
DashboardVersion = 'v4.3.2'
ConfigurationPath = os.getenv('CONFIGURATION_PATH', '.')
@@ -104,6 +100,52 @@ class DashboardConfig:
self.APIAccessed = False
self.SetConfig("Server", "version", DashboardConfig.DashboardVersion)
def EnsureDatabaseIntegrity(self, wireguardConfigurations):
expected_columns = {
'id': db.String(255),
'private_key': db.String(255),
'DNS': db.Text,
'endpoint_allowed_ip': db.Text,
'name': db.Text,
'total_receive': db.Float,
'total_sent': db.Float,
'total_data': db.Float,
'endpoint': db.String(255),
'status': db.String(255),
'latest_handshake': db.String(255),
'allowed_ip': db.String(255),
'cumu_receive': db.Float,
'cumu_sent': db.Float,
'cumu_data': db.Float,
'mtu': db.Integer,
'keepalive': db.Integer,
'notes': db.Text,
'remote_endpoint': db.String(255),
'preshared_key': db.String(255)
}
inspector = db.inspect(self.engine)
with self.engine.begin() as conn:
for cfg_name, cfg_obj in wireguardConfigurations.items():
tables_to_check = [
cfg_name,
f'{cfg_name}_restrict_access',
f'{cfg_name}_deleted'
]
for table_name in tables_to_check:
if not inspector.has_table(table_name):
continue
existing_columns = [c['name'] for c in inspector.get_columns(table_name)]
for col_name, col_type in expected_columns.items():
if col_name not in existing_columns:
type_str = col_type().compile(dialect=self.engine.dialect)
current_app.logger.info(f"Adding missing column '{col_name}' to table '{table_name}'")
conn.execute(db.text(f'ALTER TABLE "{table_name}" ADD COLUMN "{col_name}" {type_str}'))
def getConnectionString(self, database) -> str or None:
sqlitePath = os.path.join(DashboardConfig.ConfigurationPath, "db")
@@ -118,7 +160,7 @@ class DashboardConfig:
cn = f'sqlite:///{os.path.join(sqlitePath, f"{database}.db")}'
if not database_exists(cn):
create_database(cn)
return cn
return cn
def __createAPIKeyTable(self):
self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata,

View File

@@ -4,7 +4,7 @@ Dashboard Logger Class
import uuid
import sqlalchemy as db
from flask import current_app
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
class DashboardLogger:

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timedelta
import requests
from pydantic import BaseModel, field_serializer
import sqlalchemy as db
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
from flask import current_app
WebHookActions = ['peer_created', 'peer_deleted', 'peer_updated']

View File

@@ -6,9 +6,11 @@ from flask import current_app
def ConnectionString(database) -> str:
parser = configparser.ConfigParser(strict=False)
parser.read_file(open('wg-dashboard.ini', "r+"))
sqlitePath = os.path.join("db")
if not os.path.isdir(sqlitePath):
os.mkdir(sqlitePath)
if parser.get("Database", "type") == "postgresql":
cn = f'postgresql+psycopg://{parser.get("Database", "username")}:{parser.get("Database", "password")}@{parser.get("Database", "host")}/{database}'
elif parser.get("Database", "type") == "mysql":
@@ -18,8 +20,9 @@ def ConnectionString(database) -> str:
try:
if not database_exists(cn):
create_database(cn)
current_app.logger.info(f"Database {database} created.")
except Exception as e:
current_app.logger.error("Database error. Terminating...", e)
exit(1)
return cn

View File

@@ -2,7 +2,7 @@ import uuid
from pydantic import BaseModel, field_serializer
import sqlalchemy as db
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
class NewConfigurationTemplate(BaseModel):

View File

@@ -11,7 +11,7 @@ import jinja2
import sqlalchemy as db
from .PeerJob import PeerJob
from .PeerShareLink import PeerShareLink
from .Utilities import GenerateWireguardPublicKey, ValidateIPAddressesWithRange, ValidateDNSAddress
from .Utilities import GenerateWireguardPublicKey, CheckAddress, ValidateDNSAddress
class Peer:
@@ -34,6 +34,7 @@ class Peer:
self.cumu_data = tableData["cumu_data"]
self.mtu = tableData["mtu"]
self.keepalive = tableData["keepalive"]
self.notes = tableData.get("notes", "")
self.remote_endpoint = tableData["remote_endpoint"]
self.preshared_key = tableData["preshared_key"]
self.jobs: list[PeerJob] = []
@@ -49,56 +50,75 @@ class Peer:
def __repr__(self):
return str(self.toJson())
def updatePeer(self, name: str, private_key: str,
def updatePeer(self, name: str,
private_key: str,
preshared_key: str,
dns_addresses: str, allowed_ip: str, endpoint_allowed_ip: str, mtu: int,
keepalive: int) -> tuple[bool, str] or tuple[bool, None]:
dns_addresses: str,
allowed_ip: str,
endpoint_allowed_ip: str,
mtu: int,
keepalive: int,
notes: str
) -> tuple[bool, str | None]:
if not self.configuration.getStatus():
self.configuration.toggleConfiguration()
existingAllowedIps = [item for row in list(
map(lambda x: [q.strip() for q in x.split(',')],
map(lambda y: y.allowed_ip,
list(filter(lambda k: k.id != self.id, self.configuration.getPeersList()))))) for item in row]
# Before we do any compute, let us check if the given endpoint allowed ip is valid at all
if not CheckAddress(endpoint_allowed_ip):
return False, f"Endpoint Allowed IPs format is incorrect"
if allowed_ip in existingAllowedIps:
peers = []
for peer in self.configuration.getPeersList():
# Make sure to exclude your own data when updating since its not really relevant
if peer.id != self.id:
continue
peers.append(peer)
used_allowed_ips = []
for peer in peers:
ips = peer.allowed_ip.split(',')
ips = [ip.strip() for ip in ips]
used_allowed_ips.append(ips)
if allowed_ip in used_allowed_ips:
return False, "Allowed IP already taken by another peer"
if not ValidateIPAddressesWithRange(endpoint_allowed_ip):
return False, f"Endpoint Allowed IPs format is incorrect"
if not ValidateDNSAddress(dns_addresses):
return False, f"DNS IP-Address or FQDN is incorrect"
if len(dns_addresses) > 0 and not ValidateDNSAddress(dns_addresses):
return False, f"DNS format is incorrect"
if type(mtu) is str or mtu is None:
if isinstance(mtu, str):
mtu = 0
if mtu < 0 or mtu > 1460:
return False, "MTU format is not correct"
if type(keepalive) is str or keepalive is None:
if isinstance(keepalive, str):
keepalive = 0
if mtu not in range(0, 1461):
return False, "MTU format is not correct"
if keepalive < 0:
return False, "Persistent Keepalive format is not correct"
if len(private_key) > 0:
pubKey = GenerateWireguardPublicKey(private_key)
if not pubKey[0] or pubKey[1] != self.id:
return False, "Private key does not match with the public key"
try:
rd = random.Random()
uid = str(uuid.UUID(int=rd.getrandbits(128), version=4))
pskExist = len(preshared_key) > 0
if pskExist:
try:
rand = random.Random()
uid = str(uuid.UUID(int=rand.getrandbits(128), version=4))
psk_exist = len(preshared_key) > 0
if psk_exist:
with open(uid, "w+") as f:
f.write(preshared_key)
newAllowedIPs = allowed_ip.replace(" ", "")
updateAllowedIp = subprocess.check_output(
f"{self.configuration.Protocol} set {self.configuration.Name} peer {self.id} allowed-ips {newAllowedIPs} {f'preshared-key {uid}' if pskExist else 'preshared-key /dev/null'}",
f"{self.configuration.Protocol} set {self.configuration.Name} peer {self.id} allowed-ips {newAllowedIPs} {f'preshared-key {uid}' if psk_exist else 'preshared-key /dev/null'}",
shell=True, stderr=subprocess.STDOUT)
if pskExist: os.remove(uid)
if psk_exist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
return False, "Update peer failed when updating Allowed IPs"
saveConfig = subprocess.check_output(f"{self.configuration.Protocol}-quick save {self.configuration.Name}",
@@ -114,6 +134,7 @@ class Peer:
"endpoint_allowed_ip": endpoint_allowed_ip,
"mtu": mtu,
"keepalive": keepalive,
"notes": notes,
"preshared_key": preshared_key
}).where(
self.configuration.peersTable.c.id == self.id

View File

@@ -8,7 +8,7 @@ import sqlalchemy as db
from flask import current_app
from sqlalchemy import RowMapping
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
from .Log import Log
class PeerJobLogger:

View File

@@ -3,7 +3,7 @@ Peer Jobs
"""
import sqlalchemy
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
from .PeerJob import PeerJob
from .PeerJobLogger import PeerJobLogger
import sqlalchemy as db

View File

@@ -1,4 +1,4 @@
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
from .PeerShareLink import PeerShareLink
import sqlalchemy as db
from datetime import datetime

View File

@@ -1,6 +1,6 @@
import re, ipaddress
import subprocess
import sqlalchemy
def RegexMatch(regex, text) -> bool:
"""
@@ -41,31 +41,32 @@ def StringToBoolean(value: str):
return (value.strip().replace(" ", "").lower() in
("yes", "true", "t", "1", 1))
def ValidateIPAddressesWithRange(ips: str) -> bool:
s = ips.replace(" ", "").split(",")
for ip in s:
def CheckAddress(ips_str: str) -> bool:
if len(ips_str) == 0:
return False
for ip in ips_str.split(','):
stripped_ip = ip.strip()
try:
ipaddress.ip_network(ip)
except ValueError as e:
# Verify the IP-address, with the strict flag as false also allows for /32 and /128
ipaddress.ip_network(stripped_ip, strict=False)
except ValueError:
return False
return True
def ValidateIPAddresses(ips) -> bool:
s = ips.replace(" ", "").split(",")
for ip in s:
try:
ipaddress.ip_address(ip)
except ValueError as e:
return False
return True
def ValidateDNSAddress(addresses_str: str) -> tuple[bool, str | None]:
if len(addresses_str) == 0:
return False, "Got an empty list/string to check for valid DNS-addresses"
addresses = addresses_str.split(',')
for address in addresses:
stripped_address = address.strip()
if not CheckAddress(stripped_address) and not RegexMatch(r"(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z]{0,61}[a-z]", stripped_address):
return False, f"{stripped_address} does not appear to be a valid IP-address or FQDN"
return True, None
def ValidateDNSAddress(addresses) -> tuple[bool, str]:
s = addresses.replace(" ", "").split(",")
for address in s:
if not ValidateIPAddresses(address) and not RegexMatch(
r"(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z]{0,61}[a-z]", address):
return False, f"{address} does not appear to be an valid DNS address"
return True, ""
def ValidateEndpointAllowedIPs(IPs) -> tuple[bool, str] | tuple[bool, None]:
ips = IPs.replace(" ", "").split(",")

View File

@@ -10,12 +10,15 @@ from datetime import datetime, timedelta
from itertools import islice
from flask import current_app
from .ConnectionString import ConnectionString
from .DatabaseConnection import ConnectionString
from .DashboardConfig import DashboardConfig
from .Peer import Peer
from .PeerJobs import PeerJobs
from .PeerShareLinks import PeerShareLinks
from .Utilities import StringToBoolean, GenerateWireguardPublicKey, RegexMatch, ValidateDNSAddress, \
from .Utilities import StringToBoolean, \
GenerateWireguardPublicKey, \
RegexMatch, \
ValidateDNSAddress, \
ValidateEndpointAllowedIPs
from .WireguardConfigurationInfo import WireguardConfigurationInfo, PeerGroupsClass
from .DashboardWebHooks import DashboardWebHooks
@@ -61,13 +64,14 @@ class WireguardConfiguration:
self.Protocol = "wg" if wg else "awg"
self.AllPeerJobs = AllPeerJobs
self.DashboardConfig = DashboardConfig
self.DashboardConfig.EnsureDatabaseIntegrity({self.Name: self})
self.AllPeerShareLinks = AllPeerShareLinks
self.DashboardWebHooks = DashboardWebHooks
self.configPath = os.path.join(self.__getProtocolPath(), f'{self.Name}.conf')
self.engine: sqlalchemy.Engine = sqlalchemy.create_engine(ConnectionString("wgdashboard"))
self.metadata: sqlalchemy.MetaData = sqlalchemy.MetaData()
self.dbType = self.DashboardConfig.GetConfig("Database", "type")[1]
if name is not None:
if data is not None and "Backup" in data.keys():
db = self.__importDatabase(
@@ -150,7 +154,6 @@ class WireguardConfiguration:
if self.Status:
self.addAutostart()
def __getProtocolPath(self) -> str:
_, path = self.DashboardConfig.GetConfig("Server", "wg_conf_path") if self.Protocol == "wg" \
@@ -242,54 +245,50 @@ class WireguardConfiguration:
return True
def createDatabase(self, dbName = None):
def generate_column_obj():
return [
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False, primary_key=True),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('notes', sqlalchemy.Text),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255))
]
if dbName is None:
dbName = self.Name
self.peersTable = sqlalchemy.Table(
dbName, self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False, primary_key=True),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255)),
extend_existing=True
f'{dbName}', self.metadata, *generate_column_obj(), extend_existing=True
)
self.peersRestrictedTable = sqlalchemy.Table(
f'{dbName}_restrict_access', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False, primary_key=True),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255)),
extend_existing=True
f'{dbName}_restrict_access', self.metadata, *generate_column_obj(), extend_existing=True
)
self.peersDeletedTable = sqlalchemy.Table(
f'{dbName}_deleted', self.metadata, *generate_column_obj(), extend_existing=True
)
if self.DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite':
time_col_type = sqlalchemy.DATETIME
else:
time_col_type = sqlalchemy.TIMESTAMP
self.peersTransferTable = sqlalchemy.Table(
f'{dbName}_transfer', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False),
@@ -299,8 +298,7 @@ class WireguardConfiguration:
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('time', (sqlalchemy.DATETIME if self.DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite' else sqlalchemy.TIMESTAMP),
server_default=sqlalchemy.func.now()),
sqlalchemy.Column('time', time_col_type, server_default=sqlalchemy.func.now()),
extend_existing=True
)
@@ -308,34 +306,9 @@ class WireguardConfiguration:
f'{dbName}_history_endpoint', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False),
sqlalchemy.Column('endpoint', sqlalchemy.String(255), nullable=False),
sqlalchemy.Column('time',
(sqlalchemy.DATETIME if self.DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite' else sqlalchemy.TIMESTAMP)),
extend_existing=True
sqlalchemy.Column('time', time_col_type)
)
self.peersDeletedTable = sqlalchemy.Table(
f'{dbName}_deleted', self.metadata,
sqlalchemy.Column('id', sqlalchemy.String(255), nullable=False, primary_key=True),
sqlalchemy.Column('private_key', sqlalchemy.String(255)),
sqlalchemy.Column('DNS', sqlalchemy.Text),
sqlalchemy.Column('endpoint_allowed_ip', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('total_receive', sqlalchemy.Float),
sqlalchemy.Column('total_sent', sqlalchemy.Float),
sqlalchemy.Column('total_data', sqlalchemy.Float),
sqlalchemy.Column('endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('status', sqlalchemy.String(255)),
sqlalchemy.Column('latest_handshake', sqlalchemy.String(255)),
sqlalchemy.Column('allowed_ip', sqlalchemy.String(255)),
sqlalchemy.Column('cumu_receive', sqlalchemy.Float),
sqlalchemy.Column('cumu_sent', sqlalchemy.Float),
sqlalchemy.Column('cumu_data', sqlalchemy.Float),
sqlalchemy.Column('mtu', sqlalchemy.Integer),
sqlalchemy.Column('keepalive', sqlalchemy.Integer),
sqlalchemy.Column('remote_endpoint', sqlalchemy.String(255)),
sqlalchemy.Column('preshared_key', sqlalchemy.String(255)),
extend_existing=True
)
self.infoTable = sqlalchemy.Table(
'ConfigurationsInfo', self.metadata,
sqlalchemy.Column('ID', sqlalchemy.String(255), primary_key=True),
@@ -450,8 +423,7 @@ class WireguardConfiguration:
"id": i['PublicKey'],
"private_key": "",
"DNS": self.DashboardConfig.GetConfig("Peers", "peer_global_DNS")[1],
"endpoint_allowed_ip": self.DashboardConfig.GetConfig("Peers", "peer_endpoint_allowed_ip")[
1],
"endpoint_allowed_ip": self.DashboardConfig.GetConfig("Peers", "peer_endpoint_allowed_ip")[1],
"name": i.get("name"),
"total_receive": 0,
"total_sent": 0,
@@ -465,6 +437,7 @@ class WireguardConfiguration:
"cumu_data": 0,
"mtu": self.DashboardConfig.GetConfig("Peers", "peer_mtu")[1] if len(self.DashboardConfig.GetConfig("Peers", "peer_mtu")[1]) > 0 else None,
"keepalive": self.DashboardConfig.GetConfig("Peers", "peer_keep_alive")[1] if len(self.DashboardConfig.GetConfig("Peers", "peer_keep_alive")[1]) > 0 else None,
"notes": "",
"remote_endpoint": self.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1],
"preshared_key": i["PresharedKey"] if "PresharedKey" in i.keys() else ""
}
@@ -557,6 +530,7 @@ class WireguardConfiguration:
"cumu_data": 0,
"mtu": i['mtu'],
"keepalive": i['keepalive'],
"notes": i.get("notes", ""),
"remote_endpoint": self.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1],
"preshared_key": i["preshared_key"]
}