OIDC is ready? I think?

This commit is contained in:
Donald Zou
2025-06-29 16:11:05 +08:00
parent 3d75f6bbbd
commit 299d84b16a
51 changed files with 353 additions and 146 deletions

View File

@@ -8,6 +8,7 @@ import sqlalchemy as db
from .ConnectionString import ConnectionString
from .DashboardClientsPeerAssignment import DashboardClientsPeerAssignment
from .DashboardClientsTOTP import DashboardClientsTOTP
from .DashboardOIDC import DashboardOIDC
from .Utilities import ValidatePasswordStrength
from .DashboardLogger import DashboardLogger
from flask import session
@@ -18,6 +19,7 @@ class DashboardClients:
self.logger = DashboardLogger()
self.engine = db.create_engine(ConnectionString("wgdashboard"))
self.metadata = db.MetaData()
self.OIDC = DashboardOIDC()
self.dashboardClientsTable = db.Table(
'DashboardClients', self.metadata,
@@ -33,6 +35,18 @@ class DashboardClients:
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)),
extend_existing=True,
)
self.dashboardOIDCClientsTable = db.Table(
'DashboardOIDCClients', self.metadata,
db.Column('ClientID', db.String(255), nullable=False, primary_key=True),
db.Column('Email', db.String(255), nullable=False, index=True),
db.Column('ProviderIssuer', db.String(500), nullable=False, index=True),
db.Column('ProviderSubject', db.String(500), nullable=False, index=True),
db.Column('CreatedDate',
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP),
server_default=db.func.now()),
extend_existing=True,
)
self.dashboardClientsInfoTable = db.Table(
'DashboardClientsInfo', self.metadata,
@@ -85,7 +99,53 @@ class DashboardClients:
)
).mappings().fetchone()
return existingClient
def SignIn_OIDC_UserExistence(self, data: dict[str, str]):
with self.engine.connect() as conn:
existingClient = conn.execute(
self.dashboardOIDCClientsTable.select().where(
db.and_(
self.dashboardOIDCClientsTable.c.ProviderIssuer == data.get('iss'),
self.dashboardOIDCClientsTable.c.ProviderSubject == data.get('sub'),
)
)
).mappings().fetchone()
return existingClient
def SignUp_OIDC(self, data: dict[str, str]) -> tuple[bool, str] | tuple[bool, None]:
if not self.SignIn_OIDC_UserExistence(data):
with self.engine.begin() as conn:
newClientUUID = str(uuid.uuid4())
conn.execute(
self.dashboardOIDCClientsTable.insert().values({
"ClientID": newClientUUID,
"Email": data.get('email', ''),
"ProviderIssuer": data.get('iss', ''),
"ProviderSubject": data.get('sub', '')
})
)
conn.execute(
self.dashboardClientsInfoTable.insert().values({
"ClientID": newClientUUID
})
)
self.logger.log(Message=f"User {data.get('email', '')} from {data.get('iss', '')} signed up")
return True, newClientUUID
return False, "User already signed up"
def SignIn_OIDC(self, **kwargs):
status, data = self.OIDC.VerifyToken(**kwargs)
if not status:
return False, "Sign in failed"
existingClient = self.SignIn_OIDC_UserExistence(data)
if not existingClient:
status, newClientUUID = self.SignUp_OIDC(data)
session['ClientID'] = newClientUUID
else:
session['ClientID'] = existingClient.get("ClientID")
return True, data
def SignIn(self, Email, Password) -> tuple[bool, str]:
if not all([Email, Password]):
return False, "Please fill in all fields"
@@ -149,7 +209,9 @@ class DashboardClients:
"ClientID": newClientUUID,
"Email": Email,
"Password": bcrypt.hashpw(encodePassword, bcrypt.gensalt()).decode("utf-8"),
"TotpKey": totpKey
"TotpKey": totpKey,
"AuthType": "local",
"AuthSrc": "local"
})
)
conn.execute(

View File

@@ -23,6 +23,17 @@ class DashboardOIDC:
f.write(encoder.encode(self.__default))
self.ReadFile()
def GetProviders(self):
providers = {}
for k in self.providers.keys():
if all([self.providers[k]['client_id'], self.providers[k]['client_secret'], self.providers[k]['issuer']]):
providers[k] = {
'client_id': self.providers[k]['client_id'],
'issuer': self.providers[k]['issuer'].strip('/')
}
return providers
def VerifyToken(self, provider, code, redirect_uri):
if not all([provider, code, redirect_uri]):
@@ -32,7 +43,7 @@ class DashboardOIDC:
return False, "Provider does not exist"
provider = self.providers.get(provider)
oidc_config = requests.get(f"{provider.get('issuer')}.well-known/openid-configuration").json()
oidc_config = requests.get(f"{provider.get('issuer').strip('/')}/.well-known/openid-configuration").json()
data = {
"grant_type": "authorization_code",
@@ -42,30 +53,24 @@ class DashboardOIDC:
"client_secret": provider.get('client_secret')
}
tokens = requests.post(oidc_config.get('token_endpoint'), data=data).json()
try:
tokens = requests.post(oidc_config.get('token_endpoint'), data=data).json()
if not all([tokens.get('access_token'), tokens.get('id_token')]):
return False, tokens.get('error_description', None)
except Exception as e:
return False, str(e)
id_token = tokens.get('id_token')
jwks_uri = oidc_config.get("jwks_uri")
issuer = oidc_config.get("issuer")
jwks = requests.get(jwks_uri).json()
from jose.utils import base64url_decode
from jose.backends.cryptography_backend import CryptographyRSAKey
# Choose the right key based on `kid` in token header
headers = jwt.get_unverified_header(id_token)
kid = headers["kid"]
# Find the key with the correct `kid`
key = next(k for k in jwks["keys"] if k["kid"] == kid)
# Use the key to verify token
payload = jwt.decode(
id_token,
key,
@@ -74,7 +79,7 @@ class DashboardOIDC:
issuer=issuer
)
print(payload) # This contains the user's claims
return True, payload
def ReadFile(self):