From c051ab56b46db890a40ec567ae520ab4cb49eea2 Mon Sep 17 00:00:00 2001 From: Donald Zou Date: Thu, 8 May 2025 17:27:49 +0800 Subject: [PATCH] Updated PeerShareLink to use SQLAlchemy --- src/dashboard.py | 107 ++++++++++++++++++++++----------- src/modules/DashboardLogger.py | 1 - src/modules/PeerJobLogger.py | 15 ++--- 3 files changed, 78 insertions(+), 45 deletions(-) diff --git a/src/dashboard.py b/src/dashboard.py index f0eca71..176ff13 100644 --- a/src/dashboard.py +++ b/src/dashboard.py @@ -237,47 +237,56 @@ class PeerJobs: Peer Share Link """ class PeerShareLink: - def __init__(self, ShareID:str, Configuration: str, Peer: str, ExpireDate: datetime, ShareDate: datetime): + def __init__(self, ShareID:str, Configuration: str, Peer: str, ExpireDate: datetime, SharedDate: datetime): self.ShareID = ShareID self.Peer = Peer self.Configuration = Configuration - self.ShareDate = ShareDate + self.SharedDate = SharedDate self.ExpireDate = ExpireDate - def toJson(self): return { "ShareID": self.ShareID, "Peer": self.Peer, "Configuration": self.Configuration, - "ExpireDate": self.ExpireDate + "ExpireDate": self.ExpireDate.strftime("%Y-%m-%d %H:%M:%S"), + "SharedDate": self.SharedDate.strftime("%Y-%m-%d %H:%M:%S"), } """ Peer Share Links """ class PeerShareLinks: - def __init__(self): + def __init__(self, DashboardConfig): self.Links: list[PeerShareLink] = [] - existingTables = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' and name = 'PeerShareLinks'").fetchall() - if len(existingTables) == 0: - sqlUpdate( - """ - CREATE TABLE PeerShareLinks ( - ShareID VARCHAR NOT NULL PRIMARY KEY, Configuration VARCHAR NOT NULL, Peer VARCHAR NOT NULL, - ExpireDate DATETIME, - SharedDate DATETIME DEFAULT (datetime('now', 'localtime')) - ) - """ - ) + self.engine = db.create_engine(DashboardConfig.getConnectionString("wgdashboard")) + self.metadata = db.MetaData() + self.peerShareLinksTable = db.Table( + 'PeerShareLinks', self.metadata, + db.Column('ShareID', db.String, nullable=False, primary_key=True), + db.Column('Configuration', db.String, nullable=False), + db.Column('Peer', db.String, nullable=False), + db.Column('ExpireDate', (db.DATETIME if DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite' else db.TIMESTAMP)), + db.Column('SharedDate', (db.DATETIME if DashboardConfig.GetConfig("Database", "type")[1] == 'sqlite' else db.TIMESTAMP), + server_default=db.func.now()), + ) + self.metadata.create_all(self.engine) self.__getSharedLinks() def __getSharedLinks(self): self.Links.clear() - allLinks = sqlSelect("SELECT * FROM PeerShareLinks WHERE ExpireDate IS NULL OR ExpireDate > datetime('now', 'localtime')").fetchall() - for link in allLinks: - self.Links.append(PeerShareLink(*link)) - + with self.engine.connect() as conn: + allLinks = conn.execute( + self.peerShareLinksTable.select().where( + db.or_(self.peerShareLinksTable.columns.ExpireDate == None, self.peerShareLinksTable.columns.ExpireDate > datetime.now()) + ) + ).mappings().fetchall() + for link in allLinks: + self.Links.append(PeerShareLink(**link)) + + + def getLink(self, Configuration: str, Peer: str) -> list[PeerShareLink]: + self.__getSharedLinks() return list(filter(lambda x : x.Configuration == Configuration and x.Peer == Peer, self.Links)) def getLinkByID(self, ShareID: str) -> list[PeerShareLink]: @@ -287,16 +296,40 @@ class PeerShareLinks: def addLink(self, Configuration: str, Peer: str, ExpireDate: datetime = None) -> tuple[bool, str]: try: newShareID = str(uuid.uuid4()) - if len(self.getLink(Configuration, Peer)) > 0: - sqlUpdate("UPDATE PeerShareLinks SET ExpireDate = datetime('now', 'localtime') WHERE Configuration = ? AND Peer = ?", (Configuration, Peer, )) - sqlUpdate("INSERT INTO PeerShareLinks (ShareID, Configuration, Peer, ExpireDate) VALUES (?, ?, ?, ?)", (newShareID, Configuration, Peer, ExpireDate, )) + with self.engine.begin() as conn: + if len(self.getLink(Configuration, Peer)) > 0: + conn.execute( + self.peerShareLinksTable.update().values( + { + "ExpireDate": datetime.now() + } + ).where(db.and_(self.peerShareLinksTable.columns.Configuration == Configuration, self.peerShareLinksTable.columns.Peer == Peer)) + ) + + conn.execute( + self.peerShareLinksTable.insert().values( + { + "ShareID": newShareID, + "Configuration": Configuration, + "Peer": Peer, + "ExpireDate": ExpireDate + } + ) + ) self.__getSharedLinks() except Exception as e: return False, str(e) return True, newShareID def updateLinkExpireDate(self, ShareID, ExpireDate: datetime = None) -> tuple[bool, str]: - sqlUpdate("UPDATE PeerShareLinks SET ExpireDate = ? WHERE ShareID = ?;", (ExpireDate, ShareID, )) + with self.engine.begin() as conn: + conn.execute( + self.peerShareLinksTable.update().values( + { + "ExpireDate": ExpireDate + } + ).where(db.and_(self.peerShareLinksTable.columns.ShareID == ShareID)) + ) self.__getSharedLinks() return True, "" @@ -1829,7 +1862,6 @@ class DashboardConfig: self.SetConfig(section, key, value, True) self.engine = db.create_engine(self.getConnectionString('wgdashboard')) - self.db = self.engine.connect() self.dbMetadata = db.MetaData() @@ -1871,13 +1903,18 @@ class DashboardConfig: def __getAPIKeys(self) -> list[DashboardAPIKey]: # keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall() - keys = self.db.execute(self.apiKeyTable.select().where( - db.or_(self.apiKeyTable.columns.ExpiredAt == None, self.apiKeyTable.columns.ExpiredAt > datetime.now()) - )).fetchall() - fKeys = [] - for k in keys: - fKeys.append(DashboardAPIKey(k[0], k[1].strftime("%Y-%m-%d %H:%M:%S"), (k[2].strftime("%Y-%m-%d %H:%M:%S") if k[2] else None))) - return fKeys + try: + with self.engine.connect() as conn: + keys = conn.execute(self.apiKeyTable.select().where( + db.or_(self.apiKeyTable.columns.ExpiredAt == None, self.apiKeyTable.columns.ExpiredAt > datetime.now()) + )).fetchall() + fKeys = [] + for k in keys: + fKeys.append(DashboardAPIKey(k[0], k[1].strftime("%Y-%m-%d %H:%M:%S"), (k[2].strftime("%Y-%m-%d %H:%M:%S") if k[2] else None))) + return fKeys + except Exception as e: + print("") + return [] def createAPIKeys(self, ExpiredAt = None): newKey = secrets.token_urlsafe(32) @@ -2554,7 +2591,7 @@ def API_sharePeer_create(): "This peer is already sharing. Please view data for shared link.", data=activeLink[0] ) - status, message = AllPeerShareLinks.addLink(Configuration, Peer, ExpireDate) + status, message = AllPeerShareLinks.addLink(Configuration, Peer, datetime.strptime(ExpireDate, "%Y-%m-%d %H:%M:%S")) if not status: return ResponseObject(status, message) return ResponseObject(data=AllPeerShareLinks.getLinkByID(message)) @@ -2571,7 +2608,7 @@ def API_sharePeer_update(): if len(AllPeerShareLinks.getLinkByID(ShareID)) == 0: return ResponseObject(False, "ShareID does not exist") - status, message = AllPeerShareLinks.updateLinkExpireDate(ShareID, ExpireDate) + status, message = AllPeerShareLinks.updateLinkExpireDate(ShareID, datetime.strptime(ExpireDate, "%Y-%m-%d %H:%M:%S")) if not status: return ResponseObject(status, message) return ResponseObject(data=AllPeerShareLinks.getLinkByID(ShareID)) @@ -3231,7 +3268,7 @@ def InitWireguardConfigurationsList(startup: bool = False): except WireguardConfigurations.InvalidConfigurationFileException as e: print(f"{i} have an invalid configuration file.") -AllPeerShareLinks: PeerShareLinks = PeerShareLinks() +AllPeerShareLinks: PeerShareLinks = PeerShareLinks(DashboardConfig) AllPeerJobs: PeerJobs = PeerJobs() JobLogger: PeerJobLogger = PeerJobLogger(CONFIGURATION_PATH, AllPeerJobs, DashboardConfig) DashboardLogger: DashboardLogger = DashboardLogger(CONFIGURATION_PATH, DashboardConfig) diff --git a/src/modules/DashboardLogger.py b/src/modules/DashboardLogger.py index 59ea664..d619343 100644 --- a/src/modules/DashboardLogger.py +++ b/src/modules/DashboardLogger.py @@ -11,7 +11,6 @@ class DashboardLogger: self.engine = db.create_engine(DashboardConfig.getConnectionString("wgdashboard_log")) if not database_exists(self.engine.url): create_database(self.engine.url) - self.loggerdb = self.engine.connect() self.metadata = db.MetaData() self.dashboardLoggerTable = db.Table('DashboardLog', self.metadata, db.Column('LogID', db.VARCHAR, nullable=False, primary_key=True), diff --git a/src/modules/PeerJobLogger.py b/src/modules/PeerJobLogger.py index 991f204..9ec035c 100644 --- a/src/modules/PeerJobLogger.py +++ b/src/modules/PeerJobLogger.py @@ -9,11 +9,7 @@ from sqlalchemy_utils import database_exists, create_database class PeerJobLogger: def __init__(self, CONFIGURATION_PATH, AllPeerJobs, DashboardConfig): - self.engine = db.create_engine(DashboardConfig.getConnectionString("wgdashboard_log")) - if not database_exists(self.engine.url): - create_database(self.engine.url) - - self.loggerdb = self.engine.connect() + self.engine = db.create_engine(DashboardConfig.getConnectionString("wgdashboard_log")) self.metadata = db.MetaData() self.jobLogTable = db.Table('JobLog', self.metadata, db.Column('LogID', db.String, nullable=False, primary_key=True), @@ -52,10 +48,11 @@ class PeerJobLogger: stmt = self.jobLogTable.select().where(self.jobLogTable.columns.JobID.in_( allJobsID )) - table = self.loggerdb.execute(stmt).fetchall() - for l in table: - logs.append( - Log(l.LogID, l.JobID, l.LogDate.strftime("%Y-%m-%d %H:%M:%S"), l.Status, l.Message)) + with self.engine.connect() as conn: + table = conn.execute(stmt).fetchall() + for l in table: + logs.append( + Log(l.LogID, l.JobID, l.LogDate.strftime("%Y-%m-%d %H:%M:%S"), l.Status, l.Message)) except Exception as e: print(e) return logs