From aa66a5ffb2f354c35199b934a9ffef61c63392c0 Mon Sep 17 00:00:00 2001 From: Donald Zou Date: Thu, 3 Jul 2025 19:20:01 +0800 Subject: [PATCH] OIDC should be good to go --- src/client.py | 18 +++--- src/modules/DashboardClients.py | 23 +++++++- src/modules/DashboardOIDC.py | 80 +++++++++++++++------------ src/static/client/src/main.js | 11 +++- src/static/client/src/views/index.vue | 30 ++++++++-- 5 files changed, 109 insertions(+), 53 deletions(-) diff --git a/src/client.py b/src/client.py index 8b50021..a22d4c6 100644 --- a/src/client.py +++ b/src/client.py @@ -21,7 +21,7 @@ def ResponseObject(status=True, message=None, data=None, status_code = 200) -> F def login_required(f): @wraps(f) def func(*args, **kwargs): - if session.get("Email") is None or session.get("totpVerified") is None or not session.get("totpVerified") or session.get("role") != "client": + if session.get("Email") is None or session.get("TotpVerified") is None or not session.get("TotpVerified") or session.get("Role") != "client": return ResponseObject(False, "Unauthorized access.", data=None, status_code=401) return f(*args, **kwargs) return func @@ -60,8 +60,8 @@ def createClientBlueprint(wireguardConfigurations: dict[WireguardConfiguration], return ResponseObject(status, oidcData) session['Email'] = oidcData.get('email') - session['role'] = 'client' - session['totpVerified'] = True + session['Role'] = 'client' + session['TotpVerified'] = True return ResponseObject() @@ -71,15 +71,15 @@ def createClientBlueprint(wireguardConfigurations: dict[WireguardConfiguration], status, msg = DashboardClients.SignIn(**data) if status: session['Email'] = data.get('Email') - session['role'] = 'client' - session['totpVerified'] = False + session['Role'] = 'client' + session['TotpVerified'] = False return ResponseObject(status, msg) @client.get(f'{prefix}/api/signout') def ClientAPI_SignOut(): - session['Email'] = None - session['role'] = None - session['totpVerified'] = None + if session.get("SignInMethod") == "OIDC": + DashboardClients.SignOut_OIDC() + session.clear() return ResponseObject(True) @client.get(f'{prefix}/api/signin/totp') @@ -102,7 +102,7 @@ def createClientBlueprint(wireguardConfigurations: dict[WireguardConfiguration], if status: if session.get('Email') is None: return ResponseObject(False, "Sign in status is invalid", status_code=401) - session['totpVerified'] = True + session['TotpVerified'] = True return ResponseObject(True, data={ "Email": session.get('Email'), "Profile": DashboardClients.GetClientProfile(session.get("ClientID")) diff --git a/src/modules/DashboardClients.py b/src/modules/DashboardClients.py index 9cec3ea..9ffee86 100644 --- a/src/modules/DashboardClients.py +++ b/src/modules/DashboardClients.py @@ -4,6 +4,7 @@ import uuid import bcrypt import pyotp import sqlalchemy as db +import requests from .ConnectionString import ConnectionString from .DashboardClientsPeerAssignment import DashboardClientsPeerAssignment @@ -133,19 +134,34 @@ class DashboardClients: return True, newClientUUID return False, "User already signed up" + def SignOut_OIDC(self): + sessionPayload = session.get('OIDCPayload') + status, oidc_config = self.OIDC.GetProviderConfiguration(session.get('SignInPayload').get("Provider")) + signOut = requests.get( + oidc_config.get("end_session_endpoint"), + params={ + 'id_token_hint': session.get('SignInPayload').get("Payload").get('sid') + } + ) + return True + def SignIn_OIDC(self, **kwargs): status, data = self.OIDC.VerifyToken(**kwargs) if not status: - return False, "Sign in failed" + return False, "Sign in failed. Reason: " + data existingClient = self.SignIn_OIDC_UserExistence(data) if not existingClient: status, newClientUUID = self.SignUp_OIDC(data) session['ClientID'] = newClientUUID else: session['ClientID'] = existingClient.get("ClientID") - + session['SignInMethod'] = 'OIDC' + session['SignInPayload'] = { + "Provider": kwargs.get('provider'), + "Payload": data + } return True, data - + def SignIn(self, Email, Password) -> tuple[bool, str]: if not all([Email, Password]): return False, "Please fill in all fields" @@ -153,6 +169,7 @@ class DashboardClients: if existingClient: checkPwd = self.SignIn_ValidatePassword(Email, Password) if checkPwd: + session['SignInMethod'] = 'local' session['Email'] = Email session['ClientID'] = existingClient.get("ClientID") return True, self.DashboardClientsTOTP.GenerateToken(existingClient.get("ClientID")) diff --git a/src/modules/DashboardOIDC.py b/src/modules/DashboardOIDC.py index 4a8f9e2..f07c2bc 100644 --- a/src/modules/DashboardOIDC.py +++ b/src/modules/DashboardOIDC.py @@ -10,6 +10,7 @@ class DashboardOIDC: ConfigurationFilePath = os.path.join(ConfigurationPath, 'wg-dashboard-oidc-providers.json') def __init__(self): self.providers: dict[str, dict] = {} + self.provider_secret: dict[str, str] = {} self.__default = { 'Provider': { 'client_id': '', @@ -26,52 +27,36 @@ class DashboardOIDC: self.ReadFile() def GetProviders(self): - providers = {} - for k in self.providers.keys(): - if all([self.providers[k]['client_id'], self.providers[k]['client_secret'], self.providers[k]['issuer']]): - try: - print("Requesting " + f"{self.providers[k]['issuer'].strip('/')}/.well-known/openid-configuration") - oidc_config = requests.get( - f"{self.providers[k]['issuer'].strip('/')}/.well-known/openid-configuration", - verify=certifi.where() - ).json() - providers[k] = { - 'client_id': self.providers[k]['client_id'], - 'issuer': self.providers[k]['issuer'].strip('/') - } - except Exception as e: - current_app.logger.error("Failed to request OIDC config for this provider: " + self.providers[k]['issuer'].strip('/'), exc_info=e) - - return providers + return self.providers def VerifyToken(self, provider, code, redirect_uri): try: if not all([provider, code, redirect_uri]): - return False, "" + return False, "Please provide all parameters" if provider not in self.providers.keys(): return False, "Provider does not exist" - provider = self.providers.get(provider) - oidc_config = requests.get( - f"{provider.get('issuer').strip('/')}/.well-known/openid-configuration", - verify=certifi.where() - - ).json() + secrete = self.provider_secret.get(provider) + oidc_config_status, oidc_config = self.GetProviderConfiguration(provider) + provider_info = self.providers.get(provider) + data = { "grant_type": "authorization_code", "code": code, "redirect_uri": redirect_uri, - "client_id": provider.get('client_id'), - "client_secret": provider.get('client_secret') + "client_id": provider_info.get('client_id'), + "client_secret": secrete } try: tokens = requests.post(oidc_config.get('token_endpoint'), data=data).json() if not all([tokens.get('access_token'), tokens.get('id_token')]): + print(oidc_config.get('token_endpoint'), data) return False, tokens.get('error_description', None) except Exception as e: + print(str(e)) return False, str(e) access_token = tokens.get('access_token') @@ -84,31 +69,58 @@ class DashboardOIDC: kid = headers["kid"] key = next(k for k in jwks["keys"] if k["kid"] == kid) - - print(key) - + payload = jwt.decode( id_token, key, algorithms=[key["alg"]], - audience=provider.get('client_id'), + audience=provider_info.get('client_id'), issuer=issuer, access_token=access_token ) - + print(payload) return True, payload except Exception as e: current_app.logger.error('Read OIDC file failed. Reason: ' + str(e), provider, code, redirect_uri) return False, str(e) - + + def GetProviderConfiguration(self, provider_name): + if not all([provider_name]): + return False, None + provider = self.providers.get(provider_name) + try: + oidc_config = requests.get( + f"{provider.get('issuer').strip('/')}/.well-known/openid-configuration", + verify=certifi.where() + ).json() + except Exception as e: + current_app.logger.error("Failed to get OpenID Configuration of " + provider.get('issuer'), exc_info=e) + return False, None + return True, oidc_config def ReadFile(self): decoder = json.JSONDecoder() try: - self.providers = decoder.decode( + providers = decoder.decode( open(DashboardOIDC.ConfigurationFilePath, 'r').read() ) - print(self.providers) + for k in providers.keys(): + if all([providers[k]['client_id'], providers[k]['client_secret'], providers[k]['issuer']]): + try: + print("Requesting " + f"{providers[k]['issuer'].strip('/')}/.well-known/openid-configuration") + oidc_config = requests.get( + f"{providers[k]['issuer'].strip('/')}/.well-known/openid-configuration", + timeout=3, + verify=certifi.where() + ).json() + self.providers[k] = { + 'client_id': providers[k]['client_id'], + 'issuer': providers[k]['issuer'].strip('/'), + 'openid_configuration': oidc_config + } + self.provider_secret[k] = providers[k]['client_secret'] + except Exception as e: + current_app.logger.error("Failed to request OIDC config for this provider: " + providers[k]['issuer'].strip('/'), exc_info=e) except Exception as e: current_app.logger.error('Read OIDC file failed. Reason: ' + str(e)) return False \ No newline at end of file diff --git a/src/static/client/src/main.js b/src/static/client/src/main.js index 57cb194..7dced1f 100644 --- a/src/static/client/src/main.js +++ b/src/static/client/src/main.js @@ -21,13 +21,22 @@ const initApp = () => { app.mount("#app") } +function removeSearchString() { + let url = new URL(window.location.href); + url.search = ''; // Remove all query parameters + history.replaceState({}, document.title, url.toString()); +} + if (state && code){ axiosPost("/api/signin/oidc", { provider: state, code: code, redirect_uri: window.location.protocol + '//' + window.location.host + window.location.pathname }).then(data => { - window.location.search = '' + let url = new URL(window.location.href); + url.search = ''; + history.replaceState({}, document.title, url.toString()); + initApp() if (!data.status){ const store = clientStore() diff --git a/src/static/client/src/views/index.vue b/src/static/client/src/views/index.vue index 06df5a2..f2711f1 100644 --- a/src/static/client/src/views/index.vue +++ b/src/static/client/src/views/index.vue @@ -1,9 +1,10 @@