diff --git a/src/dashboard.py b/src/dashboard.py index 0d70141..f0eca71 100644 --- a/src/dashboard.py +++ b/src/dashboard.py @@ -25,6 +25,9 @@ from modules.PeerJob import PeerJob from modules.SystemStatus import SystemStatus SystemStatus = SystemStatus() +from sqlalchemy_utils import database_exists, create_database +import sqlalchemy as db + DASHBOARD_VERSION = 'v4.2.3' CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.') @@ -1824,39 +1827,80 @@ class DashboardConfig: exist, currentData = self.GetConfig(section, key) if not exist: self.SetConfig(section, key, value, True) + + self.engine = db.create_engine(self.getConnectionString('wgdashboard')) + self.db = self.engine.connect() + self.dbMetadata = db.MetaData() + + + + self.__createAPIKeyTable() self.DashboardAPIKeys = self.__getAPIKeys() self.APIAccessed = False self.SetConfig("Server", "version", DASHBOARD_VERSION) def getConnectionString(self, database) -> str or None: + cn = None if self.GetConfig("Database", "type")[1] == "sqlite": - return f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}' + cn = f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}' elif self.GetConfig("Database", "type")[1] == "postgresql": - return f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}' - return None + cn = f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}' + + if not database_exists(cn): + create_database(cn) + + return cn def __createAPIKeyTable(self): - existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall() - if len(existingTable) == 0: - sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)") + self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata, + db.Column("Key", db.String, nullable=False, primary_key=True), + db.Column("CreatedAt", + (db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP), + server_default=db.func.now() + ), + db.Column("ExpiredAt", + (db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP) + ) + ) + self.dbMetadata.create_all(self.engine) + + # existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall() + # if len(existingTable) == 0: + # sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)") def __getAPIKeys(self) -> list[DashboardAPIKey]: - keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall() + # keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall() + keys = self.db.execute(self.apiKeyTable.select().where( + db.or_(self.apiKeyTable.columns.ExpiredAt == None, self.apiKeyTable.columns.ExpiredAt > datetime.now()) + )).fetchall() fKeys = [] for k in keys: - - fKeys.append(DashboardAPIKey(*k)) + fKeys.append(DashboardAPIKey(k[0], k[1].strftime("%Y-%m-%d %H:%M:%S"), (k[2].strftime("%Y-%m-%d %H:%M:%S") if k[2] else None))) return fKeys def createAPIKeys(self, ExpiredAt = None): newKey = secrets.token_urlsafe(32) - sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,)) + # sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,)) + with self.engine.begin() as conn: + conn.execute( + self.apiKeyTable.insert().values({ + "Key": newKey, + "ExpiredAt": ExpiredAt + }) + ) self.DashboardAPIKeys = self.__getAPIKeys() def deleteAPIKey(self, key): - sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, )) + # sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, )) + with self.engine.begin() as conn: + conn.execute( + self.apiKeyTable.update().values({ + "ExpiredAt": datetime.now(), + }).where(self.apiKeyTable.columns.Key == key) + ) + self.DashboardAPIKeys = self.__getAPIKeys() def __configValidation(self, section : str, key: str, value: Any) -> [bool, str]: