mirror of
https://github.com/h44z/wg-portal.git
synced 2026-05-28 08:56:17 +00:00
feat: sanitize external identity provider user data (#681)
* feat: sanitize external user data * remove config option to disable Sanitization: sanitize_external_user_data * cleanup --------- Co-authored-by: Christoph Haas <christoph.h@sprinternet.at>
This commit is contained in:
@@ -113,4 +113,4 @@ backend:
|
|||||||
api_verify_tls: true
|
api_verify_tls: true
|
||||||
api_timeout: 30s
|
api_timeout: 30s
|
||||||
concurrency: 5
|
concurrency: 5
|
||||||
debug: false
|
debug: false
|
||||||
|
|||||||
@@ -8,6 +8,23 @@ To enable encryption, set the [`encryption_passphrase`](../configuration/overvie
|
|||||||
> :warning: Important: Once encryption is enabled, it cannot be disabled, and the passphrase cannot be changed!
|
> :warning: Important: Once encryption is enabled, it cannot be disabled, and the passphrase cannot be changed!
|
||||||
> Only new or updated records will be encrypted; existing data remains in plaintext until it’s next modified.
|
> Only new or updated records will be encrypted; existing data remains in plaintext until it’s next modified.
|
||||||
|
|
||||||
|
## External Identity Provider Data Sanitization
|
||||||
|
|
||||||
|
When users authenticate via LDAP, OIDC, or OAuth, WireGuard Portal sanitizes the field values received from the provider before storing them. This protects against several classes of attack that a compromised or misconfigured identity provider could introduce:
|
||||||
|
|
||||||
|
- **Unsafe control characters** — Unicode control and format characters, null bytes, and invalid UTF-8 bytes are stripped from external profile fields before they reach the Vue.js UI or email templates.
|
||||||
|
- **Email header injection** — carriage return and line feed characters in email fields are rejected entirely, and email fields must parse as plain email addresses.
|
||||||
|
- **Log injection** — unsafe control and format characters are stripped from all external profile fields and from sanitization log context.
|
||||||
|
- **Denial of service via oversized fields** — field lengths are capped (e.g., 256 runes for identifiers, 254 characters for email addresses).
|
||||||
|
- **Reserved identifier collision** — reserved user identifiers such as `"all"`, `"new"`, `"id"`, and internal system user identifiers are rejected.
|
||||||
|
- **Unsafe authorization groups** — OIDC/OAuth group claims are sanitized before group-based checks; groups changed by control/format stripping or truncation are dropped rather than repaired into allowed/admin matches.
|
||||||
|
|
||||||
|
Sanitization is always enabled and cannot be disabled.
|
||||||
|
|
||||||
|
When sanitization modifies or clears a field value, a `WARN` log entry is emitted with the provider name, provider type, and field name — but never the raw or sanitized value, to avoid leaking sensitive data into logs. This makes it straightforward to detect and investigate potentially malicious or misconfigured providers.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## UI and API Access
|
## UI and API Access
|
||||||
|
|
||||||
WireGuard Portal provides a web UI and a REST API for user interaction. It is important to secure these interfaces to prevent unauthorized access and data breaches.
|
WireGuard Portal provides a web UI and a REST API for user interaction. It is important to secure these interfaces to prevent unauthorized access and data breaches.
|
||||||
@@ -21,4 +38,4 @@ A detailed explanation is available in the [Reverse Proxy](../getting-started/re
|
|||||||
### Secure Authentication
|
### Secure Authentication
|
||||||
To prevent unauthorized access, WireGuard Portal supports integrating with secure authentication providers such as LDAP, OAuth2, or Passkeys, see [Authentication](./authentication.md) for more details.
|
To prevent unauthorized access, WireGuard Portal supports integrating with secure authentication providers such as LDAP, OAuth2, or Passkeys, see [Authentication](./authentication.md) for more details.
|
||||||
When possible, use centralized authentication and enforce multi-factor authentication (MFA) at the provider level for enhanced account security.
|
When possible, use centralized authentication and enforce multi-factor authentication (MFA) at the provider level for enhanced account security.
|
||||||
For local accounts, administrators should enforce strong password requirements.
|
For local accounts, administrators should enforce strong password requirements.
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -31,6 +31,7 @@ require (
|
|||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/driver/sqlserver v1.6.3
|
gorm.io/driver/sqlserver v1.6.3
|
||||||
gorm.io/gorm v1.31.1
|
gorm.io/gorm v1.31.1
|
||||||
|
pgregory.net/rapid v1.2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -432,5 +432,7 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
|||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
|
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
|
||||||
|
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
|
||||||
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
|
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
|
||||||
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
|
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
|
||||||
|
|||||||
@@ -149,5 +149,9 @@ func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.Authentica
|
|||||||
AdminInfoAvailable: adminInfoAvailable,
|
AdminInfoAvailable: adminInfoAvailable,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := userInfo.Sanitize("ldap", l.cfg.ProviderName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return userInfo, nil
|
return userInfo, nil
|
||||||
}
|
}
|
||||||
|
|||||||
98
internal/app/auth/auth_ldap_sanitize_test.go
Normal file
98
internal/app/auth/auth_ldap_sanitize_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
"github.com/h44z/wg-portal/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeLdapAuthenticator creates a minimal LdapAuthenticator for testing ParseUserInfo.
|
||||||
|
func makeLdapAuthenticator() *LdapAuthenticator {
|
||||||
|
return &LdapAuthenticator{
|
||||||
|
cfg: &config.LdapProvider{
|
||||||
|
ProviderName: "test-ldap",
|
||||||
|
FieldMap: config.LdapFields{
|
||||||
|
BaseFields: config.BaseFields{
|
||||||
|
UserIdentifier: "uid",
|
||||||
|
Email: "mail",
|
||||||
|
Firstname: "givenName",
|
||||||
|
Lastname: "sn",
|
||||||
|
Phone: "telephoneNumber",
|
||||||
|
Department: "department",
|
||||||
|
},
|
||||||
|
GroupMembership: "", // no group membership check
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRawLdapMap builds a minimal raw LDAP attribute map for ParseUserInfo.
|
||||||
|
func makeRawLdapMap(uid, mail, givenName, sn, phone, department string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"uid": uid,
|
||||||
|
"mail": mail,
|
||||||
|
"givenName": givenName,
|
||||||
|
"sn": sn,
|
||||||
|
"telephoneNumber": phone,
|
||||||
|
"department": department,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: firstname contains \x00 → output firstname has no null byte,
|
||||||
|
// one WARN log entry with field: "firstname".
|
||||||
|
func TestLdapParseUserInfo_NullByteInFirstname(t *testing.T) {
|
||||||
|
auth := makeLdapAuthenticator()
|
||||||
|
raw := makeRawLdapMap("alice", "alice@example.com", "Ali\x00ce", "Smith", "", "")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
info, err := auth.ParseUserInfo(raw)
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotContains(t, info.Firstname, "\x00", "firstname should have null byte removed")
|
||||||
|
assert.Equal(t, "Alice", info.Firstname)
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
|
||||||
|
|
||||||
|
rec, found := testutil.FindWarnWithField(records, "firstname")
|
||||||
|
assert.True(t, found, "expected WARN log entry with field=firstname")
|
||||||
|
if found {
|
||||||
|
assert.Equal(t, "WARN", rec["level"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: all fields clean → no WARN log entries emitted.
|
||||||
|
func TestLdapParseUserInfo_AllFieldsClean(t *testing.T) {
|
||||||
|
auth := makeLdapAuthenticator()
|
||||||
|
raw := makeRawLdapMap("alice", "alice@example.com", "Alice", "Smith", "+1 555-1234", "Engineering")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
info, err := auth.ParseUserInfo(raw)
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, domain.UserIdentifier("alice"), info.Identifier)
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 0, warnCount, "expected no WARN log entries when all fields are clean")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: identifier is "all" → returns ErrInvalidData.
|
||||||
|
func TestLdapParseUserInfo_IdentifierAll(t *testing.T) {
|
||||||
|
auth := makeLdapAuthenticator()
|
||||||
|
raw := makeRawLdapMap("all", "all@example.com", "Alice", "Smith", "", "")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
_, err := auth.ParseUserInfo(raw)
|
||||||
|
_ = restore()
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
|
||||||
|
}
|
||||||
@@ -155,5 +155,5 @@ func (p PlainOauthAuthenticator) GetUserInfo(
|
|||||||
|
|
||||||
// ParseUserInfo parses the user information from the raw data.
|
// ParseUserInfo parses the user information from the raw data.
|
||||||
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||||
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
|
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw, "oauth", p.name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -194,5 +194,5 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
|
|||||||
|
|
||||||
// ParseUserInfo parses the user info.
|
// ParseUserInfo parses the user info.
|
||||||
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||||
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
|
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw, "oidc", o.name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ func parseOauthUserInfo(
|
|||||||
mapping config.OauthFields,
|
mapping config.OauthFields,
|
||||||
adminMapping *config.OauthAdminMapping,
|
adminMapping *config.OauthAdminMapping,
|
||||||
raw map[string]any,
|
raw map[string]any,
|
||||||
|
providerType string,
|
||||||
|
providerName string,
|
||||||
) (*domain.AuthenticatorUserInfo, error) {
|
) (*domain.AuthenticatorUserInfo, error) {
|
||||||
var isAdmin bool
|
var isAdmin bool
|
||||||
var adminInfoAvailable bool
|
var adminInfoAvailable bool
|
||||||
@@ -27,18 +29,6 @@ func parseOauthUserInfo(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// next try to parse the user's groups
|
|
||||||
if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" {
|
|
||||||
adminInfoAvailable = true
|
|
||||||
re := adminMapping.GetAdminGroupRegex()
|
|
||||||
for _, group := range userGroups {
|
|
||||||
if re.MatchString(strings.TrimSpace(group)) {
|
|
||||||
isAdmin = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
userInfo := &domain.AuthenticatorUserInfo{
|
userInfo := &domain.AuthenticatorUserInfo{
|
||||||
Identifier: domain.UserIdentifier(internal.MapDefaultString(raw, mapping.UserIdentifier, "")),
|
Identifier: domain.UserIdentifier(internal.MapDefaultString(raw, mapping.UserIdentifier, "")),
|
||||||
Email: internal.MapDefaultString(raw, mapping.Email, ""),
|
Email: internal.MapDefaultString(raw, mapping.Email, ""),
|
||||||
@@ -51,6 +41,24 @@ func parseOauthUserInfo(
|
|||||||
AdminInfoAvailable: adminInfoAvailable,
|
AdminInfoAvailable: adminInfoAvailable,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := userInfo.Sanitize(providerType, providerName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check admin group match after sanitization
|
||||||
|
if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" {
|
||||||
|
adminInfoAvailable = true
|
||||||
|
re := adminMapping.GetAdminGroupRegex()
|
||||||
|
for _, group := range userInfo.UserGroups {
|
||||||
|
if re.MatchString(group) {
|
||||||
|
isAdmin = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userInfo.IsAdmin = isAdmin
|
||||||
|
userInfo.AdminInfoAvailable = adminInfoAvailable
|
||||||
|
}
|
||||||
|
|
||||||
return userInfo, nil
|
return userInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
148
internal/app/auth/oauth_common_sanitize_test.go
Normal file
148
internal/app/auth/oauth_common_sanitize_test.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
"github.com/h44z/wg-portal/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeOauthFieldMapping returns a minimal OauthFields mapping for testing.
|
||||||
|
func makeOauthFieldMapping() config.OauthFields {
|
||||||
|
return config.OauthFields{
|
||||||
|
BaseFields: config.BaseFields{
|
||||||
|
UserIdentifier: "sub",
|
||||||
|
Email: "email",
|
||||||
|
Firstname: "given_name",
|
||||||
|
Lastname: "family_name",
|
||||||
|
Phone: "phone",
|
||||||
|
Department: "department",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeOauthRaw builds a minimal raw OAuth user info map.
|
||||||
|
func makeOauthRaw(sub, email, givenName, familyName, phone, department string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"sub": sub,
|
||||||
|
"email": email,
|
||||||
|
"given_name": givenName,
|
||||||
|
"family_name": familyName,
|
||||||
|
"phone": phone,
|
||||||
|
"department": department,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: email containing \r\n → output email is "",
|
||||||
|
// one WARN log entry with field: "email" and cleared indication.
|
||||||
|
func TestParseOauthUserInfo_CRLFInEmail(t *testing.T) {
|
||||||
|
mapping := makeOauthFieldMapping()
|
||||||
|
adminMapping := &config.OauthAdminMapping{}
|
||||||
|
raw := makeOauthRaw("user123", "user\r\n@example.com", "Alice", "Smith", "", "")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "", info.Email, "email should be cleared when it contains CR/LF")
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
|
||||||
|
|
||||||
|
rec, found := testutil.FindWarnWithField(records, "email")
|
||||||
|
assert.True(t, found, "expected WARN log entry with field=email")
|
||||||
|
if found {
|
||||||
|
msg, _ := rec["msg"].(string)
|
||||||
|
assert.Contains(t, msg, "cleared", "expected 'cleared' in log message when email is cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: two fields modified (email cleared, firstname truncated) →
|
||||||
|
// two separate WARN log entries.
|
||||||
|
func TestParseOauthUserInfo_TwoFieldsModified(t *testing.T) {
|
||||||
|
mapping := makeOauthFieldMapping()
|
||||||
|
adminMapping := &config.OauthAdminMapping{}
|
||||||
|
|
||||||
|
longFirstname := strings.Repeat("A", 200)
|
||||||
|
raw := makeOauthRaw("user123", "bad\r\nemail@example.com", longFirstname, "Smith", "", "")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "", info.Email, "email should be cleared")
|
||||||
|
assert.Equal(t, 128, len([]rune(info.Firstname)), "firstname should be truncated to 128 runes")
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 2, warnCount, "expected exactly two WARN log entries (one per modified field)")
|
||||||
|
|
||||||
|
_, emailFound := testutil.FindWarnWithField(records, "email")
|
||||||
|
assert.True(t, emailFound, "expected WARN log entry with field=email")
|
||||||
|
|
||||||
|
_, firstnameFound := testutil.FindWarnWithField(records, "firstname")
|
||||||
|
assert.True(t, firstnameFound, "expected WARN log entry with field=firstname")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: identifier "all" → returns ErrInvalidData.
|
||||||
|
func TestParseOauthUserInfo_IdentifierAll(t *testing.T) {
|
||||||
|
mapping := makeOauthFieldMapping()
|
||||||
|
adminMapping := &config.OauthAdminMapping{}
|
||||||
|
raw := makeOauthRaw("all", "all@example.com", "Alice", "Smith", "", "")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
_, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||||
|
_ = restore()
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOauthUserInfo_DropsModifiedGroupBeforeAdminMatch(t *testing.T) {
|
||||||
|
mapping := makeOauthFieldMapping()
|
||||||
|
mapping.UserGroups = "groups"
|
||||||
|
adminMapping := &config.OauthAdminMapping{
|
||||||
|
AdminGroupRegex: "^wgportal-admins$",
|
||||||
|
}
|
||||||
|
raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "")
|
||||||
|
raw["groups"] = []any{"wgportal-\u200badmins"}
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
assert.False(t, info.IsAdmin, "sanitization must not repair a modified group into an admin match")
|
||||||
|
assert.Empty(t, info.UserGroups)
|
||||||
|
|
||||||
|
rec, found := testutil.FindWarnWithField(records, "user_group")
|
||||||
|
assert.True(t, found, "expected WARN log entry with field=user_group")
|
||||||
|
if found {
|
||||||
|
assert.Equal(t, "oidc", rec["provider_type"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOauthUserInfo_AllowsWhitespaceOnlyGroupTrim(t *testing.T) {
|
||||||
|
mapping := makeOauthFieldMapping()
|
||||||
|
mapping.UserGroups = "groups"
|
||||||
|
adminMapping := &config.OauthAdminMapping{
|
||||||
|
AdminGroupRegex: "^wgportal-admins$",
|
||||||
|
}
|
||||||
|
raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "")
|
||||||
|
raw["groups"] = []any{" wgportal-admins "}
|
||||||
|
|
||||||
|
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
assert.True(t, info.IsAdmin)
|
||||||
|
assert.Equal(t, []string{"wgportal-admins"}, info.UserGroups)
|
||||||
|
}
|
||||||
@@ -43,7 +43,7 @@ func Test_parseOauthUserInfo_no_admin(t *testing.T) {
|
|||||||
})
|
})
|
||||||
adminMapping := &config.OauthAdminMapping{}
|
adminMapping := &config.OauthAdminMapping{}
|
||||||
|
|
||||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.False(t, info.IsAdmin)
|
assert.False(t, info.IsAdmin)
|
||||||
assert.Equal(t, info.Firstname, "Test User")
|
assert.Equal(t, info.Firstname, "Test User")
|
||||||
@@ -90,7 +90,7 @@ func Test_parseOauthUserInfo_admin_group(t *testing.T) {
|
|||||||
AdminGroupRegex: "^wgportal-admins@mydomain.net$",
|
AdminGroupRegex: "^wgportal-admins@mydomain.net$",
|
||||||
}
|
}
|
||||||
|
|
||||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, info.IsAdmin)
|
assert.True(t, info.IsAdmin)
|
||||||
assert.Equal(t, info.Firstname, "Test User")
|
assert.Equal(t, info.Firstname, "Test User")
|
||||||
@@ -132,7 +132,7 @@ func Test_parseOauthUserInfo_admin_value(t *testing.T) {
|
|||||||
})
|
})
|
||||||
adminMapping := &config.OauthAdminMapping{} // test with default regex
|
adminMapping := &config.OauthAdminMapping{} // test with default regex
|
||||||
|
|
||||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, info.IsAdmin)
|
assert.True(t, info.IsAdmin)
|
||||||
assert.Equal(t, info.Firstname, "Test User")
|
assert.Equal(t, info.Firstname, "Test User")
|
||||||
@@ -175,7 +175,7 @@ func Test_parseOauthUserInfo_admin_value_custom(t *testing.T) {
|
|||||||
AdminValueRegex: "^1$",
|
AdminValueRegex: "^1$",
|
||||||
}
|
}
|
||||||
|
|
||||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, info.IsAdmin)
|
assert.True(t, info.IsAdmin)
|
||||||
assert.Equal(t, info.Firstname, "Test User")
|
assert.Equal(t, info.Firstname, "Test User")
|
||||||
|
|||||||
90
internal/app/auth/sanitize_log_test.go
Normal file
90
internal/app/auth/sanitize_log_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"pgregory.net/rapid"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
"github.com/h44z/wg-portal/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// captureWarnLogsInline redirects the default slog logger to a buffer, calls fn,
|
||||||
|
// restores the original logger, and returns the captured log records.
|
||||||
|
func captureWarnLogsInline(fn func()) []map[string]any {
|
||||||
|
original := slog.Default()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
handler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn})
|
||||||
|
slog.SetDefault(slog.New(handler))
|
||||||
|
|
||||||
|
fn()
|
||||||
|
|
||||||
|
slog.SetDefault(original)
|
||||||
|
|
||||||
|
var records []map[string]any
|
||||||
|
decoder := json.NewDecoder(&buf)
|
||||||
|
for decoder.More() {
|
||||||
|
var rec map[string]any
|
||||||
|
if err := decoder.Decode(&rec); err == nil {
|
||||||
|
records = append(records, rec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
// Property 7: Sanitization change logging completeness
|
||||||
|
func TestPropertySanitizationChangeLoggingCompleteness(t *testing.T) {
|
||||||
|
mapping := makeOauthFieldMapping()
|
||||||
|
adminMapping := &config.OauthAdminMapping{}
|
||||||
|
|
||||||
|
rapid.Check(t, func(t *rapid.T) {
|
||||||
|
sub := rapid.StringMatching(`[a-zA-Z0-9_@.-]{1,50}`).Draw(t, "sub")
|
||||||
|
email := rapid.String().Draw(t, "email")
|
||||||
|
firstname := rapid.String().Draw(t, "firstname")
|
||||||
|
lastname := rapid.String().Draw(t, "lastname")
|
||||||
|
phone := rapid.String().Draw(t, "phone")
|
||||||
|
department := rapid.String().Draw(t, "department")
|
||||||
|
|
||||||
|
if sub == "" {
|
||||||
|
sub = "testuser"
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := makeOauthRaw(sub, email, firstname, lastname, phone, department)
|
||||||
|
|
||||||
|
// Count how many fields will actually change after sanitization
|
||||||
|
expectedChanges := 0
|
||||||
|
if domain.SanitizeIdentifier(sub, 256) != sub {
|
||||||
|
expectedChanges++
|
||||||
|
}
|
||||||
|
if domain.SanitizeEmail(email, 254) != email {
|
||||||
|
expectedChanges++
|
||||||
|
}
|
||||||
|
if domain.SanitizeString(firstname, 128) != firstname {
|
||||||
|
expectedChanges++
|
||||||
|
}
|
||||||
|
if domain.SanitizeString(lastname, 128) != lastname {
|
||||||
|
expectedChanges++
|
||||||
|
}
|
||||||
|
if domain.SanitizePhone(phone, 50) != phone {
|
||||||
|
expectedChanges++
|
||||||
|
}
|
||||||
|
if domain.SanitizeString(department, 128) != department {
|
||||||
|
expectedChanges++
|
||||||
|
}
|
||||||
|
|
||||||
|
var records []map[string]any
|
||||||
|
records = captureWarnLogsInline(func() {
|
||||||
|
_, _ = parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||||
|
})
|
||||||
|
|
||||||
|
actualWarnCount := testutil.CountWarnEntries(records)
|
||||||
|
require.Equal(t, expectedChanges, actualWarnCount,
|
||||||
|
"number of WARN log entries (%d) must equal number of fields changed by sanitization (%d)",
|
||||||
|
actualWarnCount, expectedChanges)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -28,7 +28,7 @@ func convertRawLdapUser(
|
|||||||
|
|
||||||
uid := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, ""))
|
uid := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, ""))
|
||||||
|
|
||||||
return &domain.User{
|
user := &domain.User{
|
||||||
BaseModel: domain.BaseModel{
|
BaseModel: domain.BaseModel{
|
||||||
CreatedBy: domain.CtxSystemLdapSyncer,
|
CreatedBy: domain.CtxSystemLdapSyncer,
|
||||||
UpdatedBy: domain.CtxSystemLdapSyncer,
|
UpdatedBy: domain.CtxSystemLdapSyncer,
|
||||||
@@ -49,10 +49,16 @@ func convertRawLdapUser(
|
|||||||
Lastname: internal.MapDefaultString(rawUser, fields.Lastname, ""),
|
Lastname: internal.MapDefaultString(rawUser, fields.Lastname, ""),
|
||||||
Phone: internal.MapDefaultString(rawUser, fields.Phone, ""),
|
Phone: internal.MapDefaultString(rawUser, fields.Phone, ""),
|
||||||
Department: internal.MapDefaultString(rawUser, fields.Department, ""),
|
Department: internal.MapDefaultString(rawUser, fields.Department, ""),
|
||||||
Notes: "",
|
}
|
||||||
Password: "",
|
|
||||||
Disabled: nil,
|
if err := user.SanitizeExternalData("ldap", providerName); err != nil {
|
||||||
}, nil
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update authentication identifier after sanitization
|
||||||
|
user.Authentications[0].UserIdentifier = user.Identifier
|
||||||
|
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func userChangedInLdap(dbUser, ldapUser *domain.User) bool {
|
func userChangedInLdap(dbUser, ldapUser *domain.User) bool {
|
||||||
|
|||||||
136
internal/app/users/ldap_helper_sanitize_test.go
Normal file
136
internal/app/users/ldap_helper_sanitize_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-ldap/ldap/v3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
"github.com/h44z/wg-portal/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeTestLdapFields returns a minimal LdapFields config for testing.
|
||||||
|
func makeTestLdapFields() *config.LdapFields {
|
||||||
|
return &config.LdapFields{
|
||||||
|
BaseFields: config.BaseFields{
|
||||||
|
UserIdentifier: "uid",
|
||||||
|
Email: "mail",
|
||||||
|
Firstname: "givenName",
|
||||||
|
Lastname: "sn",
|
||||||
|
Phone: "telephoneNumber",
|
||||||
|
Department: "department",
|
||||||
|
},
|
||||||
|
GroupMembership: "memberOf",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeTestAdminGroupDN returns a parsed DN for testing (a non-matching group).
|
||||||
|
func makeTestAdminGroupDN(t *testing.T) *ldap.DN {
|
||||||
|
t.Helper()
|
||||||
|
dn, err := ldap.ParseDN("cn=admins,dc=example,dc=com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return dn
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRawLdapUser builds a raw LDAP user map for convertRawLdapUser.
|
||||||
|
func makeRawLdapUser(uid, mail, givenName, sn, phone, department string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"uid": uid,
|
||||||
|
"mail": mail,
|
||||||
|
"givenName": givenName,
|
||||||
|
"sn": sn,
|
||||||
|
"telephoneNumber": phone,
|
||||||
|
"department": department,
|
||||||
|
"memberOf": [][]byte{}, // no group memberships
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: identifier "all" → returns ErrInvalidData,
|
||||||
|
// one WARN log entry with field: "identifier" and cleared indication.
|
||||||
|
func TestConvertRawLdapUser_IdentifierAll(t *testing.T) {
|
||||||
|
fields := makeTestLdapFields()
|
||||||
|
adminGroupDN := makeTestAdminGroupDN(t)
|
||||||
|
raw := makeRawLdapUser("all", "all@example.com", "Alice", "Smith", "", "")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN)
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
|
||||||
|
assert.Nil(t, user)
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
|
||||||
|
|
||||||
|
rec, found := testutil.FindWarnWithField(records, "identifier")
|
||||||
|
assert.True(t, found, "expected WARN log entry with field=identifier")
|
||||||
|
if found {
|
||||||
|
msg, _ := rec["msg"].(string)
|
||||||
|
assert.Contains(t, msg, "cleared", "expected 'cleared' in log message when identifier is cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: firstname contains \x00 → output firstname has null byte removed,
|
||||||
|
// one WARN log entry with field: "firstname".
|
||||||
|
func TestConvertRawLdapUser_NullByteInFirstname(t *testing.T) {
|
||||||
|
fields := makeTestLdapFields()
|
||||||
|
adminGroupDN := makeTestAdminGroupDN(t)
|
||||||
|
raw := makeRawLdapUser("alice", "alice@example.com", "Ali\x00ce", "Smith", "", "")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN)
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
assert.NotContains(t, user.Firstname, "\x00", "firstname should have null byte removed")
|
||||||
|
assert.Equal(t, "Alice", user.Firstname)
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
|
||||||
|
|
||||||
|
rec, found := testutil.FindWarnWithField(records, "firstname")
|
||||||
|
assert.True(t, found, "expected WARN log entry with field=firstname")
|
||||||
|
if found {
|
||||||
|
assert.Equal(t, "WARN", rec["level"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test: all fields clean → no WARN log entries emitted.
|
||||||
|
func TestConvertRawLdapUser_AllFieldsClean(t *testing.T) {
|
||||||
|
fields := makeTestLdapFields()
|
||||||
|
adminGroupDN := makeTestAdminGroupDN(t)
|
||||||
|
raw := makeRawLdapUser("alice", "alice@example.com", "Alice", "Smith", "+1 555-1234", "Engineering")
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN)
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
assert.Equal(t, domain.UserIdentifier("alice"), user.Identifier)
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 0, warnCount, "expected no WARN log entries when all fields are clean")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLdapUserIdentifier_NormalizesSyncComparisons(t *testing.T) {
|
||||||
|
raw := map[string]any{"uid": " alice\x00 "}
|
||||||
|
|
||||||
|
got := ldapUserIdentifier(raw, "uid")
|
||||||
|
|
||||||
|
assert.Equal(t, domain.UserIdentifier("alice"), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLdapUserIdentifier_RejectsReservedIdentifier(t *testing.T) {
|
||||||
|
raw := map[string]any{"uid": " all "}
|
||||||
|
|
||||||
|
got := ldapUserIdentifier(raw, "uid")
|
||||||
|
|
||||||
|
assert.Empty(t, got)
|
||||||
|
}
|
||||||
@@ -109,6 +109,11 @@ func (m Manager) updateLdapUsers(
|
|||||||
for _, rawUser := range rawUsers {
|
for _, rawUser := range rawUsers {
|
||||||
user, err := convertRawLdapUser(provider.ProviderName, rawUser, fields, adminGroupDN)
|
user, err := convertRawLdapUser(provider.ProviderName, rawUser, fields, adminGroupDN)
|
||||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||||
|
if errors.Is(err, domain.ErrInvalidData) {
|
||||||
|
slog.Warn("skipping LDAP user with invalid data after sanitization",
|
||||||
|
"raw-dn", rawUser["dn"], "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to convert LDAP data for %v: %w", rawUser["dn"], err)
|
return fmt.Errorf("failed to convert LDAP data for %v: %w", rawUser["dn"], err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,7 +217,7 @@ func (m Manager) disableMissingLdapUsers(
|
|||||||
|
|
||||||
existsInLDAP := false
|
existsInLDAP := false
|
||||||
for _, rawUser := range rawUsers {
|
for _, rawUser := range rawUsers {
|
||||||
userId := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, ""))
|
userId := ldapUserIdentifier(rawUser, fields.UserIdentifier)
|
||||||
if user.Identifier == userId {
|
if user.Identifier == userId {
|
||||||
existsInLDAP = true
|
existsInLDAP = true
|
||||||
break
|
break
|
||||||
@@ -258,19 +263,19 @@ func (m Manager) updateInterfaceLdapFilters(
|
|||||||
|
|
||||||
// Combined filter: user must match the provider's base SyncFilter AND the interface's LdapGroupFilter
|
// Combined filter: user must match the provider's base SyncFilter AND the interface's LdapGroupFilter
|
||||||
combinedFilter := fmt.Sprintf("(&(%s)(%s))", provider.SyncFilter, groupFilter)
|
combinedFilter := fmt.Sprintf("(&(%s)(%s))", provider.SyncFilter, groupFilter)
|
||||||
|
|
||||||
rawUsers, err := internal.LdapFindAllUsers(conn, provider.BaseDN, combinedFilter, &provider.FieldMap)
|
rawUsers, err := internal.LdapFindAllUsers(conn, provider.BaseDN, combinedFilter, &provider.FieldMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to find users for interface filter",
|
slog.Error("failed to find users for interface filter",
|
||||||
"interface", ifaceId,
|
"interface", ifaceId,
|
||||||
"provider", provider.ProviderName,
|
"provider", provider.ProviderName,
|
||||||
"error", err)
|
"error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
matchedUserIds := make([]domain.UserIdentifier, 0, len(rawUsers))
|
matchedUserIds := make([]domain.UserIdentifier, 0, len(rawUsers))
|
||||||
for _, rawUser := range rawUsers {
|
for _, rawUser := range rawUsers {
|
||||||
userId := domain.UserIdentifier(internal.MapDefaultString(rawUser, provider.FieldMap.UserIdentifier, ""))
|
userId := ldapUserIdentifier(rawUser, provider.FieldMap.UserIdentifier)
|
||||||
if userId != "" {
|
if userId != "" {
|
||||||
matchedUserIds = append(matchedUserIds, userId)
|
matchedUserIds = append(matchedUserIds, userId)
|
||||||
}
|
}
|
||||||
@@ -285,17 +290,26 @@ func (m Manager) updateInterfaceLdapFilters(
|
|||||||
return i, nil
|
return i, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to save interface ldap allowed users",
|
slog.Error("failed to save interface ldap allowed users",
|
||||||
"interface", ifaceId,
|
"interface", ifaceId,
|
||||||
"provider", provider.ProviderName,
|
"provider", provider.ProviderName,
|
||||||
"error", err)
|
"error", err)
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("updated interface ldap allowed users",
|
slog.Debug("updated interface ldap allowed users",
|
||||||
"interface", ifaceId,
|
"interface", ifaceId,
|
||||||
"provider", provider.ProviderName,
|
"provider", provider.ProviderName,
|
||||||
"matched_count", len(matchedUserIds))
|
"matched_count", len(matchedUserIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ldapUserIdentifier(rawUser map[string]any, field string) domain.UserIdentifier {
|
||||||
|
identifier := internal.MapDefaultString(rawUser, field, "")
|
||||||
|
identifier = domain.SanitizeIdentifier(identifier, 256)
|
||||||
|
if identifier == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return domain.UserIdentifier(identifier)
|
||||||
|
}
|
||||||
|
|||||||
@@ -365,19 +365,7 @@ func (m Manager) validateCreation(ctx context.Context, new *domain.User) error {
|
|||||||
return fmt.Errorf("invalid user identifier: %w", domain.ErrInvalidData)
|
return fmt.Errorf("invalid user identifier: %w", domain.ErrInvalidData)
|
||||||
}
|
}
|
||||||
|
|
||||||
if new.Identifier == "all" { // the 'all' user identifier collides with the rest api routes
|
if domain.IsReservedUserIdentifier(new.Identifier) {
|
||||||
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
|
|
||||||
}
|
|
||||||
|
|
||||||
if new.Identifier == "new" { // the 'new' user identifier collides with the rest api routes
|
|
||||||
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
|
|
||||||
}
|
|
||||||
|
|
||||||
if new.Identifier == "id" { // the 'id' user identifier collides with the rest api routes
|
|
||||||
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
|
|
||||||
}
|
|
||||||
|
|
||||||
if new.Identifier == domain.CtxSystemAdminId || new.Identifier == domain.CtxUnknownUserId {
|
|
||||||
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
|
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
package domain
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type LoginProvider string
|
type LoginProvider string
|
||||||
|
|
||||||
type LoginProviderInfo struct {
|
type LoginProviderInfo struct {
|
||||||
@@ -20,3 +25,52 @@ type AuthenticatorUserInfo struct {
|
|||||||
IsAdmin bool
|
IsAdmin bool
|
||||||
AdminInfoAvailable bool // true if the IsAdmin flag is valid
|
AdminInfoAvailable bool // true if the IsAdmin flag is valid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize sanitizes all external identity provider fields in place.
|
||||||
|
// Returns ErrInvalidData if the identifier becomes empty after sanitization.
|
||||||
|
func (u *AuthenticatorUserInfo) Sanitize(providerType, providerName string) error {
|
||||||
|
identifier := string(u.Identifier)
|
||||||
|
LogSanitizeChange(providerType, providerName, "identifier", identifier,
|
||||||
|
func() string { return SanitizeIdentifier(identifier, 256) }, &identifier)
|
||||||
|
u.Identifier = UserIdentifier(identifier)
|
||||||
|
|
||||||
|
email := u.Email
|
||||||
|
LogSanitizeChange(providerType, providerName, "email", email,
|
||||||
|
func() string { return SanitizeEmail(email, 254) }, &u.Email)
|
||||||
|
LogSanitizeChange(providerType, providerName, "firstname", u.Firstname,
|
||||||
|
func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname)
|
||||||
|
LogSanitizeChange(providerType, providerName, "lastname", u.Lastname,
|
||||||
|
func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname)
|
||||||
|
LogSanitizeChange(providerType, providerName, "phone", u.Phone,
|
||||||
|
func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone)
|
||||||
|
LogSanitizeChange(providerType, providerName, "department", u.Department,
|
||||||
|
func() string { return SanitizeString(u.Department, 128) }, &u.Department)
|
||||||
|
|
||||||
|
u.UserGroups = sanitizeGroups(providerType, providerName, u.UserGroups)
|
||||||
|
|
||||||
|
if u.Identifier == "" {
|
||||||
|
return fmt.Errorf("empty user identifier: %w", ErrInvalidData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeGroups sanitizes group names, dropping any that were modified by sanitization.
|
||||||
|
func sanitizeGroups(providerType, providerName string, rawGroups []string) []string {
|
||||||
|
if len(rawGroups) == 0 {
|
||||||
|
return rawGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := make([]string, 0, len(rawGroups))
|
||||||
|
for _, rawGroup := range rawGroups {
|
||||||
|
sanitized := rawGroup
|
||||||
|
LogSanitizeChange(providerType, providerName, "user_group", rawGroup,
|
||||||
|
func() string { return SanitizeString(rawGroup, 256) }, &sanitized)
|
||||||
|
if sanitized == "" || sanitized != strings.TrimSpace(rawGroup) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
groups = append(groups, sanitized)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|||||||
103
internal/domain/auth_sanitize_test.go
Normal file
103
internal/domain/auth_sanitize_test.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthenticatorUserInfo_Sanitize_NullByteInFirstname(t *testing.T) {
|
||||||
|
info := &AuthenticatorUserInfo{
|
||||||
|
Identifier: "alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
Firstname: "Ali\x00ce",
|
||||||
|
Lastname: "Smith",
|
||||||
|
}
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
err := info.Sanitize("ldap", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Alice", info.Firstname)
|
||||||
|
|
||||||
|
warnCount := testutil.CountWarnEntries(records)
|
||||||
|
assert.Equal(t, 1, warnCount)
|
||||||
|
|
||||||
|
_, found := testutil.FindWarnWithField(records, "firstname")
|
||||||
|
assert.True(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticatorUserInfo_Sanitize_AllFieldsClean(t *testing.T) {
|
||||||
|
info := &AuthenticatorUserInfo{
|
||||||
|
Identifier: "alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
Firstname: "Alice",
|
||||||
|
Lastname: "Smith",
|
||||||
|
Phone: "+1 555-1234",
|
||||||
|
Department: "Engineering",
|
||||||
|
}
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
err := info.Sanitize("ldap", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, UserIdentifier("alice"), info.Identifier)
|
||||||
|
assert.Equal(t, 0, testutil.CountWarnEntries(records))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticatorUserInfo_Sanitize_IdentifierAll(t *testing.T) {
|
||||||
|
info := &AuthenticatorUserInfo{
|
||||||
|
Identifier: "all",
|
||||||
|
Email: "all@example.com",
|
||||||
|
Firstname: "Alice",
|
||||||
|
Lastname: "Smith",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := info.Sanitize("ldap", "test-provider")
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, ErrInvalidData))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticatorUserInfo_Sanitize_CRLFInEmail(t *testing.T) {
|
||||||
|
info := &AuthenticatorUserInfo{
|
||||||
|
Identifier: "user123",
|
||||||
|
Email: "user\r\n@example.com",
|
||||||
|
Firstname: "Alice",
|
||||||
|
Lastname: "Smith",
|
||||||
|
}
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
err := info.Sanitize("oauth", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "", info.Email)
|
||||||
|
|
||||||
|
_, found := testutil.FindWarnWithField(records, "email")
|
||||||
|
assert.True(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticatorUserInfo_Sanitize_GroupsWithZeroWidthChars(t *testing.T) {
|
||||||
|
info := &AuthenticatorUserInfo{
|
||||||
|
Identifier: "user123",
|
||||||
|
Email: "user@example.com",
|
||||||
|
UserGroups: []string{"wgportal-\u200badmins"},
|
||||||
|
}
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
err := info.Sanitize("oidc", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, info.UserGroups)
|
||||||
|
|
||||||
|
_, found := testutil.FindWarnWithField(records, "user_group")
|
||||||
|
assert.True(t, found)
|
||||||
|
}
|
||||||
151
internal/domain/sanitize.go
Normal file
151
internal/domain/sanitize.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/mail"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"golang.org/x/text/unicode/norm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogSanitizeChange applies sanitizeFn to raw, logs when the value changes, and writes
|
||||||
|
// the sanitized value to dest. Raw and sanitized values are intentionally omitted.
|
||||||
|
func LogSanitizeChange(
|
||||||
|
providerType string,
|
||||||
|
providerName string,
|
||||||
|
field string,
|
||||||
|
raw string,
|
||||||
|
sanitizeFn func() string,
|
||||||
|
dest *string,
|
||||||
|
) {
|
||||||
|
sanitized := sanitizeFn()
|
||||||
|
if sanitized != raw {
|
||||||
|
message := "sanitization modified field value from external provider"
|
||||||
|
if sanitized == "" {
|
||||||
|
message = "sanitization cleared field value from external provider"
|
||||||
|
}
|
||||||
|
slog.Warn(message,
|
||||||
|
"provider_type", SanitizeString(providerType, 64),
|
||||||
|
"provider", SanitizeString(providerName, 128),
|
||||||
|
"field", SanitizeString(field, 64),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
*dest = sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
var reservedUserIdentifiers = map[string]struct{}{
|
||||||
|
"all": {},
|
||||||
|
"new": {},
|
||||||
|
"id": {},
|
||||||
|
CtxSystemAdminId: {},
|
||||||
|
CtxUnknownUserId: {},
|
||||||
|
CtxSystemLdapSyncer: {},
|
||||||
|
CtxSystemWgImporter: {},
|
||||||
|
CtxSystemV1Migrator: {},
|
||||||
|
CtxSystemDBMigrator: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeString normalizes to NFC, trims leading and trailing whitespace, strips Unicode
|
||||||
|
// control and format characters, drops invalid UTF-8 bytes, and truncates the result to
|
||||||
|
// maxLen runes. If maxLen <= 0, returns "".
|
||||||
|
func SanitizeString(s string, maxLen int) string {
|
||||||
|
if maxLen <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
s = norm.NFC.String(strings.TrimSpace(s))
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(s))
|
||||||
|
for len(s) > 0 {
|
||||||
|
r, size := utf8.DecodeRuneInString(s)
|
||||||
|
s = s[size:]
|
||||||
|
if r == utf8.RuneError && size == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !unicode.IsControl(r) && !unicode.Is(unicode.Cf, r) {
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s = b.String()
|
||||||
|
|
||||||
|
if utf8.RuneCountInString(s) > maxLen {
|
||||||
|
runes := []rune(s)
|
||||||
|
s = string(runes[:maxLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeEmail applies SanitizeString first, then returns "" if the original s
|
||||||
|
// contains CR/LF or if the sanitized result is not a plain email address.
|
||||||
|
func SanitizeEmail(s string, maxLen int) string {
|
||||||
|
if strings.ContainsRune(s, '\r') || strings.ContainsRune(s, '\n') {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized := SanitizeString(s, maxLen)
|
||||||
|
|
||||||
|
if sanitized == "" || strings.Count(sanitized, "@") != 1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
addr, err := mail.ParseAddress(sanitized)
|
||||||
|
if err != nil || addr.Name != "" || addr.Address != sanitized {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizePhone applies SanitizeString first, then removes all characters not in the
|
||||||
|
// set [0-9+\-() .]. Returns "" if the result after filtering is empty.
|
||||||
|
func SanitizePhone(s string, maxLen int) string {
|
||||||
|
sanitized := SanitizeString(s, maxLen)
|
||||||
|
|
||||||
|
// Remove all characters not in [0-9+\-() .]
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(sanitized))
|
||||||
|
for _, r := range sanitized {
|
||||||
|
if isAllowedPhoneRune(r) {
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := strings.TrimSpace(b.String())
|
||||||
|
|
||||||
|
if result == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAllowedPhoneRune reports whether r is in the allowed phone character set [0-9+\-() .].
|
||||||
|
func isAllowedPhoneRune(r rune) bool {
|
||||||
|
switch {
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
return true
|
||||||
|
case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeIdentifier applies SanitizeString first, then returns "" if the result equals
|
||||||
|
// a reserved user identifier (case-sensitive, exact match).
|
||||||
|
func SanitizeIdentifier(s string, maxLen int) string {
|
||||||
|
sanitized := SanitizeString(s, maxLen)
|
||||||
|
|
||||||
|
if IsReservedUserIdentifier(UserIdentifier(sanitized)) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsReservedUserIdentifier(identifier UserIdentifier) bool {
|
||||||
|
_, reserved := reservedUserIdentifiers[string(identifier)]
|
||||||
|
return reserved
|
||||||
|
}
|
||||||
503
internal/domain/sanitize_test.go
Normal file
503
internal/domain/sanitize_test.go
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/mail"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"pgregory.net/rapid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
maxLen int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "null byte removed",
|
||||||
|
input: "\x00",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CR removed",
|
||||||
|
input: "\r",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LF removed",
|
||||||
|
input: "\n",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tab removed",
|
||||||
|
input: "\t",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "leading and trailing whitespace trimmed",
|
||||||
|
input: " hello ",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-byte UTF-8 truncation at rune boundary",
|
||||||
|
input: "héllo",
|
||||||
|
maxLen: 3,
|
||||||
|
want: "hél", // 3 runes, not 3 bytes
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
input: "",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "maxLen zero returns empty",
|
||||||
|
input: "hello",
|
||||||
|
maxLen: 0,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string longer than maxLen truncated",
|
||||||
|
input: "abcdefgh",
|
||||||
|
maxLen: 4,
|
||||||
|
want: "abcd",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed control chars and normal chars",
|
||||||
|
input: "hel\x00lo\r\nworld",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "helloworld",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only whitespace returns empty",
|
||||||
|
input: " ",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string exactly at maxLen not truncated",
|
||||||
|
input: "abc",
|
||||||
|
maxLen: 3,
|
||||||
|
want: "abc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative maxLen returns empty",
|
||||||
|
input: "hello",
|
||||||
|
maxLen: -1,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DEL control removed",
|
||||||
|
input: "hel\x7flo",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero-width format character removed",
|
||||||
|
input: "ali\u200bce",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "alice",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid UTF-8 byte removed",
|
||||||
|
input: "a\xffb",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "ab",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode normalized to NFC",
|
||||||
|
input: "e\u0301",
|
||||||
|
maxLen: 64,
|
||||||
|
want: "é",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := SanitizeString(tc.input, tc.maxLen)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("SanitizeString(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeEmail(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
maxLen int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid email passes through unchanged",
|
||||||
|
input: "user@example.com",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "user@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CR in email returns empty",
|
||||||
|
input: "user\r@example.com",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LF in email returns empty",
|
||||||
|
input: "user\n@example.com",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing @ returns empty",
|
||||||
|
input: "userexample.com",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace-only returns empty",
|
||||||
|
input: " ",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "email with leading/trailing whitespace trimmed and returned",
|
||||||
|
input: " user@example.com ",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "user@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input returns empty",
|
||||||
|
input: "",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "display-name address rejected",
|
||||||
|
input: "User <user@example.com>",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple at signs rejected",
|
||||||
|
input: "user@@example.com",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid address rejected",
|
||||||
|
input: "user@",
|
||||||
|
maxLen: 254,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := SanitizeEmail(tc.input, tc.maxLen)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("SanitizeEmail(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizePhone(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
maxLen int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid phone passes through unchanged",
|
||||||
|
input: "+1 (555) 123-4567",
|
||||||
|
maxLen: 50,
|
||||||
|
want: "+1 (555) 123-4567",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-allowed chars stripped",
|
||||||
|
input: "abc+1def",
|
||||||
|
maxLen: 50,
|
||||||
|
want: "+1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-stripped input returns empty",
|
||||||
|
input: "abc",
|
||||||
|
maxLen: 50,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed allowed and non-allowed chars",
|
||||||
|
input: "+49 (0) 123-456.789",
|
||||||
|
maxLen: 50,
|
||||||
|
want: "+49 (0) 123-456.789",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input returns empty",
|
||||||
|
input: "",
|
||||||
|
maxLen: 50,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only digits passes through",
|
||||||
|
input: "1234567890",
|
||||||
|
maxLen: 50,
|
||||||
|
want: "1234567890",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := SanitizePhone(tc.input, tc.maxLen)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("SanitizePhone(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeIdentifier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
maxLen int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "reserved value all returns empty",
|
||||||
|
input: "all",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all with surrounding whitespace returns empty",
|
||||||
|
input: " all ",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reserved value new returns empty",
|
||||||
|
input: "new",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reserved value id returns empty",
|
||||||
|
input: "id",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "system admin identifier returns empty",
|
||||||
|
input: string(CtxSystemAdminId),
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown user identifier returns empty",
|
||||||
|
input: string(CtxUnknownUserId),
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LDAP syncer identifier returns empty",
|
||||||
|
input: string(CtxSystemLdapSyncer),
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ALL uppercase passes through (case-sensitive)",
|
||||||
|
input: "ALL",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "ALL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid email identifier passes through",
|
||||||
|
input: "alice@example.com",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "alice@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal identifier passes through",
|
||||||
|
input: "alice",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "alice",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input returns empty",
|
||||||
|
input: "",
|
||||||
|
maxLen: 256,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := SanitizeIdentifier(tc.input, tc.maxLen)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("SanitizeIdentifier(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeXSSPayload(t *testing.T) {
|
||||||
|
// XSS payload: null byte removed, angle brackets preserved
|
||||||
|
input := "<script>\x00alert(1)</script>"
|
||||||
|
want := "<script>alert(1)</script>"
|
||||||
|
got := SanitizeString(input, 256)
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("SanitizeString(%q, 256) = %q; want %q", input, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Property 1: SanitizeString output invariants
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Feature: external-identity-sanitization, Property 1: SanitizeString output is free of control characters and bounded in length
|
||||||
|
func TestPropertySanitizeStringOutputInvariants(t *testing.T) {
|
||||||
|
rapid.Check(t, func(t *rapid.T) {
|
||||||
|
s := rapid.String().Draw(t, "s")
|
||||||
|
maxLen := rapid.IntRange(0, 512).Draw(t, "maxLen")
|
||||||
|
result := SanitizeString(s, maxLen)
|
||||||
|
|
||||||
|
// No control or format runes in result
|
||||||
|
for _, r := range result {
|
||||||
|
if unicode.IsControl(r) || unicode.Is(unicode.Cf, r) {
|
||||||
|
t.Fatalf("result %q contains unsafe character %U", result, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !utf8.ValidString(result) {
|
||||||
|
t.Fatalf("result %q is not valid UTF-8", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No leading or trailing whitespace
|
||||||
|
if result != strings.TrimSpace(result) {
|
||||||
|
t.Fatalf("result %q has leading or trailing whitespace", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rune count <= maxLen
|
||||||
|
runeCount := utf8.RuneCountInString(result)
|
||||||
|
if runeCount > maxLen {
|
||||||
|
t.Fatalf("result %q has %d runes, exceeds maxLen %d", result, runeCount, maxLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Property 2: SanitizeString is idempotent
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Feature: external-identity-sanitization, Property 2: SanitizeString is idempotent
|
||||||
|
func TestPropertySanitizeStringIdempotent(t *testing.T) {
|
||||||
|
rapid.Check(t, func(t *rapid.T) {
|
||||||
|
s := rapid.String().Draw(t, "s")
|
||||||
|
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||||
|
|
||||||
|
once := SanitizeString(s, maxLen)
|
||||||
|
twice := SanitizeString(once, maxLen)
|
||||||
|
|
||||||
|
if once != twice {
|
||||||
|
t.Fatalf("SanitizeString is not idempotent: once=%q, twice=%q (input=%q, maxLen=%d)",
|
||||||
|
once, twice, s, maxLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Property 3: SanitizeEmail rejection rules
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Feature: external-identity-sanitization, Property 3: SanitizeEmail rejects strings without "@" or containing CR/LF
|
||||||
|
func TestPropertySanitizeEmailRejectionRules(t *testing.T) {
|
||||||
|
rapid.Check(t, func(t *rapid.T) {
|
||||||
|
s := rapid.String().Draw(t, "s")
|
||||||
|
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||||
|
result := SanitizeEmail(s, maxLen)
|
||||||
|
|
||||||
|
sanitized := SanitizeString(s, maxLen)
|
||||||
|
addr, parseErr := mail.ParseAddress(sanitized)
|
||||||
|
reject := strings.ContainsAny(s, "\r\n") ||
|
||||||
|
sanitized == "" ||
|
||||||
|
strings.Count(sanitized, "@") != 1 ||
|
||||||
|
parseErr != nil ||
|
||||||
|
addr.Name != "" ||
|
||||||
|
addr.Address != sanitized
|
||||||
|
if reject {
|
||||||
|
if result != "" {
|
||||||
|
t.Fatalf("SanitizeEmail(%q, %d) = %q; expected empty string (contains CR/LF or no @)",
|
||||||
|
s, maxLen, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Property 4: SanitizePhone allowed character set
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// isAllowedPhoneCharTest mirrors the internal isAllowedPhoneRune logic for test assertions.
|
||||||
|
func isAllowedPhoneCharTest(r rune) bool {
|
||||||
|
switch {
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
return true
|
||||||
|
case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feature: external-identity-sanitization, Property 4: SanitizePhone output contains only allowed characters
|
||||||
|
func TestPropertySanitizePhoneAllowedChars(t *testing.T) {
|
||||||
|
rapid.Check(t, func(t *rapid.T) {
|
||||||
|
s := rapid.String().Draw(t, "s")
|
||||||
|
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||||
|
result := SanitizePhone(s, maxLen)
|
||||||
|
|
||||||
|
for _, r := range result {
|
||||||
|
if !isAllowedPhoneCharTest(r) {
|
||||||
|
t.Fatalf("SanitizePhone(%q, %d) = %q; contains disallowed rune %U (%c)",
|
||||||
|
s, maxLen, result, r, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Property 5: SanitizeIdentifier rejects reserved identifiers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Feature: external-identity-sanitization, Property 5: SanitizeIdentifier rejects reserved values
|
||||||
|
func TestPropertySanitizeIdentifierRejectsReservedValues(t *testing.T) {
|
||||||
|
rapid.Check(t, func(t *rapid.T) {
|
||||||
|
s := rapid.String().Draw(t, "s")
|
||||||
|
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||||
|
result := SanitizeIdentifier(s, maxLen)
|
||||||
|
sanitized := SanitizeString(s, maxLen)
|
||||||
|
|
||||||
|
_, reserved := reservedUserIdentifiers[sanitized]
|
||||||
|
if reserved {
|
||||||
|
if result != "" {
|
||||||
|
t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected empty string when sanitized is reserved",
|
||||||
|
s, maxLen, result)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result != sanitized {
|
||||||
|
t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected %q (== SanitizeString result)",
|
||||||
|
s, maxLen, result, sanitized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -282,6 +282,32 @@ func (u *User) CreateDefaultPeers() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeExternalData sanitizes user profile fields received from an external identity provider.
|
||||||
|
// Returns ErrInvalidData if the identifier becomes empty after sanitization.
|
||||||
|
func (u *User) SanitizeExternalData(providerType, providerName string) error {
|
||||||
|
identifier := string(u.Identifier)
|
||||||
|
LogSanitizeChange(providerType, providerName, "identifier", identifier,
|
||||||
|
func() string { return SanitizeIdentifier(identifier, 256) }, &identifier)
|
||||||
|
u.Identifier = UserIdentifier(identifier)
|
||||||
|
|
||||||
|
LogSanitizeChange(providerType, providerName, "email", u.Email,
|
||||||
|
func() string { return SanitizeEmail(u.Email, 254) }, &u.Email)
|
||||||
|
LogSanitizeChange(providerType, providerName, "firstname", u.Firstname,
|
||||||
|
func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname)
|
||||||
|
LogSanitizeChange(providerType, providerName, "lastname", u.Lastname,
|
||||||
|
func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname)
|
||||||
|
LogSanitizeChange(providerType, providerName, "phone", u.Phone,
|
||||||
|
func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone)
|
||||||
|
LogSanitizeChange(providerType, providerName, "department", u.Department,
|
||||||
|
func() string { return SanitizeString(u.Department, 128) }, &u.Department)
|
||||||
|
|
||||||
|
if u.Identifier == "" {
|
||||||
|
return fmt.Errorf("empty user identifier: %w", ErrInvalidData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// region webauthn
|
// region webauthn
|
||||||
|
|
||||||
func (u *User) WebAuthnID() []byte {
|
func (u *User) WebAuthnID() []byte {
|
||||||
|
|||||||
64
internal/domain/user_sanitize_test.go
Normal file
64
internal/domain/user_sanitize_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUser_SanitizeExternalData_NullByteInFirstname(t *testing.T) {
|
||||||
|
u := &User{
|
||||||
|
Identifier: "alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
Firstname: "Ali\x00ce",
|
||||||
|
Lastname: "Smith",
|
||||||
|
}
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
err := u.SanitizeExternalData("ldap", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Alice", u.Firstname)
|
||||||
|
assert.Equal(t, 1, testutil.CountWarnEntries(records))
|
||||||
|
|
||||||
|
_, found := testutil.FindWarnWithField(records, "firstname")
|
||||||
|
assert.True(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_SanitizeExternalData_IdentifierAll(t *testing.T) {
|
||||||
|
u := &User{
|
||||||
|
Identifier: "all",
|
||||||
|
Email: "all@example.com",
|
||||||
|
Firstname: "Alice",
|
||||||
|
Lastname: "Smith",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := u.SanitizeExternalData("ldap", "test-provider")
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, ErrInvalidData))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_SanitizeExternalData_AllFieldsClean(t *testing.T) {
|
||||||
|
u := &User{
|
||||||
|
Identifier: "alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
Firstname: "Alice",
|
||||||
|
Lastname: "Smith",
|
||||||
|
Phone: "+1 555-1234",
|
||||||
|
Department: "Engineering",
|
||||||
|
}
|
||||||
|
|
||||||
|
restore := testutil.CaptureWarnLogs(t)
|
||||||
|
err := u.SanitizeExternalData("ldap", "test-provider")
|
||||||
|
records := restore()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, UserIdentifier("alice"), u.Identifier)
|
||||||
|
assert.Equal(t, 0, testutil.CountWarnEntries(records))
|
||||||
|
}
|
||||||
32
internal/sanitize/log.go
Normal file
32
internal/sanitize/log.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package sanitize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogChange applies sanitizeFn to raw, logs when the value changes, and writes
|
||||||
|
// the sanitized value to dest. Raw and sanitized values are intentionally omitted.
|
||||||
|
func LogChange(
|
||||||
|
providerType string,
|
||||||
|
providerName string,
|
||||||
|
field string,
|
||||||
|
raw string,
|
||||||
|
sanitizeFn func() string,
|
||||||
|
dest *string,
|
||||||
|
) {
|
||||||
|
sanitized := sanitizeFn()
|
||||||
|
if sanitized != raw {
|
||||||
|
message := "sanitization modified field value from external provider"
|
||||||
|
if sanitized == "" {
|
||||||
|
message = "sanitization cleared field value from external provider"
|
||||||
|
}
|
||||||
|
slog.Warn(message,
|
||||||
|
"provider_type", domain.SanitizeString(providerType, 64),
|
||||||
|
"provider", domain.SanitizeString(providerName, 128),
|
||||||
|
"field", domain.SanitizeString(field, 64),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
*dest = sanitized
|
||||||
|
}
|
||||||
50
internal/testutil/testutil.go
Normal file
50
internal/testutil/testutil.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CaptureWarnLogs(t *testing.T) (restore func() []map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
original := slog.Default()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
handler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn})
|
||||||
|
slog.SetDefault(slog.New(handler))
|
||||||
|
|
||||||
|
return func() []map[string]any {
|
||||||
|
slog.SetDefault(original)
|
||||||
|
var records []map[string]any
|
||||||
|
decoder := json.NewDecoder(&buf)
|
||||||
|
for decoder.More() {
|
||||||
|
var rec map[string]any
|
||||||
|
if err := decoder.Decode(&rec); err == nil {
|
||||||
|
records = append(records, rec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountWarnEntries(records []map[string]any) int {
|
||||||
|
count := 0
|
||||||
|
for _, r := range records {
|
||||||
|
if lvl, ok := r["level"].(string); ok && lvl == "WARN" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindWarnWithField(records []map[string]any, fieldName string) (map[string]any, bool) {
|
||||||
|
for _, r := range records {
|
||||||
|
if lvl, ok := r["level"].(string); ok && lvl == "WARN" {
|
||||||
|
if f, ok := r["field"].(string); ok && f == fieldName {
|
||||||
|
return r, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user