mirror of
https://github.com/donaldzou/WGDashboard.git
synced 2025-10-04 00:06:18 +00:00
OIDC is ready? I think?
This commit is contained in:
@@ -8,6 +8,7 @@ import sqlalchemy as db
|
||||
from .ConnectionString import ConnectionString
|
||||
from .DashboardClientsPeerAssignment import DashboardClientsPeerAssignment
|
||||
from .DashboardClientsTOTP import DashboardClientsTOTP
|
||||
from .DashboardOIDC import DashboardOIDC
|
||||
from .Utilities import ValidatePasswordStrength
|
||||
from .DashboardLogger import DashboardLogger
|
||||
from flask import session
|
||||
@@ -18,6 +19,7 @@ class DashboardClients:
|
||||
self.logger = DashboardLogger()
|
||||
self.engine = db.create_engine(ConnectionString("wgdashboard"))
|
||||
self.metadata = db.MetaData()
|
||||
self.OIDC = DashboardOIDC()
|
||||
|
||||
self.dashboardClientsTable = db.Table(
|
||||
'DashboardClients', self.metadata,
|
||||
@@ -33,6 +35,18 @@ class DashboardClients:
|
||||
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP)),
|
||||
extend_existing=True,
|
||||
)
|
||||
|
||||
self.dashboardOIDCClientsTable = db.Table(
|
||||
'DashboardOIDCClients', self.metadata,
|
||||
db.Column('ClientID', db.String(255), nullable=False, primary_key=True),
|
||||
db.Column('Email', db.String(255), nullable=False, index=True),
|
||||
db.Column('ProviderIssuer', db.String(500), nullable=False, index=True),
|
||||
db.Column('ProviderSubject', db.String(500), nullable=False, index=True),
|
||||
db.Column('CreatedDate',
|
||||
(db.DATETIME if 'sqlite:///' in ConnectionString("wgdashboard") else db.TIMESTAMP),
|
||||
server_default=db.func.now()),
|
||||
extend_existing=True,
|
||||
)
|
||||
|
||||
self.dashboardClientsInfoTable = db.Table(
|
||||
'DashboardClientsInfo', self.metadata,
|
||||
@@ -85,7 +99,53 @@ class DashboardClients:
|
||||
)
|
||||
).mappings().fetchone()
|
||||
return existingClient
|
||||
|
||||
def SignIn_OIDC_UserExistence(self, data: dict[str, str]):
|
||||
with self.engine.connect() as conn:
|
||||
existingClient = conn.execute(
|
||||
self.dashboardOIDCClientsTable.select().where(
|
||||
db.and_(
|
||||
self.dashboardOIDCClientsTable.c.ProviderIssuer == data.get('iss'),
|
||||
self.dashboardOIDCClientsTable.c.ProviderSubject == data.get('sub'),
|
||||
)
|
||||
)
|
||||
).mappings().fetchone()
|
||||
return existingClient
|
||||
|
||||
def SignUp_OIDC(self, data: dict[str, str]) -> tuple[bool, str] | tuple[bool, None]:
|
||||
if not self.SignIn_OIDC_UserExistence(data):
|
||||
with self.engine.begin() as conn:
|
||||
newClientUUID = str(uuid.uuid4())
|
||||
conn.execute(
|
||||
self.dashboardOIDCClientsTable.insert().values({
|
||||
"ClientID": newClientUUID,
|
||||
"Email": data.get('email', ''),
|
||||
"ProviderIssuer": data.get('iss', ''),
|
||||
"ProviderSubject": data.get('sub', '')
|
||||
})
|
||||
)
|
||||
conn.execute(
|
||||
self.dashboardClientsInfoTable.insert().values({
|
||||
"ClientID": newClientUUID
|
||||
})
|
||||
)
|
||||
self.logger.log(Message=f"User {data.get('email', '')} from {data.get('iss', '')} signed up")
|
||||
return True, newClientUUID
|
||||
return False, "User already signed up"
|
||||
|
||||
def SignIn_OIDC(self, **kwargs):
|
||||
status, data = self.OIDC.VerifyToken(**kwargs)
|
||||
if not status:
|
||||
return False, "Sign in failed"
|
||||
existingClient = self.SignIn_OIDC_UserExistence(data)
|
||||
if not existingClient:
|
||||
status, newClientUUID = self.SignUp_OIDC(data)
|
||||
session['ClientID'] = newClientUUID
|
||||
else:
|
||||
session['ClientID'] = existingClient.get("ClientID")
|
||||
|
||||
return True, data
|
||||
|
||||
def SignIn(self, Email, Password) -> tuple[bool, str]:
|
||||
if not all([Email, Password]):
|
||||
return False, "Please fill in all fields"
|
||||
@@ -149,7 +209,9 @@ class DashboardClients:
|
||||
"ClientID": newClientUUID,
|
||||
"Email": Email,
|
||||
"Password": bcrypt.hashpw(encodePassword, bcrypt.gensalt()).decode("utf-8"),
|
||||
"TotpKey": totpKey
|
||||
"TotpKey": totpKey,
|
||||
"AuthType": "local",
|
||||
"AuthSrc": "local"
|
||||
})
|
||||
)
|
||||
conn.execute(
|
||||
|
@@ -23,6 +23,17 @@ class DashboardOIDC:
|
||||
f.write(encoder.encode(self.__default))
|
||||
|
||||
self.ReadFile()
|
||||
|
||||
def GetProviders(self):
|
||||
providers = {}
|
||||
for k in self.providers.keys():
|
||||
if all([self.providers[k]['client_id'], self.providers[k]['client_secret'], self.providers[k]['issuer']]):
|
||||
providers[k] = {
|
||||
'client_id': self.providers[k]['client_id'],
|
||||
'issuer': self.providers[k]['issuer'].strip('/')
|
||||
}
|
||||
|
||||
return providers
|
||||
|
||||
def VerifyToken(self, provider, code, redirect_uri):
|
||||
if not all([provider, code, redirect_uri]):
|
||||
@@ -32,7 +43,7 @@ class DashboardOIDC:
|
||||
return False, "Provider does not exist"
|
||||
|
||||
provider = self.providers.get(provider)
|
||||
oidc_config = requests.get(f"{provider.get('issuer')}.well-known/openid-configuration").json()
|
||||
oidc_config = requests.get(f"{provider.get('issuer').strip('/')}/.well-known/openid-configuration").json()
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
@@ -42,30 +53,24 @@ class DashboardOIDC:
|
||||
"client_secret": provider.get('client_secret')
|
||||
}
|
||||
|
||||
tokens = requests.post(oidc_config.get('token_endpoint'), data=data).json()
|
||||
|
||||
|
||||
try:
|
||||
tokens = requests.post(oidc_config.get('token_endpoint'), data=data).json()
|
||||
if not all([tokens.get('access_token'), tokens.get('id_token')]):
|
||||
return False, tokens.get('error_description', None)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
|
||||
id_token = tokens.get('id_token')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
jwks_uri = oidc_config.get("jwks_uri")
|
||||
issuer = oidc_config.get("issuer")
|
||||
jwks = requests.get(jwks_uri).json()
|
||||
|
||||
from jose.utils import base64url_decode
|
||||
from jose.backends.cryptography_backend import CryptographyRSAKey
|
||||
|
||||
# Choose the right key based on `kid` in token header
|
||||
headers = jwt.get_unverified_header(id_token)
|
||||
kid = headers["kid"]
|
||||
|
||||
# Find the key with the correct `kid`
|
||||
key = next(k for k in jwks["keys"] if k["kid"] == kid)
|
||||
|
||||
# Use the key to verify token
|
||||
payload = jwt.decode(
|
||||
id_token,
|
||||
key,
|
||||
@@ -74,7 +79,7 @@ class DashboardOIDC:
|
||||
issuer=issuer
|
||||
)
|
||||
|
||||
print(payload) # This contains the user's claims
|
||||
return True, payload
|
||||
|
||||
|
||||
def ReadFile(self):
|
||||
|
Reference in New Issue
Block a user