mirror of
https://github.com/donaldzou/WGDashboard.git
synced 2025-10-03 15:56:17 +00:00
Update dashboard.py
- Updated `DashboardConfig` class to use SqlAlchemy, tested with SQLite and Postgresql
This commit is contained in:
105
src/dashboard.py
105
src/dashboard.py
@@ -25,7 +25,10 @@ from modules.PeerJob import PeerJob
|
|||||||
from modules.SystemStatus import SystemStatus
|
from modules.SystemStatus import SystemStatus
|
||||||
SystemStatus = SystemStatus()
|
SystemStatus = SystemStatus()
|
||||||
|
|
||||||
DASHBOARD_VERSION = 'v4.2.5'
|
from sqlalchemy_utils import database_exists, create_database
|
||||||
|
import sqlalchemy as db
|
||||||
|
|
||||||
|
DASHBOARD_VERSION = 'v4.2.3'
|
||||||
|
|
||||||
CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.')
|
CONFIGURATION_PATH = os.getenv('CONFIGURATION_PATH', '.')
|
||||||
DB_PATH = os.path.join(CONFIGURATION_PATH, 'db')
|
DB_PATH = os.path.join(CONFIGURATION_PATH, 'db')
|
||||||
@@ -180,7 +183,6 @@ class PeerJobs:
|
|||||||
|
|
||||||
def runJob(self):
|
def runJob(self):
|
||||||
needToDelete = []
|
needToDelete = []
|
||||||
self.__getJobs()
|
|
||||||
for job in self.Jobs:
|
for job in self.Jobs:
|
||||||
c = WireguardConfigurations.get(job.Configuration)
|
c = WireguardConfigurations.get(job.Configuration)
|
||||||
if c is not None:
|
if c is not None:
|
||||||
@@ -433,6 +435,7 @@ class WireguardConfiguration:
|
|||||||
original = [l.rstrip("\n") for l in f.readlines()]
|
original = [l.rstrip("\n") for l in f.readlines()]
|
||||||
try:
|
try:
|
||||||
start = original.index("[Interface]")
|
start = original.index("[Interface]")
|
||||||
|
|
||||||
# Clean
|
# Clean
|
||||||
for i in range(start, len(original)):
|
for i in range(start, len(original)):
|
||||||
if original[i] == "[Peer]":
|
if original[i] == "[Peer]":
|
||||||
@@ -459,10 +462,7 @@ class WireguardConfiguration:
|
|||||||
setattr(self, key, StringToBoolean(value))
|
setattr(self, key, StringToBoolean(value))
|
||||||
else:
|
else:
|
||||||
if len(getattr(self, key)) > 0:
|
if len(getattr(self, key)) > 0:
|
||||||
if key not in ["PostUp", "PostDown", "PreUp", "PreDown"]:
|
|
||||||
setattr(self, key, f"{getattr(self, key)}, {value}")
|
setattr(self, key, f"{getattr(self, key)}, {value}")
|
||||||
else:
|
|
||||||
setattr(self, key, f"{getattr(self, key)}; {value}")
|
|
||||||
else:
|
else:
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -1212,15 +1212,15 @@ AmneziaWG Configuration
|
|||||||
"""
|
"""
|
||||||
class AmneziaWireguardConfiguration(WireguardConfiguration):
|
class AmneziaWireguardConfiguration(WireguardConfiguration):
|
||||||
def __init__(self, name: str = None, data: dict = None, backup: dict = None, startup: bool = False):
|
def __init__(self, name: str = None, data: dict = None, backup: dict = None, startup: bool = False):
|
||||||
self.Jc = ""
|
self.Jc = 0
|
||||||
self.Jmin = ""
|
self.Jmin = 0
|
||||||
self.Jmax = ""
|
self.Jmax = 0
|
||||||
self.S1 = ""
|
self.S1 = 0
|
||||||
self.S2 = ""
|
self.S2 = 0
|
||||||
self.H1 = ""
|
self.H1 = 1
|
||||||
self.H2 = ""
|
self.H2 = 2
|
||||||
self.H3 = ""
|
self.H3 = 3
|
||||||
self.H4 = ""
|
self.H4 = 4
|
||||||
|
|
||||||
super().__init__(name, data, backup, startup, wg=False)
|
super().__init__(name, data, backup, startup, wg=False)
|
||||||
|
|
||||||
@@ -1245,7 +1245,6 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
|
|||||||
},
|
},
|
||||||
"ConnectedPeers": len(list(filter(lambda x: x.status == "running", self.Peers))),
|
"ConnectedPeers": len(list(filter(lambda x: x.status == "running", self.Peers))),
|
||||||
"TotalPeers": len(self.Peers),
|
"TotalPeers": len(self.Peers),
|
||||||
"Table": self.Table,
|
|
||||||
"Protocol": self.Protocol,
|
"Protocol": self.Protocol,
|
||||||
"Jc": self.Jc,
|
"Jc": self.Jc,
|
||||||
"Jmin": self.Jmin,
|
"Jmin": self.Jmin,
|
||||||
@@ -1765,7 +1764,7 @@ class DashboardConfig:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not os.path.exists(DASHBOARD_CONF):
|
if not os.path.exists(DASHBOARD_CONF):
|
||||||
open(DASHBOARD_CONF, "x")
|
open(DASHBOARD_CONF, "x")
|
||||||
self.__config = configparser.RawConfigParser(strict=False)
|
self.__config = configparser.ConfigParser(strict=False)
|
||||||
self.__config.read_file(open(DASHBOARD_CONF, "r+"))
|
self.__config.read_file(open(DASHBOARD_CONF, "r+"))
|
||||||
self.hiddenAttribute = ["totp_key", "auth_req"]
|
self.hiddenAttribute = ["totp_key", "auth_req"]
|
||||||
self.__default = {
|
self.__default = {
|
||||||
@@ -1828,39 +1827,80 @@ class DashboardConfig:
|
|||||||
exist, currentData = self.GetConfig(section, key)
|
exist, currentData = self.GetConfig(section, key)
|
||||||
if not exist:
|
if not exist:
|
||||||
self.SetConfig(section, key, value, True)
|
self.SetConfig(section, key, value, True)
|
||||||
|
|
||||||
|
self.engine = db.create_engine(self.getConnectionString('wgdashboard'))
|
||||||
|
self.db = self.engine.connect()
|
||||||
|
self.dbMetadata = db.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.__createAPIKeyTable()
|
self.__createAPIKeyTable()
|
||||||
self.DashboardAPIKeys = self.__getAPIKeys()
|
self.DashboardAPIKeys = self.__getAPIKeys()
|
||||||
self.APIAccessed = False
|
self.APIAccessed = False
|
||||||
self.SetConfig("Server", "version", DASHBOARD_VERSION)
|
self.SetConfig("Server", "version", DASHBOARD_VERSION)
|
||||||
|
|
||||||
def getConnectionString(self, database) -> str or None:
|
def getConnectionString(self, database) -> str or None:
|
||||||
|
cn = None
|
||||||
if self.GetConfig("Database", "type")[1] == "sqlite":
|
if self.GetConfig("Database", "type")[1] == "sqlite":
|
||||||
return f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}'
|
cn = f'sqlite:///{os.path.join(CONFIGURATION_PATH, "db", f"{database}.db")}'
|
||||||
elif self.GetConfig("Database", "type")[1] == "postgresql":
|
elif self.GetConfig("Database", "type")[1] == "postgresql":
|
||||||
return f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}'
|
cn = f'postgresql+psycopg2://{self.GetConfig("Database", "username")[1]}:{self.GetConfig("Database", "password")[1]}@{self.GetConfig("Database", "host")[1]}/{database}'
|
||||||
return None
|
|
||||||
|
if not database_exists(cn):
|
||||||
|
create_database(cn)
|
||||||
|
|
||||||
|
return cn
|
||||||
|
|
||||||
def __createAPIKeyTable(self):
|
def __createAPIKeyTable(self):
|
||||||
existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall()
|
self.apiKeyTable = db.Table('DashboardAPIKeys', self.dbMetadata,
|
||||||
if len(existingTable) == 0:
|
db.Column("Key", db.String, nullable=False, primary_key=True),
|
||||||
sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)")
|
db.Column("CreatedAt",
|
||||||
|
(db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP),
|
||||||
|
server_default=db.func.now()
|
||||||
|
),
|
||||||
|
db.Column("ExpiredAt",
|
||||||
|
(db.DATETIME if self.GetConfig('Database', 'type')[1] == 'sqlite' else db.TIMESTAMP)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.dbMetadata.create_all(self.engine)
|
||||||
|
|
||||||
|
# existingTable = sqlSelect("SELECT name FROM sqlite_master WHERE type='table' AND name = 'DashboardAPIKeys'").fetchall()
|
||||||
|
# if len(existingTable) == 0:
|
||||||
|
# sqlUpdate("CREATE TABLE DashboardAPIKeys (Key VARCHAR NOT NULL PRIMARY KEY, CreatedAt DATETIME NOT NULL DEFAULT (datetime('now', 'localtime')), ExpiredAt VARCHAR)")
|
||||||
|
|
||||||
def __getAPIKeys(self) -> list[DashboardAPIKey]:
|
def __getAPIKeys(self) -> list[DashboardAPIKey]:
|
||||||
keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall()
|
# keys = sqlSelect("SELECT * FROM DashboardAPIKeys WHERE ExpiredAt IS NULL OR ExpiredAt > datetime('now', 'localtime') ORDER BY CreatedAt DESC").fetchall()
|
||||||
|
keys = self.db.execute(self.apiKeyTable.select().where(
|
||||||
|
db.or_(self.apiKeyTable.columns.ExpiredAt == None, self.apiKeyTable.columns.ExpiredAt > datetime.now())
|
||||||
|
)).fetchall()
|
||||||
fKeys = []
|
fKeys = []
|
||||||
for k in keys:
|
for k in keys:
|
||||||
|
fKeys.append(DashboardAPIKey(k[0], k[1].strftime("%Y-%m-%d %H:%M:%S"), (k[2].strftime("%Y-%m-%d %H:%M:%S") if k[2] else None)))
|
||||||
fKeys.append(DashboardAPIKey(*k))
|
|
||||||
return fKeys
|
return fKeys
|
||||||
|
|
||||||
def createAPIKeys(self, ExpiredAt = None):
|
def createAPIKeys(self, ExpiredAt = None):
|
||||||
newKey = secrets.token_urlsafe(32)
|
newKey = secrets.token_urlsafe(32)
|
||||||
sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,))
|
# sqlUpdate('INSERT INTO DashboardAPIKeys (Key, ExpiredAt) VALUES (?, ?)', (newKey, ExpiredAt,))
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
self.apiKeyTable.insert().values({
|
||||||
|
"Key": newKey,
|
||||||
|
"ExpiredAt": ExpiredAt
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
self.DashboardAPIKeys = self.__getAPIKeys()
|
self.DashboardAPIKeys = self.__getAPIKeys()
|
||||||
|
|
||||||
def deleteAPIKey(self, key):
|
def deleteAPIKey(self, key):
|
||||||
sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, ))
|
# sqlUpdate("UPDATE DashboardAPIKeys SET ExpiredAt = datetime('now', 'localtime') WHERE Key = ?", (key, ))
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
self.apiKeyTable.update().values({
|
||||||
|
"ExpiredAt": datetime.now(),
|
||||||
|
}).where(self.apiKeyTable.columns.Key == key)
|
||||||
|
)
|
||||||
|
|
||||||
self.DashboardAPIKeys = self.__getAPIKeys()
|
self.DashboardAPIKeys = self.__getAPIKeys()
|
||||||
|
|
||||||
def __configValidation(self, section : str, key: str, value: Any) -> [bool, str]:
|
def __configValidation(self, section : str, key: str, value: Any) -> [bool, str]:
|
||||||
@@ -2062,7 +2102,7 @@ def auth_req():
|
|||||||
else:
|
else:
|
||||||
DashboardConfig.APIAccessed = False
|
DashboardConfig.APIAccessed = False
|
||||||
whiteList = [
|
whiteList = [
|
||||||
'/static/', 'validateAuthentication', 'authenticate',
|
'/static/', 'validateAuthentication', 'authenticate', 'getDashboardConfiguration',
|
||||||
'getDashboardTheme', 'getDashboardVersion', 'sharePeer/get', 'isTotpEnabled', 'locale',
|
'getDashboardTheme', 'getDashboardVersion', 'sharePeer/get', 'isTotpEnabled', 'locale',
|
||||||
'/fileDownload'
|
'/fileDownload'
|
||||||
]
|
]
|
||||||
@@ -2393,7 +2433,7 @@ def API_updateDashboardConfigurationItem():
|
|||||||
valid, msg = DashboardConfig.SetConfig(
|
valid, msg = DashboardConfig.SetConfig(
|
||||||
data["section"], data["key"], data['value'])
|
data["section"], data["key"], data['value'])
|
||||||
if not valid:
|
if not valid:
|
||||||
return ResponseObject(False, msg)
|
return ResponseObject(False, msg, status_code=404)
|
||||||
if data['section'] == "Server":
|
if data['section'] == "Server":
|
||||||
if data['key'] == 'wg_conf_path':
|
if data['key'] == 'wg_conf_path':
|
||||||
WireguardConfigurations.clear()
|
WireguardConfigurations.clear()
|
||||||
@@ -3126,12 +3166,9 @@ def peerInformationBackgroundThread():
|
|||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
while True:
|
while True:
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
try:
|
for c in WireguardConfigurations.values():
|
||||||
curKeys = list(WireguardConfigurations.keys())
|
|
||||||
for name in curKeys:
|
|
||||||
if name in WireguardConfigurations.keys() and WireguardConfigurations.get(name) is not None:
|
|
||||||
c = WireguardConfigurations.get(name)
|
|
||||||
if c.getStatus():
|
if c.getStatus():
|
||||||
|
try:
|
||||||
c.getPeersTransfer()
|
c.getPeersTransfer()
|
||||||
c.getPeersLatestHandshake()
|
c.getPeersLatestHandshake()
|
||||||
c.getPeersEndpoint()
|
c.getPeersEndpoint()
|
||||||
|
Reference in New Issue
Block a user