From bf28983229d582455e5965398c998d253f90e0f8 Mon Sep 17 00:00:00 2001 From: Daan Selen Date: Fri, 19 Sep 2025 20:55:47 +0200 Subject: [PATCH] feat: make the prefix configurable --- src/modules/ConnectionString.py | 57 ++++++++++++------- src/modules/DashboardClients.py | 16 +++--- src/modules/DashboardClientsPeerAssignment.py | 8 +-- src/modules/DashboardClientsTOTP.py | 6 +- src/modules/DashboardConfig.py | 23 ++------ src/modules/DashboardLogger.py | 6 +- src/modules/DashboardWebHooks.py | 6 +- src/modules/NewConfigurationTemplates.py | 4 +- src/modules/PeerJobLogger.py | 4 +- src/modules/PeerJobs.py | 4 +- src/modules/PeerShareLinks.py | 4 +- src/modules/WireguardConfiguration.py | 4 +- 12 files changed, 73 insertions(+), 69 deletions(-) diff --git a/src/modules/ConnectionString.py b/src/modules/ConnectionString.py index 77f69644..128bbb1e 100644 --- a/src/modules/ConnectionString.py +++ b/src/modules/ConnectionString.py @@ -1,25 +1,44 @@ +# ConnectionString.py import configparser import os from sqlalchemy_utils import database_exists, create_database -from flask import current_app -def ConnectionString(database) -> str: - parser = configparser.ConfigParser(strict=False) - parser.read_file(open('wg-dashboard.ini', "r+")) - sqlitePath = os.path.join("db") - if not os.path.isdir(sqlitePath): - os.mkdir(sqlitePath) - if parser.get("Database", "type") == "postgresql": - cn = f'postgresql+psycopg://{parser.get("Database", "username")}:{parser.get("Database", "password")}@{parser.get("Database", "host")}/{database}' - elif parser.get("Database", "type") == "mysql": - cn = f'mysql+pymysql://{parser.get("Database", "username")}:{parser.get("Database", "password")}@{parser.get("Database", "host")}/{database}' +# Read and parse the INI file once at startup +parser = configparser.ConfigParser(strict=False) +parser.read("wg-dashboard.ini") + +# Ensure SQLite folder exists +SQLITE_PATH = "db" +os.makedirs(SQLITE_PATH, exist_ok=True) + +DEFAULT_DB = "wgdashboard" +DEFAULT_LOG_DB = "wgdashboard_log" +DEFAULT_JOB_DB = "wgdashboard_log" + +def ConnectionString(database_name: str) -> str: + """ + Returns a SQLAlchemy-compatible connection string for the chosen database. + Creates the database if it doesn't exist. + """ + db_type = parser.get("Database", "type") + db_prefix = parser.get("Database", "prefix") + database_name = f"{db_prefix}_{database_name}" + + if db_type == "postgresql": + username = parser.get("Database", "username") + password = parser.get("Database", "password") + host = parser.get("Database", "host") + cn = f"postgresql+psycopg://{username}:{password}@{host}/{database_name}" + elif db_type == "mysql": + username = parser.get("Database", "username") + password = parser.get("Database", "password") + host = parser.get("Database", "host") + cn = f"mysql+pymysql://{username}:{password}@{host}/{database_name}" else: - cn = f'sqlite:///{os.path.join(sqlitePath, f"{database}.db")}' - try: - if not database_exists(cn): - create_database(cn) - except Exception as e: - current_app.logger.error("Database error. Terminating...", e) - exit(1) - + cn = f"sqlite:///{os.path.join(SQLITE_PATH, f'{database_name}.db')}" + + # Ensure database exists + if not database_exists(cn): + create_database(cn) + return cn \ No newline at end of file diff --git a/src/modules/DashboardClients.py b/src/modules/DashboardClients.py index 231141b1..96020f01 100644 --- a/src/modules/DashboardClients.py +++ b/src/modules/DashboardClients.py @@ -8,7 +8,7 @@ import pyotp import sqlalchemy as db import requests -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB from .DashboardClientsPeerAssignment import DashboardClientsPeerAssignment from .DashboardClientsTOTP import DashboardClientsTOTP from .DashboardOIDC import DashboardOIDC @@ -20,7 +20,7 @@ from flask import session class DashboardClients: def __init__(self, wireguardConfigurations): self.logger = DashboardLogger() - self.engine = db.create_engine(ConnectionString("wgdashboard")) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.metadata = db.MetaData() self.OIDC = DashboardOIDC("Client") @@ -32,10 +32,10 @@ class DashboardClients: db.Column('TotpKey', db.String(500)), db.Column('TotpKeyVerified', db.Integer), db.Column('CreatedDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP), server_default=db.func.now()), db.Column('DeletedDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)), extend_existing=True, ) @@ -46,10 +46,10 @@ class DashboardClients: db.Column('ProviderIssuer', db.String(500), nullable=False, index=True), db.Column('ProviderSubject', db.String(500), nullable=False, index=True), db.Column('CreatedDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP), server_default=db.func.now()), db.Column('DeletedDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)), extend_existing=True, ) @@ -65,10 +65,10 @@ class DashboardClients: db.Column('ResetToken', db.String(255), nullable=False, primary_key=True), db.Column('ClientID', db.String(255), nullable=False), db.Column('CreatedDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP), server_default=db.func.now()), db.Column('ExpiryDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)), extend_existing=True ) diff --git a/src/modules/DashboardClientsPeerAssignment.py b/src/modules/DashboardClientsPeerAssignment.py index 80722d06..e437bc92 100644 --- a/src/modules/DashboardClientsPeerAssignment.py +++ b/src/modules/DashboardClientsPeerAssignment.py @@ -1,7 +1,7 @@ import datetime import uuid -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB from .DashboardLogger import DashboardLogger import sqlalchemy as db from .WireguardConfiguration import WireguardConfiguration @@ -31,7 +31,7 @@ class Assignment: class DashboardClientsPeerAssignment: def __init__(self, wireguardConfigurations: dict[str, WireguardConfiguration]): self.logger = DashboardLogger() - self.engine = db.create_engine(ConnectionString("wgdashboard")) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.metadata = db.MetaData() self.wireguardConfigurations = wireguardConfigurations self.dashboardClientsPeerAssignmentTable = db.Table( @@ -41,10 +41,10 @@ class DashboardClientsPeerAssignment: db.Column('ConfigurationName', db.String(255)), db.Column('PeerID', db.String(500)), db.Column('AssignedDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP), server_default=db.func.now()), db.Column('UnassignedDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)), extend_existing=True ) self.metadata.create_all(self.engine) diff --git a/src/modules/DashboardClientsTOTP.py b/src/modules/DashboardClientsTOTP.py index e3830fb5..486b08bc 100644 --- a/src/modules/DashboardClientsTOTP.py +++ b/src/modules/DashboardClientsTOTP.py @@ -3,19 +3,19 @@ import hashlib import uuid import sqlalchemy as db -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB class DashboardClientsTOTP: def __init__(self): - self.engine = db.create_engine(ConnectionString("wgdashboard")) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.metadata = db.MetaData() self.dashboardClientsTOTPTable = db.Table( 'DashboardClientsTOTPTokens', self.metadata, db.Column("Token", db.String(500), primary_key=True, index=True), db.Column("ClientID", db.String(500), index=True), db.Column( - "ExpireTime", (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP) + "ExpireTime", (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP) ) ) self.metadata.create_all(self.engine) diff --git a/src/modules/DashboardConfig.py b/src/modules/DashboardConfig.py index 2aeedc5e..2ae05da5 100644 --- a/src/modules/DashboardConfig.py +++ b/src/modules/DashboardConfig.py @@ -7,7 +7,7 @@ import sqlalchemy as db from datetime import datetime from typing import Any from flask import current_app -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB from .Utilities import ( GetRemoteEndpoint, ValidateDNSAddress ) @@ -65,7 +65,8 @@ class DashboardConfig: "host": "", "port": "", "username": "", - "password": "" + "password": "", + "prefix": "" }, "Email":{ "server": "", @@ -95,29 +96,13 @@ class DashboardConfig: if not exist: self.SetConfig(section, key, value, True) - self.engine = db.create_engine(ConnectionString('wgdashboard')) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.dbMetadata = db.MetaData() self.__createAPIKeyTable() self.DashboardAPIKeys = self.__getAPIKeys() self.APIAccessed = False self.SetConfig("Server", "version", DashboardConfig.DashboardVersion) - def getConnectionString(self, database) -> str or None: - sqlitePath = os.path.join(DashboardConfig.ConfigurationPath, "db") - - if not os.path.isdir(sqlitePath): - os.mkdir(sqlitePath) - - if self.GetConfig("Database", "type")[1] == "postgresql": - cn = f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}' - elif self.GetConfig("Database", "type")[1] == "mysql": - cn = f'mysql+mysqldb://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}' - else: - cn = f'sqlite:///{os.path.join(sqlitePath, f"{database}.db")}' - if not database_exists(cn): - create_database(cn) - return cn - def __createAPIKeyTable(self): self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata, db.Column("Key", db.String(255), nullable=False, primary_key=True), diff --git a/src/modules/DashboardLogger.py b/src/modules/DashboardLogger.py index 9b4e1f24..18602587 100644 --- a/src/modules/DashboardLogger.py +++ b/src/modules/DashboardLogger.py @@ -4,18 +4,18 @@ Dashboard Logger Class import uuid import sqlalchemy as db from flask import current_app -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB, DEFAULT_LOG_DB class DashboardLogger: def __init__(self): - self.engine = db.create_engine(ConnectionString("wgdashboard_log")) + self.engine = db.create_engine(ConnectionString(DEFAULT_LOG_DB)) self.metadata = db.MetaData() self.dashboardLoggerTable = db.Table('DashboardLog', self.metadata, db.Column('LogID', db.String(255), nullable=False, primary_key=True), db.Column('LogDate', - (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), + (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP), server_default=db.func.now()), db.Column('URL', db.String(255)), db.Column('IP', db.String(255)), diff --git a/src/modules/DashboardWebHooks.py b/src/modules/DashboardWebHooks.py index a598444b..78004e2e 100644 --- a/src/modules/DashboardWebHooks.py +++ b/src/modules/DashboardWebHooks.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta import requests from pydantic import BaseModel, field_serializer import sqlalchemy as db -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB from flask import current_app WebHookActions = ['peer_created', 'peer_deleted', 'peer_updated'] @@ -40,7 +40,7 @@ class WebHookSessionLogs(BaseModel): class DashboardWebHooks: def __init__(self, DashboardConfig): - self.engine = db.create_engine(ConnectionString("wgdashboard")) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.metadata = db.MetaData() self.webHooksTable = db.Table( 'DashboardWebHooks', self.metadata, @@ -201,7 +201,7 @@ class DashboardWebHooks: class WebHookSession: def __init__(self, webHook: WebHook, data: dict[str, str]): - self.engine = db.create_engine(ConnectionString("wgdashboard")) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.metadata = db.MetaData() self.webHookSessionsTable = db.Table('DashboardWebHookSessions', self.metadata, autoload_with=self.engine) self.webHook = webHook diff --git a/src/modules/NewConfigurationTemplates.py b/src/modules/NewConfigurationTemplates.py index 9c4511a4..26a00576 100644 --- a/src/modules/NewConfigurationTemplates.py +++ b/src/modules/NewConfigurationTemplates.py @@ -2,7 +2,7 @@ import uuid from pydantic import BaseModel, field_serializer import sqlalchemy as db -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB class NewConfigurationTemplate(BaseModel): @@ -14,7 +14,7 @@ class NewConfigurationTemplate(BaseModel): class NewConfigurationTemplates: def __init__(self): - self.engine = db.create_engine(ConnectionString("wgdashboard")) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.metadata = db.MetaData() self.templatesTable = db.Table( 'NewConfigurationTemplates', self.metadata, diff --git a/src/modules/PeerJobLogger.py b/src/modules/PeerJobLogger.py index b047f3c4..a82e5f26 100644 --- a/src/modules/PeerJobLogger.py +++ b/src/modules/PeerJobLogger.py @@ -4,12 +4,12 @@ Peer Job Logger import uuid import sqlalchemy as db from flask import current_app -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_LOG_DB from .Log import Log class PeerJobLogger: def __init__(self, AllPeerJobs, DashboardConfig): - self.engine = db.create_engine(ConnectionString("wgdashboard_log")) + self.engine = db.create_engine(ConnectionString(DEFAULT_LOG_DB)) self.metadata = db.MetaData() self.jobLogTable = db.Table('JobLog', self.metadata, db.Column('LogID', db.String(255), nullable=False, primary_key=True), diff --git a/src/modules/PeerJobs.py b/src/modules/PeerJobs.py index 5d069f66..a4fa8ede 100644 --- a/src/modules/PeerJobs.py +++ b/src/modules/PeerJobs.py @@ -1,7 +1,7 @@ """ Peer Jobs """ -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_JOB_DB from .PeerJob import PeerJob from .PeerJobLogger import PeerJobLogger import sqlalchemy as db @@ -11,7 +11,7 @@ from flask import current_app class PeerJobs: def __init__(self, DashboardConfig, WireguardConfigurations): self.Jobs: list[PeerJob] = [] - self.engine = db.create_engine(ConnectionString('wgdashboard_job')) + self.engine = db.create_engine(ConnectionString(DEFAULT_JOB_DB)) self.metadata = db.MetaData() self.peerJobTable = db.Table('PeerJobs', self.metadata, db.Column('JobID', db.String(255), nullable=False, primary_key=True), diff --git a/src/modules/PeerShareLinks.py b/src/modules/PeerShareLinks.py index 206e2fd0..bfc96f02 100644 --- a/src/modules/PeerShareLinks.py +++ b/src/modules/PeerShareLinks.py @@ -1,4 +1,4 @@ -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB from .PeerShareLink import PeerShareLink import sqlalchemy as db from datetime import datetime @@ -10,7 +10,7 @@ Peer Share Links class PeerShareLinks: def __init__(self, DashboardConfig, WireguardConfigurations): self.Links: list[PeerShareLink] = [] - self.engine = db.create_engine(ConnectionString("wgdashboard")) + self.engine = db.create_engine(ConnectionString(DEFAULT_DB)) self.metadata = db.MetaData() self.peerShareLinksTable = db.Table( 'PeerShareLinks', self.metadata, diff --git a/src/modules/WireguardConfiguration.py b/src/modules/WireguardConfiguration.py index b3b04099..05c5de6e 100644 --- a/src/modules/WireguardConfiguration.py +++ b/src/modules/WireguardConfiguration.py @@ -10,7 +10,7 @@ from datetime import datetime, timedelta from itertools import islice from flask import current_app -from .ConnectionString import ConnectionString +from .ConnectionString import ConnectionString, DEFAULT_DB from .DashboardConfig import DashboardConfig from .Peer import Peer from .PeerJobs import PeerJobs @@ -64,7 +64,7 @@ class WireguardConfiguration: self.AllPeerShareLinks = AllPeerShareLinks self.DashboardWebHooks = DashboardWebHooks self.configPath = os.path.join(self.__getProtocolPath(), f'{self.Name}.conf') - self.engine: sqlalchemy.Engine = sqlalchemy.create_engine(ConnectionString("wgdashboard")) + self.engine: sqlalchemy.Engine = sqlalchemy.create_engine(ConnectionString(DEFAULT_DB)) self.metadata: sqlalchemy.MetaData = sqlalchemy.MetaData() self.dbType = self.DashboardConfig.GetConfig("Database", "type")[1]