feat: make the prefix configurable

This commit is contained in:
Daan Selen
2025-09-19 20:55:47 +02:00
parent cf77610a56
commit bf28983229
12 changed files with 73 additions and 69 deletions

View File

@@ -1,25 +1,44 @@
# ConnectionString.py
import configparser import configparser
import os import os
from sqlalchemy_utils import database_exists, create_database from sqlalchemy_utils import database_exists, create_database
from flask import current_app
def ConnectionString(database) -> str: # Read and parse the INI file once at startup
parser = configparser.ConfigParser(strict=False) parser = configparser.ConfigParser(strict=False)
parser.read_file(open('wg-dashboard.ini', "r+")) parser.read("wg-dashboard.ini")
sqlitePath = os.path.join("db")
if not os.path.isdir(sqlitePath): # Ensure SQLite folder exists
os.mkdir(sqlitePath) SQLITE_PATH = "db"
if parser.get("Database", "type") == "postgresql": os.makedirs(SQLITE_PATH, exist_ok=True)
cn = f'postgresql+psycopg://{parser.get("Database", "username")}:{parser.get("Database", "password")}@{parser.get("Database", "host")}/{database}'
elif parser.get("Database", "type") == "mysql": DEFAULT_DB = "wgdashboard"
cn = f'mysql+pymysql://{parser.get("Database", "username")}:{parser.get("Database", "password")}@{parser.get("Database", "host")}/{database}' DEFAULT_LOG_DB = "wgdashboard_log"
DEFAULT_JOB_DB = "wgdashboard_log"
def ConnectionString(database_name: str) -> str:
"""
Returns a SQLAlchemy-compatible connection string for the chosen database.
Creates the database if it doesn't exist.
"""
db_type = parser.get("Database", "type")
db_prefix = parser.get("Database", "prefix")
database_name = f"{db_prefix}_{database_name}"
if db_type == "postgresql":
username = parser.get("Database", "username")
password = parser.get("Database", "password")
host = parser.get("Database", "host")
cn = f"postgresql+psycopg://{username}:{password}@{host}/{database_name}"
elif db_type == "mysql":
username = parser.get("Database", "username")
password = parser.get("Database", "password")
host = parser.get("Database", "host")
cn = f"mysql+pymysql://{username}:{password}@{host}/{database_name}"
else: else:
cn = f'sqlite:///{os.path.join(sqlitePath, f"{database}.db")}' cn = f"sqlite:///{os.path.join(SQLITE_PATH, f'{database_name}.db')}"
try:
# Ensure database exists
if not database_exists(cn): if not database_exists(cn):
create_database(cn) create_database(cn)
except Exception as e:
current_app.logger.error("Database error. Terminating...", e)
exit(1)
return cn return cn

View File

@@ -8,7 +8,7 @@ import pyotp
import sqlalchemy as db import sqlalchemy as db
import requests import requests
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
from .DashboardClientsPeerAssignment import DashboardClientsPeerAssignment from .DashboardClientsPeerAssignment import DashboardClientsPeerAssignment
from .DashboardClientsTOTP import DashboardClientsTOTP from .DashboardClientsTOTP import DashboardClientsTOTP
from .DashboardOIDC import DashboardOIDC from .DashboardOIDC import DashboardOIDC
@@ -20,7 +20,7 @@ from flask import session
class DashboardClients: class DashboardClients:
def __init__(self, wireguardConfigurations): def __init__(self, wireguardConfigurations):
self.logger = DashboardLogger() self.logger = DashboardLogger()
self.engine = db.create_engine(ConnectionString("wgdashboard")) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.OIDC = DashboardOIDC("Client") self.OIDC = DashboardOIDC("Client")
@@ -32,10 +32,10 @@ class DashboardClients:
db.Column('TotpKey', db.String(500)), db.Column('TotpKey', db.String(500)),
db.Column('TotpKeyVerified', db.Integer), db.Column('TotpKeyVerified', db.Integer),
db.Column('CreatedDate', db.Column('CreatedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP),
server_default=db.func.now()), server_default=db.func.now()),
db.Column('DeletedDate', db.Column('DeletedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)),
extend_existing=True, extend_existing=True,
) )
@@ -46,10 +46,10 @@ class DashboardClients:
db.Column('ProviderIssuer', db.String(500), nullable=False, index=True), db.Column('ProviderIssuer', db.String(500), nullable=False, index=True),
db.Column('ProviderSubject', db.String(500), nullable=False, index=True), db.Column('ProviderSubject', db.String(500), nullable=False, index=True),
db.Column('CreatedDate', db.Column('CreatedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP),
server_default=db.func.now()), server_default=db.func.now()),
db.Column('DeletedDate', db.Column('DeletedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)),
extend_existing=True, extend_existing=True,
) )
@@ -65,10 +65,10 @@ class DashboardClients:
db.Column('ResetToken', db.String(255), nullable=False, primary_key=True), db.Column('ResetToken', db.String(255), nullable=False, primary_key=True),
db.Column('ClientID', db.String(255), nullable=False), db.Column('ClientID', db.String(255), nullable=False),
db.Column('CreatedDate', db.Column('CreatedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP),
server_default=db.func.now()), server_default=db.func.now()),
db.Column('ExpiryDate', db.Column('ExpiryDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)),
extend_existing=True extend_existing=True
) )

