From 922d8eab58b0132e3a85b38ded46d1c1023bfb0e Mon Sep 17 00:00:00 2001 From: Donald Zou Date: Wed, 7 May 2025 18:40:24 +0800 Subject: [PATCH] Update dashboard.py - Updated `DashboardConfig` class to use SqlAlchemy, tested with SQLite and Postgresql --- src/dashboard.py | 125 ++++++++++++++++++++++++++++++----------------- 1 file changed, 81 insertions(+), 44 deletions(-) diff --git a/src/dashboard.py b/src/dashboard.py index 88e3f769..f0eca710 100644 --- a/src/dashboard.py +++ b/src/dashboard.py @@ -25,7 +25,10 @@ from modules.PeerJob import PeerJob from modules.SystemStatus import SystemStatus SystemStatus = SystemStatus() -DASHBOARD_VERSION = 'v4.2.5' +from sqlalchemy_utils import database_exists, create_database +import sqlalchemy as db + +DASHBOARD_VERSION = 'v4.2.3' CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.') DB_PATH = os.path.join(CONFIGURATION_PATH, 'db') @@ -180,7 +183,6 @@ class PeerJobs: def runJob(self): needToDelete = [] - self.__getJobs() for job in self.Jobs: c = WireguardConfigurations.get(job.Configuration) if c is not None: @@ -433,6 +435,7 @@ class WireguardConfiguration: original = [l.rstrip("\n") for l in f.readlines()] try: start = original.index("[Interface]") + # Clean for i in range(start, len(original)): if original[i] == "[Peer]": @@ -445,7 +448,7 @@ class WireguardConfiguration: setattr(self, key, False) else: setattr(self, key, "") - + # Set for i in range(start, len(original)): if original[i] == "[Peer]": @@ -459,10 +462,7 @@ class WireguardConfiguration: setattr(self, key, StringToBoolean(value)) else: if len(getattr(self, key)) > 0: - if key not in ["PostUp", "PostDown", "PreUp", "PreDown"]: - setattr(self, key, f"{getattr(self, key)}, {value}") - else: - setattr(self, key, f"{getattr(self, key)}; {value}") + setattr(self, key, f"{getattr(self, key)}, {value}") else: setattr(self, key, value) except ValueError as e: @@ -1212,15 +1212,15 @@ AmneziaWG Configuration """ class AmneziaWireguardConfiguration(WireguardConfiguration): def __init__(self, name: str = None, data: dict = None, backup: dict = None, startup: bool = False): - self.Jc = "" - self.Jmin = "" - self.Jmax = "" - self.S1 = "" - self.S2 = "" - self.H1 = "" - self.H2 = "" - self.H3 = "" - self.H4 = "" + self.Jc = 0 + self.Jmin = 0 + self.Jmax = 0 + self.S1 = 0 + self.S2 = 0 + self.H1 = 1 + self.H2 = 2 + self.H3 = 3 + self.H4 = 4 super().__init__(name, data, backup, startup, wg=False) @@ -1245,7 +1245,6 @@ class AmneziaWireguardConfiguration(WireguardConfiguration): }, "ConnectedPeers": len(list(filter(lambda x: x.status == "running", self.Peers))), "TotalPeers": len(self.Peers), - "Table": self.Table, "Protocol": self.Protocol, "Jc": self.Jc, "Jmin": self.Jmin, @@ -1765,7 +1764,7 @@ class DashboardConfig: def __init__(self): if not os.path.exists(DASHBOARD_CONF): open(DASHBOARD_CONF, "x") - self.__config = configparser.RawConfigParser(strict=False) + self.__config = configparser.ConfigParser(strict=False) self.__config.read_file(open(DASHBOARD_CONF, "r+")) self.hiddenAttribute = ["totp_key", "auth_req"] self.__default = { @@ -1828,39 +1827,80 @@ class DashboardConfig: exist, currentData = self.GetConfig(section, key) if not exist: self.SetConfig(section, key, value, True) + + self.engine = db.create_engine(self.getConnectionString('wgdashboard')) + self.db = self.engine.connect() + self.dbMetadata = db.MetaData() + + + + self.__createAPIKeyTable() self.DashboardAPIKeys = self.__getAPIKeys() self.APIAccessed = False self.SetConfig("Server", "version", DASHBOARD_VERSION) def getConnectionString(self, database) -> str or None: + cn = None if self.GetConfig("Database", "type")[1] == "sqlite": - return f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}' + cn = f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}' elif self.GetConfig("Database", "type")[1] == "postgresql": - return f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}' - return None + cn = f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}' + + if not database_exists(cn): + create_database(cn) + + return cn def __createAPIKeyTable(self): - existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall() - if len(existingTable) == 0: - sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)") + self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata, + db.Column("Key", db.String, nullable=False, primary_key=True), + db.Column("CreatedAt", + (db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP), + server_default=db.func.now() + ), + db.Column("ExpiredAt", + (db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP) + ) + ) + self.dbMetadata.create_all(self.engine) + + # existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall() + # if len(existingTable) == 0: + # sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)") def __getAPIKeys(self) -> list[DashboardAPIKey]: - keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall() + # keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall() + keys = self.db.execute(self.apiKeyTable.select().where( + db.or_(self.apiKeyTable.columns.ExpiredAt == None, self.apiKeyTable.columns.ExpiredAt > datetime.now()) + )).fetchall() fKeys = [] for k in keys: - - fKeys.append(DashboardAPIKey(*k)) + fKeys.append(DashboardAPIKey(k[0], k[1].strftime("%Y-%m-%d %H:%M:%S"), (k[2].strftime("%Y-%m-%d %H:%M:%S") if k[2] else None))) return fKeys def createAPIKeys(self, ExpiredAt = None): newKey = secrets.token_urlsafe(32) - sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,)) + # sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,)) + with self.engine.begin() as conn: + conn.execute( + self.apiKeyTable.insert().values({ + "Key": newKey, + "ExpiredAt": ExpiredAt + }) + ) self.DashboardAPIKeys = self.__getAPIKeys() def deleteAPIKey(self, key): - sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, )) + # sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, )) + with self.engine.begin() as conn: + conn.execute( + self.apiKeyTable.update().values({ + "ExpiredAt": datetime.now(), + }).where(self.apiKeyTable.columns.Key == key) + ) + self.DashboardAPIKeys = self.__getAPIKeys() def __configValidation(self, section : str, key: str, value: Any) -> [bool, str]: @@ -2062,7 +2102,7 @@ def auth_req(): else: DashboardConfig.APIAccessed = False whiteList = [ - '/static/', 'validateAuthentication', 'authenticate', + '/static/', 'validateAuthentication', 'authenticate', 'getDashboardConfiguration', 'getDashboardTheme', 'getDashboardVersion', 'sharePeer/get', 'isTotpEnabled', 'locale', '/fileDownload' ] @@ -2393,7 +2433,7 @@ def API_updateDashboardConfigurationItem(): valid, msg = DashboardConfig.SetConfig( data["section"], data["key"], data['value']) if not valid: - return ResponseObject(False, msg) + return ResponseObject(False, msg, status_code=404) if data['section'] == "Server": if data['key'] == 'wg_conf_path': WireguardConfigurations.clear() @@ -3126,19 +3166,16 @@ def peerInformationBackgroundThread(): time.sleep(10) while True: with app.app_context(): - try: - curKeys = list(WireguardConfigurations.keys()) - for name in curKeys: - if name in WireguardConfigurations.keys() and WireguardConfigurations.get(name) is not None: - c = WireguardConfigurations.get(name) - if c.getStatus(): - c.getPeersTransfer() - c.getPeersLatestHandshake() - c.getPeersEndpoint() - c.getPeersList() - c.getRestrictedPeersList() - except Exception as e: - print(f"[WGDashboard] Background Thread #1 Error: {str(e)}", flush=True) + for c in WireguardConfigurations.values(): + if c.getStatus(): + try: + c.getPeersTransfer() + c.getPeersLatestHandshake() + c.getPeersEndpoint() + c.getPeersList() + c.getRestrictedPeersList() + except Exception as e: + print(f"[WGDashboard] Background Thread #1 Error: {str(e)}", flush=True) time.sleep(10) def peerJobScheduleBackgroundThread():