Files
WGDashboard/src/modules/Peer.py
2026-04-02 11:34:50 +02:00

395 lines
16 KiB
Python

"""
Peer
"""
import base64
import datetime
import json
import os, subprocess, uuid, random, re
from datetime import timedelta
import jinja2
import sqlalchemy as db
from .PeerJob import PeerJob
from flask import current_app
from .PeerShareLink import PeerShareLink
from .Utilities import GenerateWireguardPublicKey, CheckAddress, ValidateDNSAddress
class Peer:
def __init__(self, tableData, configuration):
self.configuration = configuration
self.id = tableData["id"]
self.private_key = tableData["private_key"]
self.DNS = tableData["DNS"]
self.endpoint_allowed_ip = tableData["endpoint_allowed_ip"]
self.name = tableData["name"]
self.total_receive = tableData["total_receive"]
self.total_sent = tableData["total_sent"]
self.total_data = tableData["total_data"]
self.endpoint = tableData["endpoint"]
self.status = tableData["status"]
self.latest_handshake = tableData["latest_handshake"]
self.allowed_ip = tableData["allowed_ip"]
self.cumu_receive = tableData["cumu_receive"]
self.cumu_sent = tableData["cumu_sent"]
self.cumu_data = tableData["cumu_data"]
self.mtu = tableData["mtu"]
self.keepalive = tableData["keepalive"]
self.notes = tableData.get("notes", "")
self.remote_endpoint = tableData["remote_endpoint"]
self.preshared_key = tableData["preshared_key"]
self.jobs: list[PeerJob] = []
self.ShareLink: list[PeerShareLink] = []
self.getJobs()
self.getShareLink()
def toJson(self):
# self.getJobs()
# self.getShareLink()
return self.__dict__
def __repr__(self):
return str(self.toJson())
def updatePeer(self, name: str,
private_key: str,
preshared_key: str,
dns_addresses: str,
allowed_ip: str,
endpoint_allowed_ip: str,
mtu: int,
keepalive: int,
notes: str
) -> tuple[bool, str | None]:
if not self.configuration.getStatus():
self.configuration.toggleConfiguration()
# Before we do any compute, let us check if the given endpoint allowed ip is valid at all
if not CheckAddress(endpoint_allowed_ip):
return False, f"Endpoint Allowed IPs format is incorrect"
peers = []
for peer in self.configuration.getPeersList():
# Make sure to exclude your own data when updating since its not really relevant
if peer.id == self.id:
continue
peers.append(peer)
used_allowed_ips = []
for peer in peers:
ips = peer.allowed_ip.split(',')
ips = [ip.strip() for ip in ips]
used_allowed_ips.append(ips)
if allowed_ip in used_allowed_ips:
return False, "Allowed IP already taken by another peer"
if not ValidateDNSAddress(dns_addresses):
return False, f"DNS IP-Address or FQDN is incorrect"
if isinstance(mtu, str):
mtu = 0
if isinstance(keepalive, str):
keepalive = 0
if mtu not in range(0, 1461):
return False, "MTU format is not correct"
if keepalive < 0:
return False, "Persistent Keepalive format is not correct"
if len(private_key) > 0:
pubKey = GenerateWireguardPublicKey(private_key)
if not pubKey[0] or pubKey[1] != self.id:
return False, "Private key does not match with the public key"
try:
rand = random.Random()
uid = str(uuid.UUID(int=rand.getrandbits(128), version=4))
psk_exist = len(preshared_key) > 0
if psk_exist:
with open(uid, "w+") as f:
f.write(preshared_key)
newAllowedIPs = allowed_ip.replace(" ", "")
if not CheckAddress(newAllowedIPs):
return False, "Allowed IPs entry format is incorrect"
command = [self.configuration.Protocol, "set", self.configuration.Name, "peer", self.id, "allowed-ips", newAllowedIPs, "preshared-key", uid if psk_exist else "/dev/null"]
updateAllowedIp = subprocess.check_output(command, stderr=subprocess.STDOUT)
if psk_exist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
current_app.logger.error("Update peer failed when updating Allowed IPs")
return False, "Internal server error"
command = [f"{self.configuration.Protocol}-quick", "save", self.configuration.Name]
saveConfig = subprocess.check_output(command, stderr=subprocess.STDOUT)
if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'):
current_app.logger.error("Update peer failed when saving the configuration")
return False, "Internal server error"
with self.configuration.engine.begin() as conn:
conn.execute(
self.configuration.peersTable.update().values({
"name": name,
"private_key": private_key,
"DNS": dns_addresses,
"endpoint_allowed_ip": endpoint_allowed_ip,
"mtu": mtu,
"keepalive": keepalive,
"notes": notes,
"preshared_key": preshared_key
}).where(
self.configuration.peersTable.c.id == self.id
)
)
return True, None
except subprocess.CalledProcessError as exc:
current_app.logger.error(f"Subprocess call failed:\n{exc.output.decode("UTF-8")}")
return False, "Internal server error"
def downloadPeer(self) -> dict[str, str]:
final = {
"fileName": "",
"file": ""
}
filename = self.name
if len(filename) == 0:
filename = "UntitledPeer"
filename = "".join(filename.split(' '))
# use previous filtering code if code below is insufficient or faulty
filename = re.sub(r'[.,/?<>\\:*|"]', '', filename).rstrip(". ") # remove special characters
reserved_pattern = r"^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])(\..*)?$" # match com1-9, lpt1-9, con, nul, prn, aux, nul
if re.match(reserved_pattern, filename, re.IGNORECASE):
filename = f"file_{filename}" # prepend "file_" if it matches
for i in filename:
if re.match("^[a-zA-Z0-9_=+.-]$", i):
final["fileName"] += i
interfaceSection = {
"PrivateKey": self.private_key,
"Address": self.allowed_ip,
"MTU": (
self.configuration.configurationInfo.OverridePeerSettings.MTU
if self.configuration.configurationInfo.OverridePeerSettings.MTU else self.mtu
),
"DNS": (
self.configuration.configurationInfo.OverridePeerSettings.DNS
if self.configuration.configurationInfo.OverridePeerSettings.DNS else self.DNS
)
}
if self.configuration.Protocol == "awg":
interfaceSection.update({
"Jc": self.configuration.Jc,
"Jmin": self.configuration.Jmin,
"Jmax": self.configuration.Jmax,
"S1": self.configuration.S1,
"S2": self.configuration.S2,
"S3": self.configuration.S3,
"S4": self.configuration.S4,
"H1": self.configuration.H1,
"H2": self.configuration.H2,
"H3": self.configuration.H3,
"H4": self.configuration.H4,
"I1": self.configuration.I1,
"I2": self.configuration.I2,
"I3": self.configuration.I3,
"I4": self.configuration.I4,
"I5": self.configuration.I5
})
peerSection = {
"PublicKey": self.configuration.PublicKey,
"AllowedIPs": (
self.configuration.configurationInfo.OverridePeerSettings.EndpointAllowedIPs
if self.configuration.configurationInfo.OverridePeerSettings.EndpointAllowedIPs else self.endpoint_allowed_ip
),
"Endpoint": f'{(self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint if self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint else self.configuration.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1])}:{(self.configuration.configurationInfo.OverridePeerSettings.ListenPort if self.configuration.configurationInfo.OverridePeerSettings.ListenPort else self.configuration.ListenPort)}',
"PersistentKeepalive": (
self.configuration.configurationInfo.OverridePeerSettings.PersistentKeepalive
if self.configuration.configurationInfo.OverridePeerSettings.PersistentKeepalive
else self.keepalive
),
"PresharedKey": self.preshared_key
}
combine = [interfaceSection.items(), peerSection.items()]
for s in range(len(combine)):
if s == 0:
final["file"] += "[Interface]\n"
else:
final["file"] += "\n[Peer]\n"
for (key, val) in combine[s]:
if val is not None and ((type(val) is str and len(val) > 0) or (type(val) is int and val > 0)):
final["file"] += f"{key} = {val}\n"
final["file"] = jinja2.Template(final["file"]).render(configuration=self.configuration)
if self.configuration.Protocol == "awg":
final["amneziaVPN"] = json.dumps({
"containers": [{
"awg": {
"isThirdPartyConfig": True,
"last_config": final['file'],
"port": self.configuration.ListenPort,
"transport_proto": "udp"
},
"container": "amnezia-awg"
}],
"defaultContainer": "amnezia-awg",
"description": self.name,
"hostName": (
self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint
if self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint
else self.configuration.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1])
})
return final
def getJobs(self):
self.jobs = self.configuration.AllPeerJobs.searchJob(self.configuration.Name, self.id)
def getShareLink(self):
self.ShareLink = self.configuration.AllPeerShareLinks.getLink(self.configuration.Name, self.id)
def resetDataUsage(self, mode: str):
try:
with self.configuration.engine.begin() as conn:
if mode == "total":
conn.execute(
self.configuration.peersTable.update().values({
"total_data": 0,
"cumu_data": 0,
"total_receive": 0,
"cumu_receive": 0,
"total_sent": 0,
"cumu_sent": 0
}).where(
self.configuration.peersTable.c.id == self.id
)
)
self.total_data = 0
self.total_receive = 0
self.total_sent = 0
self.cumu_data = 0
self.cumu_sent = 0
self.cumu_receive = 0
elif mode == "receive":
conn.execute(
self.configuration.peersTable.update().values({
"total_receive": 0,
"cumu_receive": 0,
}).where(
self.configuration.peersTable.c.id == self.id
)
)
self.cumu_receive = 0
self.total_receive = 0
elif mode == "sent":
conn.execute(
self.configuration.peersTable.update().values({
"total_sent": 0,
"cumu_sent": 0
}).where(
self.configuration.peersTable.c.id == self.id
)
)
self.cumu_sent = 0
self.total_sent = 0
else:
return False
except Exception as e:
print(e)
return False
return True
def getEndpoints(self):
result = []
with self.configuration.engine.connect() as conn:
result = conn.execute(
db.select(
self.configuration.peersHistoryEndpointTable.c.endpoint
).group_by(
self.configuration.peersHistoryEndpointTable.c.endpoint
).where(
self.configuration.peersHistoryEndpointTable.c.id == self.id
)
).mappings().fetchall()
return list(result)
def getTraffics(self, interval: int = 30, startDate: datetime.datetime = None, endDate: datetime.datetime = None):
if startDate is None and endDate is None:
endDate = datetime.datetime.now()
startDate = endDate - timedelta(minutes=interval)
else:
endDate = endDate.replace(hour=23, minute=59, second=59, microsecond=999999)
startDate = startDate.replace(hour=0, minute=0, second=0, microsecond=0)
with self.configuration.engine.connect() as conn:
result = conn.execute(
db.select(
self.configuration.peersTransferTable.c.cumu_data,
self.configuration.peersTransferTable.c.total_data,
self.configuration.peersTransferTable.c.cumu_receive,
self.configuration.peersTransferTable.c.total_receive,
self.configuration.peersTransferTable.c.cumu_sent,
self.configuration.peersTransferTable.c.total_sent,
self.configuration.peersTransferTable.c.time
).where(
db.and_(
self.configuration.peersTransferTable.c.id == self.id,
self.configuration.peersTransferTable.c.time <= endDate,
self.configuration.peersTransferTable.c.time >= startDate,
)
).order_by(
self.configuration.peersTransferTable.c.time
)
).mappings().fetchall()
return list(result)
def getSessions(self, startDate: datetime.datetime = None, endDate: datetime.datetime = None):
if endDate is None:
endDate = datetime.datetime.now()
if startDate is None:
startDate = endDate
endDate = endDate.replace(hour=23, minute=59, second=59, microsecond=999999)
startDate = startDate.replace(hour=0, minute=0, second=0, microsecond=0)
with self.configuration.engine.connect() as conn:
result = conn.execute(
db.select(
self.configuration.peersTransferTable.c.time
).where(
db.and_(
self.configuration.peersTransferTable.c.id == self.id,
self.configuration.peersTransferTable.c.time <= endDate,
self.configuration.peersTransferTable.c.time >= startDate,
)
).order_by(
self.configuration.peersTransferTable.c.time
)
).fetchall()
time = list(map(lambda x : x[0], result))
return time
def __duration(self, t1: datetime.datetime, t2: datetime.datetime):
delta = t1 - t2
hours, remainder = divmod(delta.total_seconds(), 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"