Update dashboard.py

- Updated `DashboardConfig` class to use SqlAlchemy, tested with SQLite and Postgresql
This commit is contained in:
Donald Zou
2025-05-07 18:40:24 +08:00
parent 409acc9f1a
commit 922d8eab58

View File

@@ -25,7 +25,10 @@ from modules.PeerJob import PeerJob
from modules.SystemStatus import SystemStatus from modules.SystemStatus import SystemStatus
SystemStatus = SystemStatus() SystemStatus = SystemStatus()
DASHBOARD_VERSION = 'v4.2.5' from sqlalchemy_utils import database_exists, create_database
import sqlalchemy as db
DASHBOARD_VERSION = 'v4.2.3'
CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.') CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.')
DB_PATH = os.path.join(CONFIGURATION_PATH, 'db') DB_PATH = os.path.join(CONFIGURATION_PATH, 'db')
@@ -180,7 +183,6 @@ class PeerJobs:
def runJob(self): def runJob(self):
needToDelete = [] needToDelete = []
self.__getJobs()
for job in self.Jobs: for job in self.Jobs:
c = WireguardConfigurations.get(job.Configuration) c = WireguardConfigurations.get(job.Configuration)
if c is not None: if c is not None:
@@ -433,6 +435,7 @@ class WireguardConfiguration:
original = [l.rstrip("\n") for l in f.readlines()] original = [l.rstrip("\n") for l in f.readlines()]
try: try:
start = original.index("[Interface]") start = original.index("[Interface]")
# Clean # Clean
for i in range(start, len(original)): for i in range(start, len(original)):
if original[i] == "[Peer]": if original[i] == "[Peer]":
@@ -459,10 +462,7 @@ class WireguardConfiguration:
setattr(self, key, StringToBoolean(value)) setattr(self, key, StringToBoolean(value))
else: else:
if len(getattr(self, key)) > 0: if len(getattr(self, key)) > 0:
if key not in ["PostUp", "PostDown", "PreUp", "PreDown"]: setattr(self, key, f"{getattr(self, key)}, {value}")
setattr(self, key, f"{getattr(self, key)}, {value}")
else:
setattr(self, key, f"{getattr(self, key)}; {value}")
else: else:
setattr(self, key, value) setattr(self, key, value)
except ValueError as e: except ValueError as e:
@@ -1212,15 +1212,15 @@ AmneziaWG Configuration
""" """
class AmneziaWireguardConfiguration(WireguardConfiguration): class AmneziaWireguardConfiguration(WireguardConfiguration):
def __init__(self, name: str = None, data: dict = None, backup: dict = None, startup: bool = False): def __init__(self, name: str = None, data: dict = None, backup: dict = None, startup: bool = False):
self.Jc = "" self.Jc = 0
self.Jmin = "" self.Jmin = 0
self.Jmax = "" self.Jmax = 0
self.S1 = "" self.S1 = 0
self.S2 = "" self.S2 = 0
self.H1 = "" self.H1 = 1
self.H2 = "" self.H2 = 2
self.H3 = "" self.H3 = 3
self.H4 = "" self.H4 = 4
super().__init__(name, data, backup, startup, wg=False) super().__init__(name, data, backup, startup, wg=False)
@@ -1245,7 +1245,6 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
}, },
"ConnectedPeers": len(list(filter(lambda x: x.status == "running", self.Peers))), "ConnectedPeers": len(list(filter(lambda x: x.status == "running", self.Peers))),
"TotalPeers": len(self.Peers), "TotalPeers": len(self.Peers),
"Table": self.Table,
"Protocol": self.Protocol, "Protocol": self.Protocol,
"Jc": self.Jc, "Jc": self.Jc,
"Jmin": self.Jmin, "Jmin": self.Jmin,
@@ -1765,7 +1764,7 @@ class DashboardConfig:
def __init__(self): def __init__(self):
if not os.path.exists(DASHBOARD_CONF): if not os.path.exists(DASHBOARD_CONF):
open(DASHBOARD_CONF, "x") open(DASHBOARD_CONF, "x")
self.__config = configparser.RawConfigParser(strict=False) self.__config = configparser.ConfigParser(strict=False)
self.__config.read_file(open(DASHBOARD_CONF, "r+")) self.__config.read_file(open(DASHBOARD_CONF, "r+"))
self.hiddenAttribute = ["totp_key", "auth_req"] self.hiddenAttribute = ["totp_key", "auth_req"]
self.__default = { self.__default = {
@@ -1828,39 +1827,80 @@ class DashboardConfig:
exist, currentData = self.GetConfig(section, key) exist, currentData = self.GetConfig(section, key)
if not exist: if not exist:
self.SetConfig(section, key, value, True) self.SetConfig(section, key, value, True)
self.engine = db.create_engine(self.getConnectionString('wgdashboard'))
self.db = self.engine.connect()
self.dbMetadata = db.MetaData()
self.__createAPIKeyTable() self.__createAPIKeyTable()
self.DashboardAPIKeys = self.__getAPIKeys() self.DashboardAPIKeys = self.__getAPIKeys()
self.APIAccessed = False self.APIAccessed = False
self.SetConfig("Server", "version", DASHBOARD_VERSION) self.SetConfig("Server", "version", DASHBOARD_VERSION)
def getConnectionString(self, database) -> str or None: def getConnectionString(self, database) -> str or None:
cn = None
if self.GetConfig("Database", "type")[1] == "sqlite": if self.GetConfig("Database", "type")[1] == "sqlite":
return f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}' cn = f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}'
elif self.GetConfig("Database", "type")[1] == "postgresql": elif self.GetConfig("Database", "type")[1] == "postgresql":
return f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}' cn = f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}'
return None
if not database_exists(cn):
create_database(cn)
return cn
def __createAPIKeyTable(self): def __createAPIKeyTable(self):
existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall() self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata,
if len(existingTable) == 0: db.Column("Key", db.String, nullable=False, primary_key=True),
sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)") db.Column("CreatedAt",
(db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP),
server_default=db.func.now()
),
db.Column("ExpiredAt",
(db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP)
)
)
self.dbMetadata.create_all(self.engine)
# existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall()
# if len(existingTable) == 0:
# sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)")
def __getAPIKeys(self) -> list[DashboardAPIKey]: def __getAPIKeys(self) -> list[DashboardAPIKey]:
keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall() # keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall()
keys = self.db.execute(self.apiKeyTable.select().where(
db.or_(self.apiKeyTable.columns.ExpiredAt == None, self.apiKeyTable.columns.ExpiredAt > datetime.now())
)).fetchall()
fKeys = [] fKeys = []
for k in keys: for k in keys:
fKeys.append(DashboardAPIKey(k[0], k[1].strftime("%Y-%m-%d %H:%M:%S"), (k[2].strftime("%Y-%m-%d %H:%M:%S") if k[2] else None)))
fKeys.append(DashboardAPIKey(*k))
return fKeys return fKeys
def createAPIKeys(self, ExpiredAt = None): def createAPIKeys(self, ExpiredAt = None):
newKey = secrets.token_urlsafe(32) newKey = secrets.token_urlsafe(32)
sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,)) # sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,))
with self.engine.begin() as conn:
conn.execute(
self.apiKeyTable.insert().values({
"Key": newKey,
"ExpiredAt": ExpiredAt
})
)
self.DashboardAPIKeys = self.__getAPIKeys() self.DashboardAPIKeys = self.__getAPIKeys()
def deleteAPIKey(self, key): def deleteAPIKey(self, key):
sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, )) # sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, ))
with self.engine.begin() as conn:
conn.execute(
self.apiKeyTable.update().values({
"ExpiredAt": datetime.now(),
}).where(self.apiKeyTable.columns.Key == key)
)
self.DashboardAPIKeys = self.__getAPIKeys() self.DashboardAPIKeys = self.__getAPIKeys()
def __configValidation(self, section : str, key: str, value: Any) -> [bool, str]: def __configValidation(self, section : str, key: str, value: Any) -> [bool, str]:
@@ -2062,7 +2102,7 @@ def auth_req():
else: else:
DashboardConfig.APIAccessed = False DashboardConfig.APIAccessed = False
whiteList = [ whiteList = [
'/static/', 'validateAuthentication', 'authenticate', '/static/', 'validateAuthentication', 'authenticate', 'getDashboardConfiguration',
'getDashboardTheme', 'getDashboardVersion', 'sharePeer/get', 'isTotpEnabled', 'locale', 'getDashboardTheme', 'getDashboardVersion', 'sharePeer/get', 'isTotpEnabled', 'locale',
'/fileDownload' '/fileDownload'
] ]
@@ -2393,7 +2433,7 @@ def API_updateDashboardConfigurationItem():
valid, msg = DashboardConfig.SetConfig( valid, msg = DashboardConfig.SetConfig(
data["section"], data["key"], data['value']) data["section"], data["key"], data['value'])
if not valid: if not valid:
return ResponseObject(False, msg) return ResponseObject(False, msg, status_code=404)
if data['section'] == "Server": if data['section'] == "Server":
if data['key'] == 'wg_conf_path': if data['key'] == 'wg_conf_path':
WireguardConfigurations.clear() WireguardConfigurations.clear()
@@ -3126,19 +3166,16 @@ def peerInformationBackgroundThread():
time.sleep(10) time.sleep(10)
while True: while True:
with app.app_context(): with app.app_context():
try: for c in WireguardConfigurations.values():
curKeys = list(WireguardConfigurations.keys()) if c.getStatus():
for name in curKeys: try:
if name in WireguardConfigurations.keys() and WireguardConfigurations.get(name) is not None: c.getPeersTransfer()
c = WireguardConfigurations.get(name) c.getPeersLatestHandshake()
if c.getStatus(): c.getPeersEndpoint()
c.getPeersTransfer() c.getPeersList()
c.getPeersLatestHandshake() c.getRestrictedPeersList()
c.getPeersEndpoint() except Exception as e:
c.getPeersList() print(f"[WGDashboard] Background Thread #1 Error: {str(e)}", flush=True)
c.getRestrictedPeersList()
except Exception as e:
print(f"[WGDashboard] Background Thread #1 Error: {str(e)}", flush=True)
time.sleep(10) time.sleep(10)
def peerJobScheduleBackgroundThread(): def peerJobScheduleBackgroundThread():