Refactored some of the codes

This commit is contained in:
Donald Zou 2024-11-25 22:11:51 +08:00
parent 6a4d16fae9
commit 578a1db62f
2 changed files with 220 additions and 209 deletions

76
src/Utilities.py Normal file
View File

@ -0,0 +1,76 @@
import re, ipaddress
import subprocess
def RegexMatch(regex, text) -> bool:
"""
Regex Match
@param regex: Regex patter
@param text: Text to match
@return: Boolean indicate if the text match the regex pattern
"""
pattern = re.compile(regex)
return pattern.search(text) is not None
def GetRemoteEndpoint() -> str:
"""
Using socket to determine default interface IP address. Thanks, @NOXICS
@return:
"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("1.1.1.1", 80)) # Connecting to a public IP
wgd_remote_endpoint = s.getsockname()[0]
return str(wgd_remote_endpoint)
def StringToBoolean(value: str):
"""
Convert string boolean to boolean
@param value: Boolean value in string came from Configuration file
@return: Boolean value
"""
return (value.strip().replace(" ", "").lower() in
("yes", "true", "t", "1", 1))
def ValidateIPAddressesWithRange(ips: str) -> bool:
s = ips.replace(" ", "").split(",")
for ip in s:
try:
ipaddress.ip_network(ip)
except ValueError as e:
return False
return True
def ValidateIPAddresses(ips) -> bool:
s = ips.replace(" ", "").split(",")
for ip in s:
try:
ipaddress.ip_address(ip)
except ValueError as e:
return False
return True
def ValidateDNSAddress(addresses) -> tuple[bool, str]:
s = addresses.replace(" ", "").split(",")
for address in s:
if not ValidateIPAddresses(address) and not RegexMatch(
r"(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z]{0,61}[a-z]", address):
return False, f"{address} does not appear to be an valid DNS address"
return True, ""
def GenerateWireguardPublicKey(privateKey: str) -> tuple[bool, str] | tuple[bool, None]:
try:
publicKey = subprocess.check_output(f"wg pubkey", input=privateKey.encode(), shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None
def GenerateWireguardPrivateKey() -> tuple[bool, str] | tuple[bool, None]:
try:
publicKey = subprocess.check_output(f"wg genkey", shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None

View File

@ -1,33 +1,20 @@
import itertools, random
import shutil
import sqlite3
import configparser
import hashlib
import ipaddress
import json
import traceback
import os
import secrets
import subprocess
import time
import re
import urllib.error
import uuid
import itertools, random, shutil, sqlite3, configparser, hashlib, ipaddress, json, traceback, os, secrets, subprocess
import time, re, urllib.error, uuid, bcrypt, psutil, pyotp, threading
from datetime import datetime, timedelta
from typing import Any
import bcrypt
# import ifcfg
import psutil
import pyotp
from flask import Flask, request, render_template, session, g
from json import JSONEncoder
from flask_cors import CORS
from icmplib import ping, traceroute
import threading
from flask.json.provider import DefaultJSONProvider
from Utilities import (
RegexMatch, GetRemoteEndpoint, StringToBoolean,
ValidateIPAddressesWithRange, ValidateIPAddresses, ValidateDNSAddress,
GenerateWireguardPublicKey, GenerateWireguardPrivateKey
)
DASHBOARD_VERSION = 'v4.1.1'
DASHBOARD_VERSION = 'v4.2.0'
CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.')
DB_PATH = os.path.join(CONFIGURATION_PATH, 'db')
if not os.path.isdir(DB_PATH):
@ -46,10 +33,24 @@ class ModelEncoder(JSONEncoder):
else:
return super(ModelEncoder, self).default(o)
'''
Classes
'''
class CustomJsonEncoder(DefaultJSONProvider):
def __init__(self, app):
super().__init__(app)
def default(self, o):
if (isinstance(o, WireguardConfiguration)
or isinstance(o, Peer)
or isinstance(o, PeerJob)
or isinstance(o, Log)
or isinstance(o, DashboardAPIKey)
or isinstance(o, PeerShareLink)):
return o.toJson()
return super().default(self, o)
app.json = CustomJsonEncoder(app)
'''
Response Object
'''
def ResponseObject(status=True, message=None, data=None) -> Flask.response_class:
response = Flask.make_response(app, {
"status": status,
@ -59,24 +60,9 @@ def ResponseObject(status=True, message=None, data=None) -> Flask.response_class
response.content_type = "application/json"
return response
class CustomJsonEncoder(DefaultJSONProvider):
def __init__(self, app):
super().__init__(app)
def default(self, o):
if (isinstance(o, WireguardConfiguration)
or isinstance(o, Peer)
or isinstance(o, PeerJob)
or isinstance(o, Log)
or isinstance(o, DashboardAPIKey)
or isinstance(o, PeerShareLink)):
return o.toJson()
return super().default(self, o)
app.json = CustomJsonEncoder(app)
"""
Log Class
"""
class Log:
def __init__(self, LogID: str, JobID: str, LogDate: str, Status: str, Message: str):
self.LogID = LogID
@ -96,7 +82,10 @@ class Log:
def __dict__(self):
return self.toJson()
"""
Dashboard Logger Class
"""
class DashboardLogger:
def __init__(self):
self.loggerdb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard_log.db'),
@ -128,7 +117,10 @@ class DashboardLogger:
except Exception as e:
print(f"[WGDashboard] Access Log Error: {str(e)}")
return False
"""
Peer Job Logger
"""
class PeerJobLogger:
def __init__(self):
self.loggerdb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard_log.db'),
@ -176,7 +168,10 @@ class PeerJobLogger:
except Exception as e:
return logs
return logs
"""
Peer Job
"""
class PeerJob:
def __init__(self, JobID: str, Configuration: str, Peer: str,
Field: str, Operator: str, Value: str, CreationDate: datetime, ExpireDate: datetime, Action: str):
@ -206,6 +201,9 @@ class PeerJob:
def __dict__(self):
return self.toJson()
"""
Peer Jobs
"""
class PeerJobs:
def __init__(self):
@ -367,7 +365,10 @@ class PeerJobs:
return x > y
if operator == "lst":
return x < y
"""
Peer Share Link
"""
class PeerShareLink:
def __init__(self, ShareID:str, Configuration: str, Peer: str, ExpireDate: datetime, ShareDate: datetime):
self.ShareID = ShareID
@ -385,6 +386,9 @@ class PeerShareLink:
"ExpireDate": self.ExpireDate
}
"""
Peer Share Links
"""
class PeerShareLinks:
def __init__(self):
self.Links: list[PeerShareLink] = []
@ -429,7 +433,10 @@ class PeerShareLinks:
sqlUpdate("UPDATE PeerShareLinks SET ExpireDate = ? WHERE ShareID = ?;", (ExpireDate, ShareID, ))
self.__getSharedLinks()
return True, ""
"""
WireGuard Configuration
"""
class WireguardConfiguration:
class InvalidConfigurationFileException(Exception):
def __init__(self, m):
@ -482,7 +489,7 @@ class WireguardConfiguration:
for i in dir(self):
if str(i) in data.keys():
if isinstance(getattr(self, i), bool):
setattr(self, i, _strToBool(data[i]))
setattr(self, i, StringToBoolean(data[i]))
else:
setattr(self, i, str(data[i]))
@ -507,8 +514,7 @@ class WireguardConfiguration:
if self.getAutostartStatus() and not self.getStatus() and startup:
self.toggleConfiguration()
print(f"[WGDashboard] Autostart Configuration: {name}")
def __initPeersList(self):
self.Peers: list[Peer] = []
self.getPeersList()
@ -524,7 +530,7 @@ class WireguardConfiguration:
for i in dir(self):
if str(i) in interfaceConfig.keys():
if isinstance(getattr(self, i), bool):
setattr(self, i, _strToBool(interfaceConfig[i]))
setattr(self, i, StringToBoolean(interfaceConfig[i]))
else:
setattr(self, i, interfaceConfig[i])
if self.PrivateKey:
@ -620,7 +626,7 @@ class WireguardConfiguration:
return True
def __getPublicKey(self) -> str:
return _generatePublicKey(self.PrivateKey)[1]
return GenerateWireguardPublicKey(self.PrivateKey)[1]
def getStatus(self) -> bool:
self.Status = self.Name in psutil.net_if_addrs().keys()
@ -653,7 +659,7 @@ class WireguardConfiguration:
peerStarts = content.index("[Peer]")
content = content[peerStarts:]
for i in content:
if not regex_match("#(.*)", i) and not regex_match(";(.*)", i):
if not RegexMatch("#(.*)", i) and not RegexMatch(";(.*)", i):
if i == "[Peer]":
pCounter += 1
p.append({})
@ -664,7 +670,7 @@ class WireguardConfiguration:
if len(split) == 2:
p[pCounter][split[0]] = split[1]
if regex_match("#Name# = (.*)", i):
if RegexMatch("#Name# = (.*)", i):
split = re.split(r'\s*=\s*', i, 1)
if len(split) == 2:
p[pCounter]["name"] = split[1]
@ -1037,7 +1043,7 @@ class WireguardConfiguration:
files.sort(key=lambda x: x[1], reverse=True)
for f, ct in files:
if _regexMatch(f"^({self.Name})_(.*)\\.(conf)$", f):
if RegexMatch(f"^({self.Name})_(.*)\\.(conf)$", f):
s = re.search(f"^({self.Name})_(.*)\\.(conf)$", f)
date = s.group(2)
d = {
@ -1107,7 +1113,7 @@ class WireguardConfiguration:
split[1] = newData[key]
original[line] = " = ".join(split)
if isinstance(getattr(self, key), bool):
setattr(self, key, _strToBool(newData[key]))
setattr(self, key, StringToBoolean(newData[key]))
else:
setattr(self, key, str(newData[key]))
dataChanged = True
@ -1150,7 +1156,46 @@ class WireguardConfiguration:
except Exception as e:
return False, str(e)
return True, None
def getAvailableIP(self, all: bool = False) -> tuple[bool, list[str]] | tuple[bool, None]:
if len(self.Address) < 0:
return False, None
address = self.Address.split(',')
existedAddress = []
availableAddress = []
for p in self.Peers:
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
try:
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
except ValueError as error:
print(f"[WGDashboard] Error: {configName} peer {p.id} have invalid ip")
for p in self.getRestrictedPeersList():
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
for i in address:
addressSplit, cidr = i.split('/')
existedAddress.append(ipaddress.ip_address(addressSplit.replace(" ", "")))
for i in address:
network = ipaddress.ip_network(i.replace(" ", ""), False)
count = 0
for h in network.hosts():
if h not in existedAddress:
availableAddress.append(ipaddress.ip_network(h).compressed)
count += 1
if not all:
if network.version == 6 and count > 255:
break
return True, availableAddress
"""
Peer
"""
class Peer:
def __init__(self, tableData, configuration: WireguardConfiguration):
self.configuration = configuration
@ -1200,16 +1245,16 @@ class Peer:
if allowed_ip in existingAllowedIps:
return ResponseObject(False, "Allowed IP already taken by another peer")
if not _checkIPWithRange(endpoint_allowed_ip):
if not ValidateIPAddressesWithRange(endpoint_allowed_ip):
return ResponseObject(False, f"Endpoint Allowed IPs format is incorrect")
if len(dns_addresses) > 0 and not _checkDNS(dns_addresses):
if len(dns_addresses) > 0 and not ValidateDNSAddress(dns_addresses):
return ResponseObject(False, f"DNS format is incorrect")
if mtu < 0 or mtu > 1460:
return ResponseObject(False, "MTU format is not correct")
if keepalive < 0:
return ResponseObject(False, "Persistent Keepalive format is not correct")
if len(private_key) > 0:
pubKey = _generatePublicKey(private_key)
pubKey = GenerateWireguardPublicKey(private_key)
if not pubKey[0] or pubKey[1] != self.id:
return ResponseObject(False, "Private key does not match with the public key")
try:
@ -1297,20 +1342,10 @@ PersistentKeepalive = {str(self.keepalive)}
except Exception as e:
return False
return True
# Regex Match
def regex_match(regex, text):
pattern = re.compile(regex)
return pattern.search(text) is not None
def get_remote_endpoint():
# Thanks @NOXICS
import socket
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("1.1.1.1", 80)) # Connecting to a public IP
wgd_remote_endpoint = s.getsockname()[0]
return str(wgd_remote_endpoint)
"""
Dashboard API Key
"""
class DashboardAPIKey:
def __init__(self, Key: str, CreatedAt: str, ExpiredAt: str):
self.Key = Key
@ -1320,6 +1355,9 @@ class DashboardAPIKey:
def toJson(self):
return self.__dict__
"""
Dashboard Configuration
"""
class DashboardConfig:
def __init__(self):
@ -1353,7 +1391,7 @@ class DashboardConfig:
"peer_global_DNS": "1.1.1.1",
"peer_endpoint_allowed_ip": "0.0.0.0/0",
"peer_display_mode": "grid",
"remote_endpoint": get_remote_endpoint(),
"remote_endpoint": GetRemoteEndpoint(),
"peer_MTU": "1420",
"peer_keep_alive": "21"
},
@ -1406,7 +1444,7 @@ class DashboardConfig:
if type(value) is str and len(value) == 0:
return False, "Field cannot be empty!"
if key == "peer_global_dns":
return _checkDNS(value)
return ValidateDNSAddress(value)
if key == "peer_endpoint_allowed_ip":
value = value.split(",")
for i in value:
@ -1505,128 +1543,9 @@ class DashboardConfig:
the_dict[section][key] = self.GetConfig(section, key)[1]
return the_dict
'''
Private Functions
'''
def _strToBool(value: str) -> bool:
return value.lower() in ("yes", "true", "t", "1", 1)
def _regexMatch(regex, text):
pattern = re.compile(regex)
return pattern.search(text) is not None
def _getConfigurationList(startup: bool = False):
confs = os.listdir(DashboardConfig.GetConfig("Server", "wg_conf_path")[1])
confs.sort()
for i in confs:
if _regexMatch("^(.{1,}).(conf)$", i):
i = i.replace('.conf', '')
try:
if i in WireguardConfigurations.keys():
if WireguardConfigurations[i].configurationFileChanged():
WireguardConfigurations[i] = WireguardConfiguration(i)
else:
WireguardConfigurations[i] = WireguardConfiguration(i, startup=startup)
except WireguardConfiguration.InvalidConfigurationFileException as e:
print(f"{i} have an invalid configuration file.")
def _checkIPWithRange(ip):
ip_patterns = (
r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|\/)){4}([0-9]{1,2})(,|$)",
r"[0-9a-fA-F]{0,4}(:([0-9a-fA-F]{0,4})){1,7}\/([0-9]{1,3})(,|$)"
)
for match_pattern in ip_patterns:
match_result = regex_match(match_pattern, ip)
if match_result:
result = match_result
break
else:
result = None
return result
def _checkIP(ip):
ip_patterns = (
r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|$)){4}",
r"[0-9a-fA-F]{0,4}(:([0-9a-fA-F]{0,4})){1,7}$"
)
for match_pattern in ip_patterns:
match_result = regex_match(match_pattern, ip)
if match_result:
result = match_result
break
else:
result = None
return result
def _checkDNS(dns):
dns = dns.replace(' ', '').split(',')
for i in dns:
if not _checkIP(i) and not regex_match(r"(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z]{0,61}[a-z]", i):
return False, f"{i} does not appear to be an valid DNS address"
return True, ""
def _generatePublicKey(privateKey) -> tuple[bool, str] | tuple[bool, None]:
try:
publicKey = subprocess.check_output(f"wg pubkey", input=privateKey.encode(), shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None
def _generatePrivateKey() -> [bool, str]:
try:
publicKey = subprocess.check_output(f"wg genkey", shell=True,
stderr=subprocess.STDOUT)
return True, publicKey.decode().strip('\n')
except subprocess.CalledProcessError:
return False, None
def _getWireguardConfigurationAvailableIP(configName: str, all: bool = False) -> tuple[bool, list[str]] | tuple[bool, None]:
if configName not in WireguardConfigurations.keys():
return False, None
configuration = WireguardConfigurations[configName]
if len(configuration.Address) > 0:
address = configuration.Address.split(',')
existedAddress = []
availableAddress = []
for p in configuration.Peers:
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
try:
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
except ValueError as error:
print(f"[WGDashboard] Error: {configName} peer {p.id} have invalid ip")
for p in configuration.getRestrictedPeersList():
if len(p.allowed_ip) > 0:
add = p.allowed_ip.split(',')
for i in add:
a, c = i.split('/')
existedAddress.append(ipaddress.ip_address(a.replace(" ", "")))
for i in address:
addressSplit, cidr = i.split('/')
existedAddress.append(ipaddress.ip_address(addressSplit.replace(" ", "")))
for i in address:
network = ipaddress.ip_network(i.replace(" ", ""), False)
count = 0
for h in network.hosts():
if h not in existedAddress:
availableAddress.append(ipaddress.ip_network(h).compressed)
count += 1
if not all:
if network.version == 6 and count > 255:
break
return True, availableAddress
return False, None
"""
Database Connection Functions
"""
sqldb = sqlite3.connect(os.path.join(CONFIGURATION_PATH, 'db', 'wgdashboard.db'), check_same_thread=False)
sqldb.row_factory = sqlite3.Row
@ -1642,7 +1561,6 @@ def sqlSelect(statement: str, paramters: tuple = ()) -> sqlite3.Cursor:
print("[WGDashboard] SQLite Error:" + str(error) + " | Statement: " + statement)
return []
def sqlUpdate(statement: str, paramters: tuple = ()) -> sqlite3.Cursor:
with sqldb:
cursor = sqldb.cursor()
@ -1654,6 +1572,7 @@ def sqlUpdate(statement: str, paramters: tuple = ()) -> sqlite3.Cursor:
except sqlite3.OperationalError as error:
print("[WGDashboard] SQLite Error:" + str(error) + " | Statement: " + statement)
DashboardConfig = DashboardConfig()
_, APP_PREFIX = DashboardConfig.GetConfig("Server", "app_prefix")
cors = CORS(app, resources={rf"{APP_PREFIX}/api/*": {
@ -1780,7 +1699,7 @@ def API_SignOut():
@app.route(f'{APP_PREFIX}/api/getWireguardConfigurations', methods=["GET"])
def API_getWireguardConfigurations():
_getConfigurationList()
InitWireguardConfigurationsList()
return ResponseObject(data=[wc for wc in WireguardConfigurations.values()])
@app.route(f'{APP_PREFIX}/api/addWireguardConfiguration', methods=["POST"])
@ -1905,7 +1824,7 @@ def API_getAllWireguardConfigurationBackup():
files.sort(key=lambda x: x[1], reverse=True)
for f, ct in files:
if _regexMatch(r"^(.*)_(.*)\.(conf)$", f):
if RegexMatch(r"^(.*)_(.*)\.(conf)$", f):
s = re.search(r"^(.*)_(.*)\.(conf)$", f)
name = s.group(1)
if name not in existingConfiguration:
@ -1981,7 +1900,7 @@ def API_updateDashboardConfigurationItem():
if data['section'] == "Server":
if data['key'] == 'wg_conf_path':
WireguardConfigurations.clear()
_getConfigurationList()
InitWireguardConfigurationsList()
return ResponseObject(True, data=DashboardConfig.GetConfig(data["section"], data["key"])[1])
@ -2168,7 +2087,7 @@ def API_addPeers(configName):
return ResponseObject(False, "Please provide at least public_key and allowed_ips")
if not config.getStatus():
config.toggleConfiguration()
availableIps = _getWireguardConfigurationAvailableIP(configName)
availableIps = config.getAvailableIP()
if bulkAdd:
if type(preshared_key_bulkAdd) is not bool:
preshared_key_bulkAdd = False
@ -2182,11 +2101,11 @@ def API_addPeers(configName):
f"The maximum number of peers can add is {len(availableIps[1])}")
keyPairs = []
for i in range(bulkAddAmount):
newPrivateKey = _generatePrivateKey()[1]
newPrivateKey = GenerateWireguardPrivateKey()[1]
keyPairs.append({
"private_key": newPrivateKey,
"id": _generatePublicKey(newPrivateKey)[1],
"preshared_key": (_generatePrivateKey()[1] if preshared_key_bulkAdd else ""),
"id": GenerateWireguardPublicKey(newPrivateKey)[1],
"preshared_key": (GenerateWireguardPrivateKey()[1] if preshared_key_bulkAdd else ""),
"allowed_ip": availableIps[1][i],
"name": f"BulkPeer #{(i + 1)}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
"DNS": dns_addresses,
@ -2257,7 +2176,9 @@ def API_downloadAllPeers(configName):
@app.get(f"{APP_PREFIX}/api/getAvailableIPs/<configName>")
def API_getAvailableIPs(configName):
status, ips = _getWireguardConfigurationAvailableIP(configName)
if configName not in WireguardConfigurations.keys():
return ResponseObject(False, "Configuration does not exist")
status, ips = WireguardConfigurations.get(configName).getAvailableIP()
return ResponseObject(status=status, data=ips)
@app.get(f'{APP_PREFIX}/api/getWireguardConfigurationInfo')
@ -2556,7 +2477,6 @@ def API_Locale_Update():
return ResponseObject(False, "Please specify a lang_id")
Locale.updateLanguage(data['lang_id'])
return ResponseObject(data=Locale.getLanguage())
@app.get(f'{APP_PREFIX}/')
def index():
@ -2593,6 +2513,21 @@ def gunicornConfig():
_, app_port = DashboardConfig.GetConfig("Server", "app_port")
return app_ip, app_port
def InitWireguardConfigurationsList(startup: bool = False):
confs = os.listdir(DashboardConfig.GetConfig("Server", "wg_conf_path")[1])
confs.sort()
for i in confs:
if RegexMatch("^(.{1,}).(conf)$", i):
i = i.replace('.conf', '')
try:
if i in WireguardConfigurations.keys():
if WireguardConfigurations[i].configurationFileChanged():
WireguardConfigurations[i] = WireguardConfiguration(i)
else:
WireguardConfigurations[i] = WireguardConfiguration(i, startup=startup)
except WireguardConfiguration.InvalidConfigurationFileException as e:
print(f"{i} have an invalid configuration file.")
AllPeerShareLinks: PeerShareLinks = PeerShareLinks()
AllPeerJobs: PeerJobs = PeerJobs()
JobLogger: PeerJobLogger = PeerJobLogger()
@ -2602,7 +2537,7 @@ _, app_port = DashboardConfig.GetConfig("Server", "app_port")
_, WG_CONF_PATH = DashboardConfig.GetConfig("Server", "wg_conf_path")
WireguardConfigurations: dict[str, WireguardConfiguration] = {}
_getConfigurationList(startup=True)
InitWireguardConfigurationsList(startup=True)
def startThreads():
bgThread = threading.Thread(target=backGroundThread)