View File

@@ -1,7 +1,7 @@
import datetime import datetime
import uuid import uuid
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
from .DashboardLogger import DashboardLogger from .DashboardLogger import DashboardLogger
import sqlalchemy as db import sqlalchemy as db
from .WireguardConfiguration import WireguardConfiguration from .WireguardConfiguration import WireguardConfiguration
@@ -31,7 +31,7 @@ class Assignment:
class DashboardClientsPeerAssignment: class DashboardClientsPeerAssignment:
def __init__(self, wireguardConfigurations: dict[str, WireguardConfiguration]): def __init__(self, wireguardConfigurations: dict[str, WireguardConfiguration]):
self.logger = DashboardLogger() self.logger = DashboardLogger()
self.engine = db.create_engine(ConnectionString("wgdashboard")) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.wireguardConfigurations = wireguardConfigurations self.wireguardConfigurations = wireguardConfigurations
self.dashboardClientsPeerAssignmentTable = db.Table( self.dashboardClientsPeerAssignmentTable = db.Table(
@@ -41,10 +41,10 @@ class DashboardClientsPeerAssignment:
db.Column('ConfigurationName', db.String(255)), db.Column('ConfigurationName', db.String(255)),
db.Column('PeerID', db.String(500)), db.Column('PeerID', db.String(500)),
db.Column('AssignedDate', db.Column('AssignedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP),
server_default=db.func.now()), server_default=db.func.now()),
db.Column('UnassignedDate', db.Column('UnassignedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)),
extend_existing=True extend_existing=True
) )
self.metadata.create_all(self.engine) self.metadata.create_all(self.engine)

View File

@@ -3,19 +3,19 @@ import hashlib
import uuid import uuid
import sqlalchemy as db import sqlalchemy as db
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
class DashboardClientsTOTP: class DashboardClientsTOTP:
def __init__(self): def __init__(self):
self.engine = db.create_engine(ConnectionString("wgdashboard")) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.dashboardClientsTOTPTable = db.Table( self.dashboardClientsTOTPTable = db.Table(
'DashboardClientsTOTPTokens', self.metadata, 'DashboardClientsTOTPTokens', self.metadata,
db.Column("Token", db.String(500), primary_key=True, index=True), db.Column("Token", db.String(500), primary_key=True, index=True),
db.Column("ClientID", db.String(500), index=True), db.Column("ClientID", db.String(500), index=True),
db.Column( db.Column(
"ExpireTime", (db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP) "ExpireTime", (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP)
) )
) )
self.metadata.create_all(self.engine) self.metadata.create_all(self.engine)

View File

@@ -7,7 +7,7 @@ import sqlalchemy as db
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from flask import current_app from flask import current_app
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
from .Utilities import ( from .Utilities import (
GetRemoteEndpoint, ValidateDNSAddress GetRemoteEndpoint, ValidateDNSAddress
) )
@@ -65,7 +65,8 @@ class DashboardConfig:
"host": "", "host": "",
"port": "", "port": "",
"username": "", "username": "",
"password": "" "password": "",
"prefix": ""
}, },
"Email":{ "Email":{
"server": "", "server": "",
@@ -95,29 +96,13 @@ class DashboardConfig:
if not exist: if not exist:
self.SetConfig(section, key, value, True) self.SetConfig(section, key, value, True)
self.engine = db.create_engine(ConnectionString('wgdashboard')) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.dbMetadata = db.MetaData() self.dbMetadata = db.MetaData()
self.__createAPIKeyTable() self.__createAPIKeyTable()
self.DashboardAPIKeys = self.__getAPIKeys() self.DashboardAPIKeys = self.__getAPIKeys()
self.APIAccessed = False self.APIAccessed = False
self.SetConfig("Server", "version", DashboardConfig.DashboardVersion) self.SetConfig("Server", "version", DashboardConfig.DashboardVersion)
def getConnectionString(self, database) -> str or None:
sqlitePath = os.path.join(DashboardConfig.ConfigurationPath, "db")
if not os.path.isdir(sqlitePath):
os.mkdir(sqlitePath)
if self.GetConfig("Database", "type")[1] == "postgresql":
cn = f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}'
elif self.GetConfig("Database", "type")[1] == "mysql":
cn = f'mysql+mysqldb://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}'
else:
cn = f'sqlite:///{os.path.join(sqlitePath, f"{database}.db")}'
if not database_exists(cn):
create_database(cn)
return cn
def __createAPIKeyTable(self): def __createAPIKeyTable(self):
self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata, self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata,
db.Column("Key", db.String(255), nullable=False, primary_key=True), db.Column("Key", db.String(255), nullable=False, primary_key=True),

View File

@@ -4,18 +4,18 @@ Dashboard Logger Class
import uuid import uuid
import sqlalchemy as db import sqlalchemy as db
from flask import current_app from flask import current_app
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB, DEFAULT_LOG_DB
class DashboardLogger: class DashboardLogger:
def __init__(self): def __init__(self):
self.engine = db.create_engine(ConnectionString("wgdashboard_log")) self.engine = db.create_engine(ConnectionString(DEFAULT_LOG_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.dashboardLoggerTable = db.Table('DashboardLog', self.metadata, self.dashboardLoggerTable = db.Table('DashboardLog', self.metadata,
db.Column('LogID', db.String(255), nullable=False, primary_key=True), db.Column('LogID', db.String(255), nullable=False, primary_key=True),
db.Column('LogDate', db.Column('LogDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP), (db.DATETIME if 'sqlite:///' in ConnectionString(DEFAULT_DB) else db.TIMESTAMP),
server_default=db.func.now()), server_default=db.func.now()),
db.Column('URL', db.String(255)), db.Column('URL', db.String(255)),
db.Column('IP', db.String(255)), db.Column('IP', db.String(255)),

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timedelta
import requests import requests
from pydantic import BaseModel, field_serializer from pydantic import BaseModel, field_serializer
import sqlalchemy as db import sqlalchemy as db
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
from flask import current_app from flask import current_app
WebHookActions = ['peer_created', 'peer_deleted', 'peer_updated'] WebHookActions = ['peer_created', 'peer_deleted', 'peer_updated']
@@ -40,7 +40,7 @@ class WebHookSessionLogs(BaseModel):
class DashboardWebHooks: class DashboardWebHooks:
def __init__(self, DashboardConfig): def __init__(self, DashboardConfig):
self.engine = db.create_engine(ConnectionString("wgdashboard")) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.webHooksTable = db.Table( self.webHooksTable = db.Table(
'DashboardWebHooks', self.metadata, 'DashboardWebHooks', self.metadata,
@@ -201,7 +201,7 @@ class DashboardWebHooks:
class WebHookSession: class WebHookSession:
def __init__(self, webHook: WebHook, data: dict[str, str]): def __init__(self, webHook: WebHook, data: dict[str, str]):
self.engine = db.create_engine(ConnectionString("wgdashboard")) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.webHookSessionsTable = db.Table('DashboardWebHookSessions', self.metadata, autoload_with=self.engine) self.webHookSessionsTable = db.Table('DashboardWebHookSessions', self.metadata, autoload_with=self.engine)
self.webHook = webHook self.webHook = webHook

View File

@@ -2,7 +2,7 @@ import uuid
from pydantic import BaseModel, field_serializer from pydantic import BaseModel, field_serializer
import sqlalchemy as db import sqlalchemy as db
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
class NewConfigurationTemplate(BaseModel): class NewConfigurationTemplate(BaseModel):
@@ -14,7 +14,7 @@ class NewConfigurationTemplate(BaseModel):
class NewConfigurationTemplates: class NewConfigurationTemplates:
def __init__(self): def __init__(self):
self.engine = db.create_engine(ConnectionString("wgdashboard")) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.templatesTable = db.Table( self.templatesTable = db.Table(
'NewConfigurationTemplates', self.metadata, 'NewConfigurationTemplates', self.metadata,

View File

@@ -4,12 +4,12 @@ Peer Job Logger
import uuid import uuid
import sqlalchemy as db import sqlalchemy as db
from flask import current_app from flask import current_app
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_LOG_DB
from .Log import Log from .Log import Log
class PeerJobLogger: class PeerJobLogger:
def __init__(self, AllPeerJobs, DashboardConfig): def __init__(self, AllPeerJobs, DashboardConfig):
self.engine = db.create_engine(ConnectionString("wgdashboard_log")) self.engine = db.create_engine(ConnectionString(DEFAULT_LOG_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.jobLogTable = db.Table('JobLog', self.metadata, self.jobLogTable = db.Table('JobLog', self.metadata,
db.Column('LogID', db.String(255), nullable=False, primary_key=True), db.Column('LogID', db.String(255), nullable=False, primary_key=True),

View File

@@ -1,7 +1,7 @@
""" """
Peer Jobs Peer Jobs
""" """
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_JOB_DB
from .PeerJob import PeerJob from .PeerJob import PeerJob
from .PeerJobLogger import PeerJobLogger from .PeerJobLogger import PeerJobLogger
import sqlalchemy as db import sqlalchemy as db
@@ -11,7 +11,7 @@ from flask import current_app
class PeerJobs: class PeerJobs:
def __init__(self, DashboardConfig, WireguardConfigurations): def __init__(self, DashboardConfig, WireguardConfigurations):
self.Jobs: list[PeerJob] = [] self.Jobs: list[PeerJob] = []
self.engine = db.create_engine(ConnectionString('wgdashboard_job')) self.engine = db.create_engine(ConnectionString(DEFAULT_JOB_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.peerJobTable = db.Table('PeerJobs', self.metadata, self.peerJobTable = db.Table('PeerJobs', self.metadata,
db.Column('JobID', db.String(255), nullable=False, primary_key=True), db.Column('JobID', db.String(255), nullable=False, primary_key=True),

View File

@@ -1,4 +1,4 @@
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
from .PeerShareLink import PeerShareLink from .PeerShareLink import PeerShareLink
import sqlalchemy as db import sqlalchemy as db
from datetime import datetime from datetime import datetime
@@ -10,7 +10,7 @@ Peer Share Links
class PeerShareLinks: class PeerShareLinks:
def __init__(self, DashboardConfig, WireguardConfigurations): def __init__(self, DashboardConfig, WireguardConfigurations):
self.Links: list[PeerShareLink] = [] self.Links: list[PeerShareLink] = []
self.engine = db.create_engine(ConnectionString("wgdashboard")) self.engine = db.create_engine(ConnectionString(DEFAULT_DB))
self.metadata = db.MetaData() self.metadata = db.MetaData()
self.peerShareLinksTable = db.Table( self.peerShareLinksTable = db.Table(
'PeerShareLinks', self.metadata, 'PeerShareLinks', self.metadata,

View File

@@ -10,7 +10,7 @@ from datetime import datetime, timedelta
from itertools import islice from itertools import islice
from flask import current_app from flask import current_app
from .ConnectionString import ConnectionString from .ConnectionString import ConnectionString, DEFAULT_DB
from .DashboardConfig import DashboardConfig from .DashboardConfig import DashboardConfig
from .Peer import Peer from .Peer import Peer
from .PeerJobs import PeerJobs from .PeerJobs import PeerJobs
@@ -64,7 +64,7 @@ class WireguardConfiguration:
self.AllPeerShareLinks = AllPeerShareLinks self.AllPeerShareLinks = AllPeerShareLinks
self.DashboardWebHooks = DashboardWebHooks self.DashboardWebHooks = DashboardWebHooks
self.configPath = os.path.join(self.__getProtocolPath(), f'{self.Name}.conf') self.configPath = os.path.join(self.__getProtocolPath(), f'{self.Name}.conf')
self.engine: sqlalchemy.Engine = sqlalchemy.create_engine(ConnectionString("wgdashboard")) self.engine: sqlalchemy.Engine = sqlalchemy.create_engine(ConnectionString(DEFAULT_DB))
self.metadata: sqlalchemy.MetaData = sqlalchemy.MetaData() self.metadata: sqlalchemy.MetaData = sqlalchemy.MetaData()
self.dbType = self.DashboardConfig.GetConfig("Database", "type")[1] self.dbType = self.DashboardConfig.GetConfig("Database", "type")[1]