Files
WGDashboard/src/modules/Peer.py

395 lines
16 KiB
Python
Raw Normal View History

"""
Peer
"""
2025-09-16 07:57:30 +08:00
import base64
2025-08-29 16:55:33 +08:00
import datetime
2025-09-16 07:57:30 +08:00
import json
import os, subprocess, uuid, random, re
2025-08-29 16:55:33 +08:00
from datetime import timedelta
2025-08-17 18:58:28 +08:00
import jinja2
2025-08-29 16:55:33 +08:00
import sqlalchemy as db
from .PeerJob import PeerJob
from flask import current_app
from .PeerShareLink import PeerShareLink
2026-02-07 03:03:35 +01:00
from .Utilities import GenerateWireguardPublicKey, CheckAddress, ValidateDNSAddress
class Peer:
def __init__(self, tableData, configuration):
self.configuration = configuration
self.id = tableData["id"]
self.private_key = tableData["private_key"]
self.DNS = tableData["DNS"]
self.endpoint_allowed_ip = tableData["endpoint_allowed_ip"]
self.name = tableData["name"]
self.total_receive = tableData["total_receive"]
self.total_sent = tableData["total_sent"]
self.total_data = tableData["total_data"]
self.endpoint = tableData["endpoint"]
self.status = tableData["status"]
self.latest_handshake = tableData["latest_handshake"]
self.allowed_ip = tableData["allowed_ip"]
self.cumu_receive = tableData["cumu_receive"]
self.cumu_sent = tableData["cumu_sent"]
self.cumu_data = tableData["cumu_data"]
self.mtu = tableData["mtu"]
self.keepalive = tableData["keepalive"]
2026-02-07 03:03:35 +01:00
self.notes = tableData.get("notes", "")
self.remote_endpoint = tableData["remote_endpoint"]
self.preshared_key = tableData["preshared_key"]
self.jobs: list[PeerJob] = []
self.ShareLink: list[PeerShareLink] = []
self.getJobs()
self.getShareLink()
def toJson(self):
# self.getJobs()
# self.getShareLink()
return self.__dict__
def __repr__(self):
return str(self.toJson())
2026-02-07 03:03:35 +01:00
def updatePeer(self, name: str,
private_key: str,
preshared_key: str,
2026-02-07 03:03:35 +01:00
dns_addresses: str,
allowed_ip: str,
endpoint_allowed_ip: str,
mtu: int,
keepalive: int,
notes: str
) -> tuple[bool, str | None]:
if not self.configuration.getStatus():
self.configuration.toggleConfiguration()
2026-02-07 03:03:35 +01:00
# Before we do any compute, let us check if the given endpoint allowed ip is valid at all
if not CheckAddress(endpoint_allowed_ip):
return False, f"Endpoint Allowed IPs format is incorrect"
peers = []
for peer in self.configuration.getPeersList():
# Make sure to exclude your own data when updating since its not really relevant
2026-04-02 17:34:02 +08:00
if peer.id == self.id:
2026-02-07 03:03:35 +01:00
continue
peers.append(peer)
used_allowed_ips = []
for peer in peers:
ips = peer.allowed_ip.split(',')
ips = [ip.strip() for ip in ips]
used_allowed_ips.append(ips)
2026-02-07 03:03:35 +01:00
if allowed_ip in used_allowed_ips:
return False, "Allowed IP already taken by another peer"
2026-04-02 11:34:18 +02:00
2026-02-07 03:03:35 +01:00
if not ValidateDNSAddress(dns_addresses):
return False, f"DNS IP-Address or FQDN is incorrect"
2026-04-02 11:34:18 +02:00
2026-02-07 03:03:35 +01:00
if isinstance(mtu, str):
mtu = 0
2026-02-07 03:03:35 +01:00
if isinstance(keepalive, str):
keepalive = 0
2026-04-02 11:34:18 +02:00
2026-02-07 03:03:35 +01:00
if mtu not in range(0, 1461):
return False, "MTU format is not correct"
2026-04-02 11:34:18 +02:00
if keepalive < 0:
return False, "Persistent Keepalive format is not correct"
2026-02-07 03:03:35 +01:00
if len(private_key) > 0:
pubKey = GenerateWireguardPublicKey(private_key)
if not pubKey[0] or pubKey[1] != self.id:
return False, "Private key does not match with the public key"
2026-02-07 03:03:35 +01:00
try:
2026-02-07 03:03:35 +01:00
rand = random.Random()
uid = str(uuid.UUID(int=rand.getrandbits(128), version=4))
psk_exist = len(preshared_key) > 0
2026-02-07 03:03:35 +01:00
if psk_exist:
with open(uid, "w+") as f:
f.write(preshared_key)
2026-02-07 03:03:35 +01:00
newAllowedIPs = allowed_ip.replace(" ", "")
if not CheckAddress(newAllowedIPs):
return False, "Allowed IPs entry format is incorrect"
2026-04-02 11:34:18 +02:00
command = [self.configuration.Protocol, "set", self.configuration.Name, "peer", self.id, "allowed-ips", newAllowedIPs, "preshared-key", uid if psk_exist else "/dev/null"]
updateAllowedIp = subprocess.check_output(command, stderr=subprocess.STDOUT)
2026-02-07 03:03:35 +01:00
if psk_exist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
current_app.logger.error("Update peer failed when updating Allowed IPs")
return False, "Internal server error"
command = [f"{self.configuration.Protocol}-quick", "save", self.configuration.Name]
saveConfig = subprocess.check_output(command, stderr=subprocess.STDOUT)
if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'):
current_app.logger.error("Update peer failed when saving the configuration")
return False, "Internal server error"
with self.configuration.engine.begin() as conn:
conn.execute(
self.configuration.peersTable.update().values({
"name": name,
"private_key": private_key,
"DNS": dns_addresses,
"endpoint_allowed_ip": endpoint_allowed_ip,
"mtu": mtu,
"keepalive": keepalive,
2026-02-07 03:03:35 +01:00
"notes": notes,
"preshared_key": preshared_key
}).where(
self.configuration.peersTable.c.id == self.id
)
)
return True, None
except subprocess.CalledProcessError as exc:
current_app.logger.error(f"Subprocess call failed:\n{exc.output.decode("UTF-8")}")
return False, "Internal server error"
def downloadPeer(self) -> dict[str, str]:
2025-09-16 07:57:30 +08:00
final = {
"fileName": "",
"file": ""
}
filename = self.name
if len(filename) == 0:
filename = "UntitledPeer"
filename = "".join(filename.split(' '))
# use previous filtering code if code below is insufficient or faulty
filename = re.sub(r'[.,/?<>\\:*|"]', '', filename).rstrip(". ") # remove special characters
reserved_pattern = r"^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])(\..*)?$" # match com1-9, lpt1-9, con, nul, prn, aux, nul
2026-04-02 11:34:18 +02:00
if re.match(reserved_pattern, filename, re.IGNORECASE):
filename = f"file_{filename}" # prepend "file_" if it matches
for i in filename:
if re.match("^[a-zA-Z0-9_=+.-]$", i):
2025-09-16 07:57:30 +08:00
final["fileName"] += i
2026-04-02 11:34:18 +02:00
interfaceSection = {
"PrivateKey": self.private_key,
"Address": self.allowed_ip,
2025-08-17 18:58:28 +08:00
"MTU": (
self.configuration.configurationInfo.OverridePeerSettings.MTU
if self.configuration.configurationInfo.OverridePeerSettings.MTU else self.mtu
),
2025-08-16 16:54:31 +08:00
"DNS": (
self.configuration.configurationInfo.OverridePeerSettings.DNS
if self.configuration.configurationInfo.OverridePeerSettings.DNS else self.DNS
)
}
2026-04-02 11:34:18 +02:00
2025-09-16 07:46:25 +08:00
if self.configuration.Protocol == "awg":
interfaceSection.update({
"Jc": self.configuration.Jc,
"Jmin": self.configuration.Jmin,
"Jmax": self.configuration.Jmax,
"S1": self.configuration.S1,
"S2": self.configuration.S2,
"S3": self.configuration.S3,
"S4": self.configuration.S4,
2025-09-16 07:46:25 +08:00
"H1": self.configuration.H1,
"H2": self.configuration.H2,
"H3": self.configuration.H3,
"H4": self.configuration.H4,
"I1": self.configuration.I1,
"I2": self.configuration.I2,
"I3": self.configuration.I3,
"I4": self.configuration.I4,
"I5": self.configuration.I5
2025-09-16 07:46:25 +08:00
})
2026-04-02 11:34:18 +02:00
peerSection = {
"PublicKey": self.configuration.PublicKey,
2025-08-17 18:58:28 +08:00
"AllowedIPs": (
self.configuration.configurationInfo.OverridePeerSettings.EndpointAllowedIPs
if self.configuration.configurationInfo.OverridePeerSettings.EndpointAllowedIPs else self.endpoint_allowed_ip
),
"Endpoint": f'{(self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint if self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint else self.configuration.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1])}:{(self.configuration.configurationInfo.OverridePeerSettings.ListenPort if self.configuration.configurationInfo.OverridePeerSettings.ListenPort else self.configuration.ListenPort)}',
"PersistentKeepalive": (
self.configuration.configurationInfo.OverridePeerSettings.PersistentKeepalive
if self.configuration.configurationInfo.OverridePeerSettings.PersistentKeepalive
else self.keepalive
),
"PresharedKey": self.preshared_key
}
combine = [interfaceSection.items(), peerSection.items()]
for s in range(len(combine)):
if s == 0:
2025-09-16 07:57:30 +08:00
final["file"] += "[Interface]\n"
else:
2025-09-16 07:57:30 +08:00
final["file"] += "\n[Peer]\n"
for (key, val) in combine[s]:
if val is not None and ((type(val) is str and len(val) > 0) or (type(val) is int and val > 0)):
2025-09-16 07:57:30 +08:00
final["file"] += f"{key} = {val}\n"
2026-04-02 11:34:18 +02:00
2025-09-17 16:35:57 +08:00
final["file"] = jinja2.Template(final["file"]).render(configuration=self.configuration)
2025-09-16 07:57:30 +08:00
if self.configuration.Protocol == "awg":
final["amneziaVPN"] = json.dumps({
"containers": [{
"awg": {
"isThirdPartyConfig": True,
"last_config": final['file'],
"port": self.configuration.ListenPort,
"transport_proto": "udp"
},
"container": "amnezia-awg"
}],
"defaultContainer": "amnezia-awg",
"description": self.name,
"hostName": (
self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint
if self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint
else self.configuration.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1])
})
return final
def getJobs(self):
self.jobs = self.configuration.AllPeerJobs.searchJob(self.configuration.Name, self.id)
def getShareLink(self):
self.ShareLink = self.configuration.AllPeerShareLinks.getLink(self.configuration.Name, self.id)
def resetDataUsage(self, mode: str):
try:
with self.configuration.engine.begin() as conn:
if mode == "total":
conn.execute(
self.configuration.peersTable.update().values({
"total_data": 0,
"cumu_data": 0,
"total_receive": 0,
"cumu_receive": 0,
"total_sent": 0,
"cumu_sent": 0
}).where(
self.configuration.peersTable.c.id == self.id
)
)
self.total_data = 0
self.total_receive = 0
self.total_sent = 0
self.cumu_data = 0
self.cumu_sent = 0
self.cumu_receive = 0
elif mode == "receive":
conn.execute(
self.configuration.peersTable.update().values({
"total_receive": 0,
"cumu_receive": 0,
}).where(
self.configuration.peersTable.c.id == self.id
)
)
self.cumu_receive = 0
self.total_receive = 0
elif mode == "sent":
conn.execute(
self.configuration.peersTable.update().values({
"total_sent": 0,
"cumu_sent": 0
}).where(
self.configuration.peersTable.c.id == self.id
)
)
self.cumu_sent = 0
self.total_sent = 0
else:
return False
except Exception as e:
print(e)
return False
2025-08-29 16:55:33 +08:00
return True
2025-09-07 17:04:22 +08:00
def getEndpoints(self):
result = []
with self.configuration.engine.connect() as conn:
result = conn.execute(
db.select(
self.configuration.peersHistoryEndpointTable.c.endpoint
).group_by(
self.configuration.peersHistoryEndpointTable.c.endpoint
).where(
self.configuration.peersHistoryEndpointTable.c.id == self.id
)
).mappings().fetchall()
return list(result)
2025-09-01 17:16:03 +08:00
def getTraffics(self, interval: int = 30, startDate: datetime.datetime = None, endDate: datetime.datetime = None):
if startDate is None and endDate is None:
endDate = datetime.datetime.now()
startDate = endDate - timedelta(minutes=interval)
else:
endDate = endDate.replace(hour=23, minute=59, second=59, microsecond=999999)
startDate = startDate.replace(hour=0, minute=0, second=0, microsecond=0)
with self.configuration.engine.connect() as conn:
result = conn.execute(
db.select(
self.configuration.peersTransferTable.c.cumu_data,
self.configuration.peersTransferTable.c.total_data,
self.configuration.peersTransferTable.c.cumu_receive,
self.configuration.peersTransferTable.c.total_receive,
self.configuration.peersTransferTable.c.cumu_sent,
self.configuration.peersTransferTable.c.total_sent,
self.configuration.peersTransferTable.c.time
).where(
db.and_(
self.configuration.peersTransferTable.c.id == self.id,
self.configuration.peersTransferTable.c.time <= endDate,
self.configuration.peersTransferTable.c.time >= startDate,
)
).order_by(
self.configuration.peersTransferTable.c.time
)
).mappings().fetchall()
return list(result)
2025-08-29 16:55:33 +08:00
def getSessions(self, startDate: datetime.datetime = None, endDate: datetime.datetime = None):
if endDate is None:
endDate = datetime.datetime.now()
if startDate is None:
2025-09-01 17:16:03 +08:00
startDate = endDate
2025-08-29 16:55:33 +08:00
endDate = endDate.replace(hour=23, minute=59, second=59, microsecond=999999)
startDate = startDate.replace(hour=0, minute=0, second=0, microsecond=0)
2025-09-01 17:16:03 +08:00
2025-08-29 16:55:33 +08:00
with self.configuration.engine.connect() as conn:
result = conn.execute(
db.select(
self.configuration.peersTransferTable.c.time
).where(
db.and_(
self.configuration.peersTransferTable.c.id == self.id,
self.configuration.peersTransferTable.c.time <= endDate,
self.configuration.peersTransferTable.c.time >= startDate,
)
).order_by(
self.configuration.peersTransferTable.c.time
)
).fetchall()
2025-09-01 17:16:03 +08:00
time = list(map(lambda x : x[0], result))
return time
2025-08-29 16:55:33 +08:00
def __duration(self, t1: datetime.datetime, t2: datetime.datetime):
delta = t1 - t2
hours, remainder = divmod(delta.total_seconds(), 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"