mirror of
https://github.com/donaldzou/WGDashboard.git
synced 2026-04-20 03:36:17 +00:00
refac: some WGDashboard code
This commit is contained in:
@@ -7,14 +7,10 @@ import sqlalchemy as db
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from flask import current_app
|
||||
from .ConnectionString import ConnectionString
|
||||
from .Utilities import (
|
||||
GetRemoteEndpoint, ValidateDNSAddress
|
||||
)
|
||||
from .DatabaseConnection import ConnectionString
|
||||
from .Utilities import (GetRemoteEndpoint, ValidateDNSAddress)
|
||||
from .DashboardAPIKey import DashboardAPIKey
|
||||
|
||||
|
||||
|
||||
class DashboardConfig:
|
||||
DashboardVersion = 'v4.3.2'
|
||||
ConfigurationPath = os.getenv('CONFIGURATION_PATH', '.')
|
||||
@@ -104,6 +100,52 @@ class DashboardConfig:
|
||||
self.APIAccessed = False
|
||||
self.SetConfig("Server", "version", DashboardConfig.DashboardVersion)
|
||||
|
||||
def EnsureDatabaseIntegrity(self, wireguardConfigurations):
|
||||
expected_columns = {
|
||||
'id': db.String(255),
|
||||
'private_key': db.String(255),
|
||||
'DNS': db.Text,
|
||||
'endpoint_allowed_ip': db.Text,
|
||||
'name': db.Text,
|
||||
'total_receive': db.Float,
|
||||
'total_sent': db.Float,
|
||||
'total_data': db.Float,
|
||||
'endpoint': db.String(255),
|
||||
'status': db.String(255),
|
||||
'latest_handshake': db.String(255),
|
||||
'allowed_ip': db.String(255),
|
||||
'cumu_receive': db.Float,
|
||||
'cumu_sent': db.Float,
|
||||
'cumu_data': db.Float,
|
||||
'mtu': db.Integer,
|
||||
'keepalive': db.Integer,
|
||||
'notes': db.Text,
|
||||
'remote_endpoint': db.String(255),
|
||||
'preshared_key': db.String(255)
|
||||
}
|
||||
|
||||
inspector = db.inspect(self.engine)
|
||||
|
||||
with self.engine.begin() as conn:
|
||||
for cfg_name, cfg_obj in wireguardConfigurations.items():
|
||||
tables_to_check = [
|
||||
cfg_name,
|
||||
f'{cfg_name}_restrict_access',
|
||||
f'{cfg_name}_deleted'
|
||||
]
|
||||
|
||||
for table_name in tables_to_check:
|
||||
if not inspector.has_table(table_name):
|
||||
continue
|
||||
|
||||
existing_columns = [c['name'] for c in inspector.get_columns(table_name)]
|
||||
|
||||
for col_name, col_type in expected_columns.items():
|
||||
if col_name not in existing_columns:
|
||||
type_str = col_type().compile(dialect=self.engine.dialect)
|
||||
current_app.logger.info(f"Adding missing column '{col_name}' to table '{table_name}'")
|
||||
conn.execute(db.text(f'ALTER TABLE "{table_name}" ADD COLUMN "{col_name}" {type_str}'))
|
||||
|
||||
def getConnectionString(self, database) -> str or None:
|
||||
sqlitePath = os.path.join(DashboardConfig.ConfigurationPath, "db")
|
||||
|
||||
@@ -118,7 +160,7 @@ class DashboardConfig:
|
||||
cn = f'sqlite:///{os.path.join(sqlitePath, f"{database}.db")}'
|
||||
if not database_exists(cn):
|
||||
create_database(cn)
|
||||
return cn
|
||||
return cn
|
||||
|
||||
def __createAPIKeyTable(self):
|
||||
self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata,
|
||||
|
||||
Reference in New Issue
Block a user