Refactored some of the codes

This commit is contained in:
Donald Zou 2024-11-25 22:11:51 +08:00
parent 6a4d16fae9
commit 578a1db62f
2 changed files with 220 additions and 209 deletions

76
src/Utilities.py Normal file
View File

@ -0,0 +1,76 @@
import re, ipaddress
import subprocess
def RegexMatch(regex, text) -> bool:
"""
Regex Match
@param regex: Regex patter
@param text: Text to match
@return: Boolean indicate if the text match the regex pattern
"""
pattern = re.compile(regex)
return pattern.search(text) is not None
def GetRemoteEndpoint() -> str:
"""
Using socket to determine default interface IP address. Thanks, @NOXICS
@return:
"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("1.1.1.1", 80)) # Connecting to a public IP
wgd_remote_endpoint = s.getsockname()[0]
return str(wgd_remote_endpoint)
def StringToBoolean(value: str):
"""
Convert string boolean to boolean
@param value: Boolean value in string came from Configuration file
@return: Boolean value
"""
return (value.strip().replace(" ", "").lower() in
("yes", "true", "t", "1", 1))
def ValidateIPAddressesWithRange(ips: str) -> bool:
s = ips.replace(" ", "").split(",")
for ip in s:
try:
ipaddress.ip_network(ip)
except ValueError as e:
return False
return True
def ValidateIPAddresses(ips) -> bool:
s = ips.replace(" ", "").split(",")
for ip in s:
try:
ipaddress.ip_address(ip)
except ValueError as e:
return False
return True
def ValidateDNSAddress(addresses) -> tuple[bool, str]:
s = addresses.replace(" ", "").split(",")
for address in s:
if not ValidateIPAddresses(address) and not RegexMatch(
r"(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z]{0,61}[a-z]", address):
return False, f"{address} does not appear to be an valid DNS address"
return True, ""
def GenerateWireguardPublicKey(privateKey: str) -> tuple[bool, str] | tuple[bool, None]:
try:
publicKey = subprocess.check_output(f"wg pubkey", input=privateKey.encode(), shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None
def GenerateWireguardPrivateKey() -> tuple[bool, str] | tuple[bool, None]:
try:
publicKey = subprocess.check_output(f"wg genkey", shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None

View File

@ -1,33 +1,20 @@
import itertools, random import itertools, random, shutil, sqlite3, configparser, hashlib, ipaddress, json, traceback, os, secrets, subprocess
import shutil import time, re, urllib.error, uuid, bcrypt, psutil, pyotp, threading
import sqlite3
import configparser
import hashlib
import ipaddress
import json
import traceback
import os
import secrets
import subprocess
import time
import re
import urllib.error
import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any from typing import Any
import bcrypt
# import ifcfg
import psutil
import pyotp
from flask import Flask, request, render_template, session, g from flask import Flask, request, render_template, session, g
from json import JSONEncoder from json import JSONEncoder
from flask_cors import CORS from flask_cors import CORS
from icmplib import ping, traceroute from icmplib import ping, traceroute
import threading
from flask.json.provider import DefaultJSONProvider from flask.json.provider import DefaultJSONProvider
from Utilities import (
RegexMatch, GetRemoteEndpoint, StringToBoolean,
ValidateIPAddressesWithRange, ValidateIPAddresses, ValidateDNSAddress,
GenerateWireguardPublicKey, GenerateWireguardPrivateKey
)
DASHBOARD_VERSION = 'v4.1.1'
DASHBOARD_VERSION = 'v4.2.0'
CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.') CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.')
DB_PATH = os.path.join(CONFIGURATION_PATH, 'db') DB_PATH = os.path.join(CONFIGURATION_PATH, 'db')
if not os.path.isdir(DB_PATH): if not os.path.isdir(DB_PATH):
@ -46,10 +33,24 @@ class ModelEncoder(JSONEncoder):
else: else:
return super(ModelEncoder, self).default(o) return super(ModelEncoder, self).default(o)
''' class CustomJsonEncoder(DefaultJSONProvider):
Classes def __init__(self, app):
''' super().__init__(app)
def default(self, o):
if (isinstance(o, WireguardConfiguration)
or isinstance(o, Peer)
or isinstance(o, PeerJob)
or isinstance(o, Log)
or isinstance(o, DashboardAPIKey)
or isinstance(o, PeerShareLink)):
return o.toJson()
return super().default(self, o)
app.json = CustomJsonEncoder(app)
'''
Response Object
'''
def ResponseObject(status=True, message=None, data=None) -> Flask.response_class: def ResponseObject(status=True, message=None, data=None) -> Flask.response_class:
response = Flask.make_response(app, { response = Flask.make_response(app, {
"status": status, "status": status,
@ -59,24 +60,9 @@ def ResponseObject(status=True, message=None, data=None) -> Flask.response_class
response.content_type = "application/json" response.content_type = "application/json"
return response return response
"""
class CustomJsonEncoder(DefaultJSONProvider): Log Class
def __init__(self, app): """
super().__init__(app)
def default(self, o):
if (isinstance(o, WireguardConfiguration)
or isinstance(o, Peer)
or isinstance(o, PeerJob)
or isinstance(o, Log)
or isinstance(o, DashboardAPIKey)
or isinstance(o, PeerShareLink)):
return o.toJson()
return super().default(self, o)
app.json = CustomJsonEncoder(app)
class Log: class Log:
def __init__(self, LogID: str, JobID: str, LogDate: str, Status: str, Message: str): def __init__(self, LogID: str, JobID: str, LogDate: str, Status: str, Message: str):
self.LogID = LogID self.LogID = LogID
@ -96,7 +82,10 @@ class Log:
def __dict__(self): def __dict__(self):
return self.toJson() return self.toJson()
"""
Dashboard Logger Class
"""
class DashboardLogger: class DashboardLogger:
def __init__(self): def __init__(self):
self.loggerdb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard_log.db'), self.loggerdb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard_log.db'),
@ -128,7 +117,10 @@ class DashboardLogger:
except Exception as e: except Exception as e:
print(f"[WGDashboard] Access Log Error: {str(e)}") print(f"[WGDashboard] Access Log Error: {str(e)}")
return False return False
"""
Peer Job Logger
"""
class PeerJobLogger: class PeerJobLogger:
def __init__(self): def __init__(self):
self.loggerdb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard_log.db'), self.loggerdb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard_log.db'),
@ -176,7 +168,10 @@ class PeerJobLogger:
except Exception as e: except Exception as e:
return logs return logs
return logs return logs
"""
Peer Job
"""
class PeerJob: class PeerJob:
def __init__(self, JobID: str, Configuration: str, Peer: str, def __init__(self, JobID: str, Configuration: str, Peer: str,
Field: str, Operator: str, Value: str, CreationDate: datetime, ExpireDate: datetime, Action: str): Field: str, Operator: str, Value: str, CreationDate: datetime, ExpireDate: datetime, Action: str):
@ -206,6 +201,9 @@ class PeerJob:
def __dict__(self): def __dict__(self):
return self.toJson() return self.toJson()
"""
Peer Jobs
"""
class PeerJobs: class PeerJobs:
def __init__(self): def __init__(self):
@ -367,7 +365,10 @@ class PeerJobs:
return x > y return x > y
if operator == "lst": if operator == "lst":
return x < y return x < y
"""
Peer Share Link
"""
class PeerShareLink: class PeerShareLink:
def __init__(self, ShareID:str, Configuration: str, Peer: str, ExpireDate: datetime, ShareDate: datetime): def __init__(self, ShareID:str, Configuration: str, Peer: str, ExpireDate: datetime, ShareDate: datetime):
self.ShareID = ShareID self.ShareID = ShareID
@ -385,6 +386,9 @@ class PeerShareLink:
"ExpireDate": self.ExpireDate "ExpireDate": self.ExpireDate
} }
"""
Peer Share Links
"""
class PeerShareLinks: class PeerShareLinks:
def __init__(self): def __init__(self):
self.Links: list[PeerShareLink] = [] self.Links: list[PeerShareLink] = []
@ -429,7 +433,10 @@ class PeerShareLinks:
sqlUpdate("UPDATE PeerShareLinks SET ExpireDate = ? WHERE ShareID = ?;", (ExpireDate, ShareID, )) sqlUpdate("UPDATE PeerShareLinks SET ExpireDate = ? WHERE ShareID = ?;", (ExpireDate, ShareID, ))
self.__getSharedLinks() self.__getSharedLinks()
return True, "" return True, ""
"""
WireGuard Configuration
"""
class WireguardConfiguration: class WireguardConfiguration:
class InvalidConfigurationFileException(Exception): class InvalidConfigurationFileException(Exception):
def __init__(self, m): def __init__(self, m):
@ -482,7 +489,7 @@ class WireguardConfiguration:
for i in dir(self): for i in dir(self):
if str(i) in data.keys(): if str(i) in data.keys():
if isinstance(getattr(self, i), bool): if isinstance(getattr(self, i), bool):
setattr(self, i, _strToBool(data[i])) setattr(self, i, StringToBoolean(data[i]))
else: else:
setattr(self, i, str(data[i])) setattr(self, i, str(data[i]))
@ -507,8 +514,7 @@ class WireguardConfiguration:
if self.getAutostartStatus() and not self.getStatus() and startup: if self.getAutostartStatus() and not self.getStatus() and startup:
self.toggleConfiguration() self.toggleConfiguration()
print(f"[WGDashboard] Autostart Configuration: {name}") print(f"[WGDashboard] Autostart Configuration: {name}")
def __initPeersList(self): def __initPeersList(self):
self.Peers: list[Peer] = [] self.Peers: list[Peer] = []
self.getPeersList() self.getPeersList()
@ -524,7 +530,7 @@ class WireguardConfiguration:
for i in dir(self): for i in dir(self):
if str(i) in interfaceConfig.keys(): if str(i) in interfaceConfig.keys():
if isinstance(getattr(self, i), bool): if isinstance(getattr(self, i), bool):
setattr(self, i, _strToBool(interfaceConfig[i])) setattr(self, i, StringToBoolean(interfaceConfig[i]))
else: else:
setattr(self, i, interfaceConfig[i]) setattr(self, i, interfaceConfig[i])
if self.PrivateKey: if self.PrivateKey:
@ -620,7 +626,7 @@ class WireguardConfiguration:
return True return True
def __getPublicKey(self) -> str: def __getPublicKey(self) -> str:
return _generatePublicKey(self.PrivateKey)[1] return GenerateWireguardPublicKey(self.PrivateKey)[1]
def getStatus(self) -> bool: def getStatus(self) -> bool:
self.Status = self.Name in psutil.net_if_addrs().keys() self.Status = self.Name in psutil.net_if_addrs().keys()
@ -653,7 +659,7 @@ class WireguardConfiguration:
peerStarts = content.index("[Peer]") peerStarts = content.index("[Peer]")
content = content[peerStarts:] content = content[peerStarts:]
for i in content: for i in content:
if not regex_match("#(.*)", i) and not regex_match(";(.*)", i): if not RegexMatch("#(.*)", i) and not RegexMatch(";(.*)", i):
if i == "[Peer]": if i == "[Peer]":
pCounter += 1 pCounter += 1
p.append({}) p.append({})
@ -664,7 +670,7 @@ class WireguardConfiguration:
if len(split) == 2: if len(split) == 2:
p[pCounter][split[0]] = split[1] p[pCounter][split[0]] = split[1]
if regex_match("#Name# = (.*)", i): if RegexMatch("#Name# = (.*)", i):
split = re.split(r'\s*=\s*', i, 1) split = re.split(r'\s*=\s*', i, 1)
if len(split) == 2: if len(split) == 2:
p[pCounter]["name"] = split[1] p[pCounter]["name"] = split[1]
@ -1037,7 +1043,7 @@ class WireguardConfiguration:
files.sort(key=lambda x: x[1], reverse=True) files.sort(key=lambda x: x[1], reverse=True)
for f, ct in files: for f, ct in files:
if _regexMatch(f"^({self.Name})_(.*)\\.(conf)$", f): if RegexMatch(f"^({self.Name})_(.*)\\.(conf)$", f):
s = re.search(f"^({self.Name})_(.*)\\.(conf)$", f) s = re.search(f"^({self.Name})_(.*)\\.(conf)$", f)
date = s.group(2) date = s.group(2)
d = { d = {
@ -1107,7 +1113,7 @@ class WireguardConfiguration:
split[1] = newData[key] split[1] = newData[key]
original[line] = " = ".join(split) original[line] = " = ".join(split)
if isinstance(getattr(self, key), bool): if isinstance(getattr(self, key), bool):
setattr(self, key, _strToBool(newData[key])) setattr(self, key, StringToBoolean(newData[key]))
else: else:
setattr(self, key, str(newData[key])) setattr(self, key, str(newData[key]))
dataChanged = True dataChanged = True
@ -1150,7 +1156,46 @@ class WireguardConfiguration:
except Exception as e: except Exception as e:
return False, str(e) return False, str(e)
return True, None return True, None
def getAvailableIP(self, all: bool = False) -> tuple[bool, list[str]] | tuple[bool, None]:
if len(self.Address) < 0:
return False, None
address = self.Address.split(',')
existedAddress = []
availableAddress = []
for p in self.Peers:
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
try:
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
except ValueError as error:
print(f"[WGDashboard] Error: {configName} peer {p.id} have invalid ip")
for p in self.getRestrictedPeersList():
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
for i in address:
addressSplit, cidr = i.split('/')
existedAddress.append(ipaddress.ip_address(addressSplit.replace(" ", "")))
for i in address:
network = ipaddress.ip_network(i.replace(" ", ""), False)
count = 0
for h in network.hosts():
if h not in existedAddress:
availableAddress.append(ipaddress.ip_network(h).compressed)
count += 1
if not all:
if network.version == 6 and count > 255:
break
return True, availableAddress
"""
Peer
"""
class Peer: class Peer:
def __init__(self, tableData, configuration: WireguardConfiguration): def __init__(self, tableData, configuration: WireguardConfiguration):
self.configuration = configuration self.configuration = configuration
@ -1200,16 +1245,16 @@ class Peer:
if allowed_ip in existingAllowedIps: if allowed_ip in existingAllowedIps:
return ResponseObject(False, "Allowed IP already taken by another peer") return ResponseObject(False, "Allowed IP already taken by another peer")
if not _checkIPWithRange(endpoint_allowed_ip): if not ValidateIPAddressesWithRange(endpoint_allowed_ip):
return ResponseObject(False, f"Endpoint Allowed IPs format is incorrect") return ResponseObject(False, f"Endpoint Allowed IPs format is incorrect")
if len(dns_addresses) > 0 and not _checkDNS(dns_addresses): if len(dns_addresses) > 0 and not ValidateDNSAddress(dns_addresses):
return ResponseObject(False, f"DNS format is incorrect") return ResponseObject(False, f"DNS format is incorrect")
if mtu < 0 or mtu > 1460: if mtu < 0 or mtu > 1460:
return ResponseObject(False, "MTU format is not correct") return ResponseObject(False, "MTU format is not correct")
if keepalive < 0: if keepalive < 0:
return ResponseObject(False, "Persistent Keepalive format is not correct") return ResponseObject(False, "Persistent Keepalive format is not correct")
if len(private_key) > 0: if len(private_key) > 0:
pubKey = _generatePublicKey(private_key) pubKey = GenerateWireguardPublicKey(private_key)
if not pubKey[0] or pubKey[1] != self.id: if not pubKey[0] or pubKey[1] != self.id:
return ResponseObject(False, "Private key does not match with the public key") return ResponseObject(False, "Private key does not match with the public key")
try: try:
@ -1297,20 +1342,10 @@ PersistentKeepalive = {str(self.keepalive)}
except Exception as e: except Exception as e:
return False return False
return True return True
# Regex Match
def regex_match(regex, text):
pattern = re.compile(regex)
return pattern.search(text) is not None
def get_remote_endpoint():
# Thanks @NOXICS
import socket
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("1.1.1.1", 80)) # Connecting to a public IP
wgd_remote_endpoint = s.getsockname()[0]
return str(wgd_remote_endpoint)
"""
Dashboard API Key
"""
class DashboardAPIKey: class DashboardAPIKey:
def __init__(self, Key: str, CreatedAt: str, ExpiredAt: str): def __init__(self, Key: str, CreatedAt: str, ExpiredAt: str):
self.Key = Key self.Key = Key
@ -1320,6 +1355,9 @@ class DashboardAPIKey:
def toJson(self): def toJson(self):
return self.__dict__ return self.__dict__
"""
Dashboard Configuration
"""
class DashboardConfig: class DashboardConfig:
def __init__(self): def __init__(self):
@ -1353,7 +1391,7 @@ class DashboardConfig:
"peer_global_DNS": "1.1.1.1", "peer_global_DNS": "1.1.1.1",
"peer_endpoint_allowed_ip": "0.0.0.0/0", "peer_endpoint_allowed_ip": "0.0.0.0/0",
"peer_display_mode": "grid", "peer_display_mode": "grid",
"remote_endpoint": get_remote_endpoint(), "remote_endpoint": GetRemoteEndpoint(),
"peer_MTU": "1420", "peer_MTU": "1420",
"peer_keep_alive": "21" "peer_keep_alive": "21"
}, },
@ -1406,7 +1444,7 @@ class DashboardConfig:
if type(value) is str and len(value) == 0: if type(value) is str and len(value) == 0:
return False, "Field cannot be empty!" return False, "Field cannot be empty!"
if key == "peer_global_dns": if key == "peer_global_dns":
return _checkDNS(value) return ValidateDNSAddress(value)
if key == "peer_endpoint_allowed_ip": if key == "peer_endpoint_allowed_ip":
value = value.split(",") value = value.split(",")
for i in value: for i in value:
@ -1505,128 +1543,9 @@ class DashboardConfig:
the_dict[section][key] = self.GetConfig(section, key)[1] the_dict[section][key] = self.GetConfig(section, key)[1]
return the_dict return the_dict
''' """
Private Functions Database Connection Functions
''' """
def _strToBool(value: str) -> bool:
return value.lower() in ("yes", "true", "t", "1", 1)
def _regexMatch(regex, text):
pattern = re.compile(regex)
return pattern.search(text) is not None
def _getConfigurationList(startup: bool = False):
confs = os.listdir(DashboardConfig.GetConfig("Server", "wg_conf_path")[1])
confs.sort()
for i in confs:
if _regexMatch("^(.{1,}).(conf)$", i):
i = i.replace('.conf', '')
try:
if i in WireguardConfigurations.keys():
if WireguardConfigurations[i].configurationFileChanged():
WireguardConfigurations[i] = WireguardConfiguration(i)
else:
WireguardConfigurations[i] = WireguardConfiguration(i, startup=startup)
except WireguardConfiguration.InvalidConfigurationFileException as e:
print(f"{i} have an invalid configuration file.")
def _checkIPWithRange(ip):
ip_patterns = (
r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|\/)){4}([0-9]{1,2})(,|$)",
r"[0-9a-fA-F]{0,4}(:([0-9a-fA-F]{0,4})){1,7}\/([0-9]{1,3})(,|$)"
)
for match_pattern in ip_patterns:
match_result = regex_match(match_pattern, ip)
if match_result:
result = match_result
break
else:
result = None
return result
def _checkIP(ip):
ip_patterns = (
r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|$)){4}",
r"[0-9a-fA-F]{0,4}(:([0-9a-fA-F]{0,4})){1,7}$"
)
for match_pattern in ip_patterns:
match_result = regex_match(match_pattern, ip)
if match_result:
result = match_result
break
else:
result = None
return result
def _checkDNS(dns):
dns = dns.replace(' ', '').split(',')
for i in dns:
if not _checkIP(i) and not regex_match(r"(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z]{0,61}[a-z]", i):
return False, f"{i} does not appear to be an valid DNS address"
return True, ""
def _generatePublicKey(privateKey) -> tuple[bool, str] | tuple[bool, None]:
try:
publicKey = subprocess.check_output(f"wg pubkey", input=privateKey.encode(), shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None
def _generatePrivateKey() -> [bool, str]:
try:
publicKey = subprocess.check_output(f"wg genkey", shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None
def _getWireguardConfigurationAvailableIP(configName: str, all: bool = False) -> tuple[bool, list[str]] | tuple[bool, None]:
if configName not in WireguardConfigurations.keys():
return False, None
configuration = WireguardConfigurations[configName]
if len(configuration.Address) > 0:
address = configuration.Address.split(',')
existedAddress = []
availableAddress = []
for p in configuration.Peers:
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
try:
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
except ValueError as error:
print(f"[WGDashboard] Error: {configName} peer {p.id} have invalid ip")
for p in configuration.getRestrictedPeersList():
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
for i in address:
addressSplit, cidr = i.split('/')
existedAddress.append(ipaddress.ip_address(addressSplit.replace(" ", "")))
for i in address:
network = ipaddress.ip_network(i.replace(" ", ""), False)
count = 0
for h in network.hosts():
if h not in existedAddress:
availableAddress.append(ipaddress.ip_network(h).compressed)
count += 1
if not all:
if network.version == 6 and count > 255:
break
return True, availableAddress
return False, None
sqldb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard.db'), check_same_thread=False) sqldb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard.db'), check_same_thread=False)
sqldb.row_factory = sqlite3.Row sqldb.row_factory = sqlite3.Row
@ -1642,7 +1561,6 @@ def sqlSelect(statement: str, paramters: tuple = ()) -> sqlite3.Cursor:
print("[WGDashboard] SQLite Error:" + str(error) + " | Statement: " + statement) print("[WGDashboard] SQLite Error:" + str(error) + " | Statement: " + statement)
return [] return []
def sqlUpdate(statement: str, paramters: tuple = ()) -> sqlite3.Cursor: def sqlUpdate(statement: str, paramters: tuple = ()) -> sqlite3.Cursor:
with sqldb: with sqldb:
cursor = sqldb.cursor() cursor = sqldb.cursor()
@ -1654,6 +1572,7 @@ def sqlUpdate(statement: str, paramters: tuple = ()) -> sqlite3.Cursor:
except sqlite3.OperationalError as error: except sqlite3.OperationalError as error:
print("[WGDashboard] SQLite Error:" + str(error) + " | Statement: " + statement) print("[WGDashboard] SQLite Error:" + str(error) + " | Statement: " + statement)
DashboardConfig = DashboardConfig() DashboardConfig = DashboardConfig()
_, APP_PREFIX = DashboardConfig.GetConfig("Server", "app_prefix") _, APP_PREFIX = DashboardConfig.GetConfig("Server", "app_prefix")
cors = CORS(app, resources={rf"{APP_PREFIX}/api/*": { cors = CORS(app, resources={rf"{APP_PREFIX}/api/*": {
@ -1780,7 +1699,7 @@ def API_SignOut():
@app.route(f'{APP_PREFIX}/api/getWireguardConfigurations', methods=["GET"]) @app.route(f'{APP_PREFIX}/api/getWireguardConfigurations', methods=["GET"])
def API_getWireguardConfigurations(): def API_getWireguardConfigurations():
_getConfigurationList() InitWireguardConfigurationsList()
return ResponseObject(data=[wc for wc in WireguardConfigurations.values()]) return ResponseObject(data=[wc for wc in WireguardConfigurations.values()])
@app.route(f'{APP_PREFIX}/api/addWireguardConfiguration', methods=["POST"]) @app.route(f'{APP_PREFIX}/api/addWireguardConfiguration', methods=["POST"])
@ -1905,7 +1824,7 @@ def API_getAllWireguardConfigurationBackup():
files.sort(key=lambda x: x[1], reverse=True) files.sort(key=lambda x: x[1], reverse=True)
for f, ct in files: for f, ct in files:
if _regexMatch(r"^(.*)_(.*)\.(conf)$", f): if RegexMatch(r"^(.*)_(.*)\.(conf)$", f):
s = re.search(r"^(.*)_(.*)\.(conf)$", f) s = re.search(r"^(.*)_(.*)\.(conf)$", f)
name = s.group(1) name = s.group(1)
if name not in existingConfiguration: if name not in existingConfiguration:
@ -1981,7 +1900,7 @@ def API_updateDashboardConfigurationItem():
if data['section'] == "Server": if data['section'] == "Server":
if data['key'] == 'wg_conf_path': if data['key'] == 'wg_conf_path':
WireguardConfigurations.clear() WireguardConfigurations.clear()
_getConfigurationList() InitWireguardConfigurationsList()
return ResponseObject(True, data=DashboardConfig.GetConfig(data["section"], data["key"])[1]) return ResponseObject(True, data=DashboardConfig.GetConfig(data["section"], data["key"])[1])
@ -2168,7 +2087,7 @@ def API_addPeers(configName):
return ResponseObject(False, "Please provide at least public_key and allowed_ips") return ResponseObject(False, "Please provide at least public_key and allowed_ips")
if not config.getStatus(): if not config.getStatus():
config.toggleConfiguration() config.toggleConfiguration()
availableIps = _getWireguardConfigurationAvailableIP(configName) availableIps = config.getAvailableIP()
if bulkAdd: if bulkAdd:
if type(preshared_key_bulkAdd) is not bool: if type(preshared_key_bulkAdd) is not bool:
preshared_key_bulkAdd = False preshared_key_bulkAdd = False
@ -2182,11 +2101,11 @@ def API_addPeers(configName):
f"The maximum number of peers can add is {len(availableIps[1])}") f"The maximum number of peers can add is {len(availableIps[1])}")
keyPairs = [] keyPairs = []
for i in range(bulkAddAmount): for i in range(bulkAddAmount):
newPrivateKey = _generatePrivateKey()[1] newPrivateKey = GenerateWireguardPrivateKey()[1]
keyPairs.append({ keyPairs.append({
"private_key": newPrivateKey, "private_key": newPrivateKey,
"id": _generatePublicKey(newPrivateKey)[1], "id": GenerateWireguardPublicKey(newPrivateKey)[1],
"preshared_key": (_generatePrivateKey()[1] if preshared_key_bulkAdd else ""), "preshared_key": (GenerateWireguardPrivateKey()[1] if preshared_key_bulkAdd else ""),
"allowed_ip": availableIps[1][i], "allowed_ip": availableIps[1][i],
"name": f"BulkPeer #{(i + 1)}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", "name": f"BulkPeer #{(i + 1)}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
"DNS": dns_addresses, "DNS": dns_addresses,
@ -2257,7 +2176,9 @@ def API_downloadAllPeers(configName):
@app.get(f"{APP_PREFIX}/api/getAvailableIPs/<configName>") @app.get(f"{APP_PREFIX}/api/getAvailableIPs/<configName>")
def API_getAvailableIPs(configName): def API_getAvailableIPs(configName):
status, ips = _getWireguardConfigurationAvailableIP(configName) if configName not in WireguardConfigurations.keys():
return ResponseObject(False, "Configuration does not exist")
status, ips = WireguardConfigurations.get(configName).getAvailableIP()
return ResponseObject(status=status, data=ips) return ResponseObject(status=status, data=ips)
@app.get(f'{APP_PREFIX}/api/getWireguardConfigurationInfo') @app.get(f'{APP_PREFIX}/api/getWireguardConfigurationInfo')
@ -2556,7 +2477,6 @@ def API_Locale_Update():
return ResponseObject(False, "Please specify a lang_id") return ResponseObject(False, "Please specify a lang_id")
Locale.updateLanguage(data['lang_id']) Locale.updateLanguage(data['lang_id'])
return ResponseObject(data=Locale.getLanguage()) return ResponseObject(data=Locale.getLanguage())
@app.get(f'{APP_PREFIX}/') @app.get(f'{APP_PREFIX}/')
def index(): def index():
@ -2593,6 +2513,21 @@ def gunicornConfig():
_, app_port = DashboardConfig.GetConfig("Server", "app_port") _, app_port = DashboardConfig.GetConfig("Server", "app_port")
return app_ip, app_port return app_ip, app_port
def InitWireguardConfigurationsList(startup: bool = False):
confs = os.listdir(DashboardConfig.GetConfig("Server", "wg_conf_path")[1])
confs.sort()
for i in confs:
if RegexMatch("^(.{1,}).(conf)$", i):
i = i.replace('.conf', '')
try:
if i in WireguardConfigurations.keys():
if WireguardConfigurations[i].configurationFileChanged():
WireguardConfigurations[i] = WireguardConfiguration(i)
else:
WireguardConfigurations[i] = WireguardConfiguration(i, startup=startup)
except WireguardConfiguration.InvalidConfigurationFileException as e:
print(f"{i} have an invalid configuration file.")
AllPeerShareLinks: PeerShareLinks = PeerShareLinks() AllPeerShareLinks: PeerShareLinks = PeerShareLinks()
AllPeerJobs: PeerJobs = PeerJobs() AllPeerJobs: PeerJobs = PeerJobs()
JobLogger: PeerJobLogger = PeerJobLogger() JobLogger: PeerJobLogger = PeerJobLogger()
@ -2602,7 +2537,7 @@ _, app_port = DashboardConfig.GetConfig("Server", "app_port")
_, WG_CONF_PATH = DashboardConfig.GetConfig("Server", "wg_conf_path") _, WG_CONF_PATH = DashboardConfig.GetConfig("Server", "wg_conf_path")
WireguardConfigurations: dict[str, WireguardConfiguration] = {} WireguardConfigurations: dict[str, WireguardConfiguration] = {}
_getConfigurationList(startup=True) InitWireguardConfigurationsList(startup=True)
def startThreads(): def startThreads():
bgThread = threading.Thread(target=backGroundThread) bgThread = threading.Thread(target=backGroundThread)