mirror of
https://github.com/donaldzou/WGDashboard.git
synced 2026-05-03 09:56:20 +00:00
Support for all keys ever used in AmneziaWG has been added. Setting an empty value to a key will prevent it from being used or rendered in the configuration. This ensures flexible support for all versions and specific configurations.
399 lines
16 KiB
Python
399 lines
16 KiB
Python
"""
|
|
Peer
|
|
"""
|
|
import base64
|
|
import datetime
|
|
import json
|
|
import os, subprocess, uuid, random, re
|
|
from datetime import timedelta
|
|
|
|
import jinja2
|
|
import sqlalchemy as db
|
|
from .PeerJob import PeerJob
|
|
from flask import current_app
|
|
from .PeerShareLink import PeerShareLink
|
|
from .Utilities import GenerateWireguardPublicKey, CheckAddress, ValidateDNSAddress
|
|
|
|
|
|
class Peer:
|
|
def __init__(self, tableData, configuration):
|
|
self.configuration = configuration
|
|
self.id = tableData["id"]
|
|
self.private_key = tableData["private_key"]
|
|
self.DNS = tableData["DNS"]
|
|
self.endpoint_allowed_ip = tableData["endpoint_allowed_ip"]
|
|
self.name = tableData["name"]
|
|
self.total_receive = tableData["total_receive"]
|
|
self.total_sent = tableData["total_sent"]
|
|
self.total_data = tableData["total_data"]
|
|
self.endpoint = tableData["endpoint"]
|
|
self.status = tableData["status"]
|
|
self.latest_handshake = tableData["latest_handshake"]
|
|
self.allowed_ip = tableData["allowed_ip"]
|
|
self.cumu_receive = tableData["cumu_receive"]
|
|
self.cumu_sent = tableData["cumu_sent"]
|
|
self.cumu_data = tableData["cumu_data"]
|
|
self.mtu = tableData["mtu"]
|
|
self.keepalive = tableData["keepalive"]
|
|
self.notes = tableData.get("notes", "")
|
|
self.remote_endpoint = tableData["remote_endpoint"]
|
|
self.preshared_key = tableData["preshared_key"]
|
|
self.jobs: list[PeerJob] = []
|
|
self.ShareLink: list[PeerShareLink] = []
|
|
self.getJobs()
|
|
self.getShareLink()
|
|
|
|
def toJson(self):
|
|
# self.getJobs()
|
|
# self.getShareLink()
|
|
return self.__dict__
|
|
|
|
def __repr__(self):
|
|
return str(self.toJson())
|
|
|
|
def updatePeer(self, name: str,
|
|
private_key: str,
|
|
preshared_key: str,
|
|
dns_addresses: str,
|
|
allowed_ip: str,
|
|
endpoint_allowed_ip: str,
|
|
mtu: int,
|
|
keepalive: int,
|
|
notes: str
|
|
) -> tuple[bool, str | None]:
|
|
|
|
if not self.configuration.getStatus():
|
|
self.configuration.toggleConfiguration()
|
|
|
|
# Before we do any compute, let us check if the given endpoint allowed ip is valid at all
|
|
if not CheckAddress(endpoint_allowed_ip):
|
|
return False, f"Endpoint Allowed IPs format is incorrect"
|
|
|
|
peers = []
|
|
for peer in self.configuration.getPeersList():
|
|
# Make sure to exclude your own data when updating since its not really relevant
|
|
if peer.id == self.id:
|
|
continue
|
|
peers.append(peer)
|
|
|
|
used_allowed_ips = []
|
|
for peer in peers:
|
|
ips = peer.allowed_ip.split(',')
|
|
ips = [ip.strip() for ip in ips]
|
|
used_allowed_ips.append(ips)
|
|
|
|
if allowed_ip in used_allowed_ips:
|
|
return False, "Allowed IP already taken by another peer"
|
|
|
|
if not ValidateDNSAddress(dns_addresses):
|
|
return False, f"DNS IP-Address or FQDN is incorrect"
|
|
|
|
if isinstance(mtu, str):
|
|
mtu = 0
|
|
|
|
if isinstance(keepalive, str):
|
|
keepalive = 0
|
|
|
|
if mtu not in range(0, 1461):
|
|
return False, "MTU format is not correct"
|
|
|
|
if keepalive < 0:
|
|
return False, "Persistent Keepalive format is not correct"
|
|
|
|
if len(private_key) > 0:
|
|
pubKey = GenerateWireguardPublicKey(private_key)
|
|
if not pubKey[0] or pubKey[1] != self.id:
|
|
return False, "Private key does not match with the public key"
|
|
|
|
try:
|
|
rand = random.Random()
|
|
uid = str(uuid.UUID(int=rand.getrandbits(128), version=4))
|
|
psk_exist = len(preshared_key) > 0
|
|
|
|
if psk_exist:
|
|
with open(uid, "w+") as f:
|
|
f.write(preshared_key)
|
|
|
|
newAllowedIPs = allowed_ip.replace(" ", "")
|
|
if not CheckAddress(newAllowedIPs):
|
|
return False, "Allowed IPs entry format is incorrect"
|
|
|
|
command = [self.configuration.Protocol, "set", self.configuration.Name, "peer", self.id, "allowed-ips", newAllowedIPs, "preshared-key", uid if psk_exist else "/dev/null"]
|
|
updateAllowedIp = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
|
|
|
if psk_exist: os.remove(uid)
|
|
|
|
if len(updateAllowedIp.decode().strip("\n")) != 0:
|
|
current_app.logger.error("Update peer failed when updating Allowed IPs")
|
|
return False, "Internal server error"
|
|
|
|
command = [f"{self.configuration.Protocol}-quick", "save", self.configuration.Name]
|
|
saveConfig = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
|
|
|
if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'):
|
|
current_app.logger.error("Update peer failed when saving the configuration")
|
|
return False, "Internal server error"
|
|
|
|
with self.configuration.engine.begin() as conn:
|
|
conn.execute(
|
|
self.configuration.peersTable.update().values({
|
|
"name": name,
|
|
"private_key": private_key,
|
|
"DNS": dns_addresses,
|
|
"endpoint_allowed_ip": endpoint_allowed_ip,
|
|
"mtu": mtu,
|
|
"keepalive": keepalive,
|
|
"notes": notes,
|
|
"preshared_key": preshared_key
|
|
}).where(
|
|
self.configuration.peersTable.c.id == self.id
|
|
)
|
|
)
|
|
return True, None
|
|
except subprocess.CalledProcessError as exc:
|
|
current_app.logger.error(f"Subprocess call failed:\n{exc.output.decode('UTF-8')}")
|
|
return False, "Internal server error"
|
|
|
|
def downloadPeer(self) -> dict[str, str]:
|
|
final = {
|
|
"fileName": "",
|
|
"file": ""
|
|
}
|
|
filename = self.name
|
|
if len(filename) == 0:
|
|
filename = "UntitledPeer"
|
|
filename = "".join(filename.split(' '))
|
|
|
|
# use previous filtering code if code below is insufficient or faulty
|
|
filename = re.sub(r'[.,/?<>\\:*|"]', '', filename).rstrip(". ") # remove special characters
|
|
|
|
reserved_pattern = r"^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])(\..*)?$" # match com1-9, lpt1-9, con, nul, prn, aux, nul
|
|
|
|
if re.match(reserved_pattern, filename, re.IGNORECASE):
|
|
filename = f"file_{filename}" # prepend "file_" if it matches
|
|
|
|
for i in filename:
|
|
if re.match("^[a-zA-Z0-9_=+.-]$", i):
|
|
final["fileName"] += i
|
|
|
|
interfaceSection = {
|
|
"PrivateKey": self.private_key,
|
|
"Address": self.allowed_ip,
|
|
"MTU": (
|
|
self.configuration.configurationInfo.OverridePeerSettings.MTU
|
|
if self.configuration.configurationInfo.OverridePeerSettings.MTU else self.mtu
|
|
),
|
|
"DNS": (
|
|
self.configuration.configurationInfo.OverridePeerSettings.DNS
|
|
if self.configuration.configurationInfo.OverridePeerSettings.DNS else self.DNS
|
|
)
|
|
}
|
|
|
|
if self.configuration.Protocol == "awg":
|
|
interfaceSection.update({
|
|
"Jc": self.configuration.Jc,
|
|
"Jmin": self.configuration.Jmin,
|
|
"Jmax": self.configuration.Jmax,
|
|
"S1": self.configuration.S1,
|
|
"S2": self.configuration.S2,
|
|
"S3": self.configuration.S3,
|
|
"S4": self.configuration.S4,
|
|
"H1": self.configuration.H1,
|
|
"H2": self.configuration.H2,
|
|
"H3": self.configuration.H3,
|
|
"H4": self.configuration.H4,
|
|
"I1": self.configuration.I1,
|
|
"I2": self.configuration.I2,
|
|
"I3": self.configuration.I3,
|
|
"I4": self.configuration.I4,
|
|
"I5": self.configuration.I5,
|
|
"J1": self.configuration.J1,
|
|
"J2": self.configuration.J2,
|
|
"J3": self.configuration.J3,
|
|
"Itime": self.configuration.Itime
|
|
})
|
|
|
|
peerSection = {
|
|
"PublicKey": self.configuration.PublicKey,
|
|
"AllowedIPs": (
|
|
self.configuration.configurationInfo.OverridePeerSettings.EndpointAllowedIPs
|
|
if self.configuration.configurationInfo.OverridePeerSettings.EndpointAllowedIPs else self.endpoint_allowed_ip
|
|
),
|
|
"Endpoint": f'{(self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint if self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint else self.configuration.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1])}:{(self.configuration.configurationInfo.OverridePeerSettings.ListenPort if self.configuration.configurationInfo.OverridePeerSettings.ListenPort else self.configuration.ListenPort)}',
|
|
"PersistentKeepalive": (
|
|
self.configuration.configurationInfo.OverridePeerSettings.PersistentKeepalive
|
|
if self.configuration.configurationInfo.OverridePeerSettings.PersistentKeepalive
|
|
else self.keepalive
|
|
),
|
|
"PresharedKey": self.preshared_key
|
|
}
|
|
combine = [interfaceSection.items(), peerSection.items()]
|
|
for s in range(len(combine)):
|
|
if s == 0:
|
|
final["file"] += "[Interface]\n"
|
|
else:
|
|
final["file"] += "\n[Peer]\n"
|
|
for (key, val) in combine[s]:
|
|
if val is not None and ((type(val) is str and len(val) > 0) or (type(val) is int and val > 0)):
|
|
final["file"] += f"{key} = {val}\n"
|
|
|
|
final["file"] = jinja2.Template(final["file"]).render(configuration=self.configuration)
|
|
|
|
|
|
if self.configuration.Protocol == "awg":
|
|
final["amneziaVPN"] = json.dumps({
|
|
"containers": [{
|
|
"awg": {
|
|
"isThirdPartyConfig": True,
|
|
"last_config": final['file'],
|
|
"port": self.configuration.ListenPort,
|
|
"transport_proto": "udp"
|
|
},
|
|
"container": "amnezia-awg"
|
|
}],
|
|
"defaultContainer": "amnezia-awg",
|
|
"description": self.name,
|
|
"hostName": (
|
|
self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint
|
|
if self.configuration.configurationInfo.OverridePeerSettings.PeerRemoteEndpoint
|
|
else self.configuration.DashboardConfig.GetConfig("Peers", "remote_endpoint")[1])
|
|
})
|
|
return final
|
|
|
|
def getJobs(self):
|
|
self.jobs = self.configuration.AllPeerJobs.searchJob(self.configuration.Name, self.id)
|
|
|
|
def getShareLink(self):
|
|
self.ShareLink = self.configuration.AllPeerShareLinks.getLink(self.configuration.Name, self.id)
|
|
|
|
def resetDataUsage(self, mode: str):
|
|
try:
|
|
with self.configuration.engine.begin() as conn:
|
|
if mode == "total":
|
|
conn.execute(
|
|
self.configuration.peersTable.update().values({
|
|
"total_data": 0,
|
|
"cumu_data": 0,
|
|
"total_receive": 0,
|
|
"cumu_receive": 0,
|
|
"total_sent": 0,
|
|
"cumu_sent": 0
|
|
}).where(
|
|
self.configuration.peersTable.c.id == self.id
|
|
)
|
|
)
|
|
self.total_data = 0
|
|
self.total_receive = 0
|
|
self.total_sent = 0
|
|
self.cumu_data = 0
|
|
self.cumu_sent = 0
|
|
self.cumu_receive = 0
|
|
elif mode == "receive":
|
|
conn.execute(
|
|
self.configuration.peersTable.update().values({
|
|
"total_receive": 0,
|
|
"cumu_receive": 0,
|
|
}).where(
|
|
self.configuration.peersTable.c.id == self.id
|
|
)
|
|
)
|
|
self.cumu_receive = 0
|
|
self.total_receive = 0
|
|
elif mode == "sent":
|
|
conn.execute(
|
|
self.configuration.peersTable.update().values({
|
|
"total_sent": 0,
|
|
"cumu_sent": 0
|
|
}).where(
|
|
self.configuration.peersTable.c.id == self.id
|
|
)
|
|
)
|
|
self.cumu_sent = 0
|
|
self.total_sent = 0
|
|
else:
|
|
return False
|
|
except Exception as e:
|
|
print(e)
|
|
return False
|
|
return True
|
|
|
|
def getEndpoints(self):
|
|
result = []
|
|
with self.configuration.engine.connect() as conn:
|
|
result = conn.execute(
|
|
db.select(
|
|
self.configuration.peersHistoryEndpointTable.c.endpoint
|
|
).group_by(
|
|
self.configuration.peersHistoryEndpointTable.c.endpoint
|
|
).where(
|
|
self.configuration.peersHistoryEndpointTable.c.id == self.id
|
|
)
|
|
).mappings().fetchall()
|
|
return list(result)
|
|
|
|
def getTraffics(self, interval: int = 30, startDate: datetime.datetime = None, endDate: datetime.datetime = None):
|
|
if startDate is None and endDate is None:
|
|
endDate = datetime.datetime.now()
|
|
startDate = endDate - timedelta(minutes=interval)
|
|
else:
|
|
endDate = endDate.replace(hour=23, minute=59, second=59, microsecond=999999)
|
|
startDate = startDate.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
with self.configuration.engine.connect() as conn:
|
|
result = conn.execute(
|
|
db.select(
|
|
self.configuration.peersTransferTable.c.cumu_data,
|
|
self.configuration.peersTransferTable.c.total_data,
|
|
self.configuration.peersTransferTable.c.cumu_receive,
|
|
self.configuration.peersTransferTable.c.total_receive,
|
|
self.configuration.peersTransferTable.c.cumu_sent,
|
|
self.configuration.peersTransferTable.c.total_sent,
|
|
self.configuration.peersTransferTable.c.time
|
|
).where(
|
|
db.and_(
|
|
self.configuration.peersTransferTable.c.id == self.id,
|
|
self.configuration.peersTransferTable.c.time <= endDate,
|
|
self.configuration.peersTransferTable.c.time >= startDate,
|
|
)
|
|
).order_by(
|
|
self.configuration.peersTransferTable.c.time
|
|
)
|
|
).mappings().fetchall()
|
|
return list(result)
|
|
|
|
|
|
def getSessions(self, startDate: datetime.datetime = None, endDate: datetime.datetime = None):
|
|
if endDate is None:
|
|
endDate = datetime.datetime.now()
|
|
|
|
if startDate is None:
|
|
startDate = endDate
|
|
|
|
endDate = endDate.replace(hour=23, minute=59, second=59, microsecond=999999)
|
|
startDate = startDate.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
|
|
with self.configuration.engine.connect() as conn:
|
|
result = conn.execute(
|
|
db.select(
|
|
self.configuration.peersTransferTable.c.time
|
|
).where(
|
|
db.and_(
|
|
self.configuration.peersTransferTable.c.id == self.id,
|
|
self.configuration.peersTransferTable.c.time <= endDate,
|
|
self.configuration.peersTransferTable.c.time >= startDate,
|
|
)
|
|
).order_by(
|
|
self.configuration.peersTransferTable.c.time
|
|
)
|
|
).fetchall()
|
|
time = list(map(lambda x : x[0], result))
|
|
return time
|
|
|
|
def __duration(self, t1: datetime.datetime, t2: datetime.datetime):
|
|
delta = t1 - t2
|
|
|
|
hours, remainder = divmod(delta.total_seconds(), 3600)
|
|
minutes, seconds = divmod(remainder, 60)
|
|
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"
|